diff --git a/pkg/audit/__init__.py b/pkg/audit/__init__.py index e69de29..c1a8353 100644 --- a/pkg/audit/__init__.py +++ b/pkg/audit/__init__.py @@ -0,0 +1,3 @@ +""" +审计相关操作 +""" \ No newline at end of file diff --git a/pkg/audit/gatherer.py b/pkg/audit/gatherer.py index c18b187..6237b8c 100644 --- a/pkg/audit/gatherer.py +++ b/pkg/audit/gatherer.py @@ -1,3 +1,7 @@ +""" +使用量统计以及数据上报功能实现 +""" + import hashlib import json import logging @@ -10,8 +14,11 @@ import pkg.utils.updater class DataGatherer: """数据收集器""" + usage = {} - """以key值md5为key,{ + """各api-key的使用量 + + 以key值md5为key,{ "text": { "text-davinci-003": 文字量:int, }, @@ -25,11 +32,16 @@ class DataGatherer: def __init__(self): self.load_from_db() try: - self.version_str = pkg.utils.updater.get_current_tag() + self.version_str = pkg.utils.updater.get_current_tag() # 从updater模块获取版本号 except: pass def report_to_server(self, subservice_name: str, count: int): + """向中央服务器报告使用量 + + 只会报告此次请求的使用量,不会报告总量。 + 不包含除版本号、使用类型、使用量以外的任何信息,仅供开发者分析使用情况。 + """ try: config = pkg.utils.context.get_config() if hasattr(config, "report_usage") and not config.report_usage: @@ -44,7 +56,9 @@ class DataGatherer: return self.usage[key_md5] if key_md5 in self.usage else {} def report_text_model_usage(self, model, total_tokens): - key_md5 = pkg.utils.context.get_openai_manager().key_mgr.get_using_key_md5() + """调用方报告文字模型请求文字使用量""" + + key_md5 = pkg.utils.context.get_openai_manager().key_mgr.get_using_key_md5() # 以key的md5进行储存 if key_md5 not in self.usage: self.usage[key_md5] = {} @@ -62,6 +76,8 @@ class DataGatherer: self.report_to_server("text", length) def report_image_model_usage(self, size): + """调用方报告图片模型请求图片使用量""" + key_md5 = pkg.utils.context.get_openai_manager().key_mgr.get_using_key_md5() if key_md5 not in self.usage: @@ -79,6 +95,7 @@ class DataGatherer: self.report_to_server("image", 1) def get_text_length_of_key(self, key): + """获取指定api-key (明文) 的文字总使用量(本地记录)""" key_md5 = hashlib.md5(key.encode('utf-8')).hexdigest() if key_md5 not in self.usage: return 0 @@ -88,6 +105,8 @@ class DataGatherer: return sum(self.usage[key_md5]["text"].values()) def get_image_count_of_key(self, key): + """获取指定api-key (明文) 的图片总使用量(本地记录)""" + key_md5 = hashlib.md5(key.encode('utf-8')).hexdigest() if key_md5 not in self.usage: return 0 @@ -97,6 +116,7 @@ class DataGatherer: return sum(self.usage[key_md5]["image"].values()) def get_total_text_length(self): + """获取所有api-key的文字总使用量(本地记录)""" total = 0 for key in self.usage: if "text" not in self.usage[key]: diff --git a/pkg/database/__init__.py b/pkg/database/__init__.py index e69de29..c40dc21 100644 --- a/pkg/database/__init__.py +++ b/pkg/database/__init__.py @@ -0,0 +1,3 @@ +""" +数据库操作封装 +""" \ No newline at end of file diff --git a/pkg/database/manager.py b/pkg/database/manager.py index 519893b..5fde3c2 100644 --- a/pkg/database/manager.py +++ b/pkg/database/manager.py @@ -1,3 +1,6 @@ +""" +数据库管理模块 +""" import hashlib import json import logging @@ -9,9 +12,9 @@ import sqlite3 import pkg.utils.context -# 数据库管理 -# 为其他模块提供数据库操作接口 class DatabaseManager: + """封装数据库底层操作,并提供方法给上层使用""" + conn = None cursor = None @@ -23,13 +26,14 @@ class DatabaseManager: # 连接到数据库文件 def reconnect(self): + """连接到数据库""" self.conn = sqlite3.connect('database.db', check_same_thread=False) self.cursor = self.conn.cursor() def close(self): self.conn.close() - def execute(self, *args, **kwargs) -> Cursor: + def __execute__(self, *args, **kwargs) -> Cursor: # logging.debug('SQL: {}'.format(sql)) c = self.cursor.execute(*args, **kwargs) self.conn.commit() @@ -37,7 +41,9 @@ class DatabaseManager: # 初始化数据库的函数 def initialize_database(self): - self.execute(""" + """创建数据表""" + + self.__execute__(""" create table if not exists `sessions` ( `id` INTEGER PRIMARY KEY AUTOINCREMENT, `name` varchar(255) not null, @@ -50,7 +56,7 @@ class DatabaseManager: ) """) - self.execute(""" + self.__execute__(""" create table if not exists `account_fee`( `id` INTEGER PRIMARY KEY AUTOINCREMENT, `key_md5` varchar(255) not null, @@ -59,7 +65,7 @@ class DatabaseManager: ) """) - self.execute(""" + self.__execute__(""" create table if not exists `account_usage`( `id` INTEGER PRIMARY KEY AUTOINCREMENT, `json` text not null @@ -70,10 +76,12 @@ class DatabaseManager: # session持久化 def persistence_session(self, subject_type: str, subject_number: int, create_timestamp: int, last_interact_timestamp: int, prompt: str): + """持久化指定session""" + # 检查是否已经有了此name和create_timestamp的session # 如果有,就更新prompt和last_interact_timestamp # 如果没有,就插入一条新的记录 - self.execute(""" + self.__execute__(""" select count(*) from `sessions` where `type` = '{}' and `number` = {} and `create_timestamp` = {} """.format(subject_type, subject_number, create_timestamp)) count = self.cursor.fetchone()[0] @@ -84,8 +92,8 @@ class DatabaseManager: values (?, ?, ?, ?, ?, ?) """ - self.execute(sql, - ("{}_{}".format(subject_type, subject_number), subject_type, subject_number, create_timestamp, + self.__execute__(sql, + ("{}_{}".format(subject_type, subject_number), subject_type, subject_number, create_timestamp, last_interact_timestamp, prompt)) else: sql = """ @@ -93,23 +101,23 @@ class DatabaseManager: where `type` = ? and `number` = ? and `create_timestamp` = ? """ - self.execute(sql, (last_interact_timestamp, prompt, subject_type, - subject_number, create_timestamp)) + self.__execute__(sql, (last_interact_timestamp, prompt, subject_type, + subject_number, create_timestamp)) # 显式关闭一个session def explicit_close_session(self, session_name: str, create_timestamp: int): - self.execute(""" + self.__execute__(""" update `sessions` set `status` = 'explicitly_closed' where `name` = '{}' and `create_timestamp` = {} """.format(session_name, create_timestamp)) def set_session_ongoing(self, session_name: str, create_timestamp: int): - self.execute(""" + self.__execute__(""" update `sessions` set `status` = 'on_going' where `name` = '{}' and `create_timestamp` = {} """.format(session_name, create_timestamp)) # 设置session为过期 def set_session_expired(self, session_name: str, create_timestamp: int): - self.execute(""" + self.__execute__(""" update `sessions` set `status` = 'expired' where `name` = '{}' and `create_timestamp` = {} """.format(session_name, create_timestamp)) @@ -117,7 +125,7 @@ class DatabaseManager: def load_valid_sessions(self) -> dict: # 从数据库中加载所有还没过期的session config = pkg.utils.context.get_config() - self.execute(""" + self.__execute__(""" select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status` from `sessions` where `last_interact_timestamp` > {} """.format(int(time.time()) - config.session_expire_time)) @@ -150,7 +158,7 @@ class DatabaseManager: # 获取此session_name前一个session的数据 def last_session(self, session_name: str, cursor_timestamp: int): - self.execute(""" + self.__execute__(""" select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status` from `sessions` where `name` = '{}' and `last_interact_timestamp` < {} order by `last_interact_timestamp` desc limit 1 @@ -179,7 +187,7 @@ class DatabaseManager: # 获取此session_name后一个session的数据 def next_session(self, session_name: str, cursor_timestamp: int): - self.execute(""" + self.__execute__(""" select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status` from `sessions` where `name` = '{}' and `last_interact_timestamp` > {} order by `last_interact_timestamp` asc limit 1 @@ -207,7 +215,7 @@ class DatabaseManager: # 列出与某个对象的所有对话session def list_history(self, session_name: str, capacity: int, page: int): - self.execute(""" + self.__execute__(""" select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status` from `sessions` where `name` = '{}' order by `last_interact_timestamp` desc limit {} offset {} """.format(session_name, capacity, capacity * page)) @@ -246,22 +254,22 @@ class DatabaseManager: usage_count = usage[key_md5] # 将使用量存进数据库 # 先检查是否已存在 - self.execute(""" + self.__execute__(""" select count(*) from `api_key_usage` where `key_md5` = '{}'""".format(key_md5)) result = self.cursor.fetchone() if result[0] == 0: # 不存在则插入 - self.execute(""" + self.__execute__(""" insert into `api_key_usage` (`key_md5`, `usage`,`timestamp`) values ('{}', {}, {}) """.format(key_md5, usage_count, int(time.time()))) else: # 存在则更新,timestamp设置为当前 - self.execute(""" + self.__execute__(""" update `api_key_usage` set `usage` = {}, `timestamp` = {} where `key_md5` = '{}' """.format(usage_count, int(time.time()), key_md5)) def load_api_key_usage(self): - self.execute(""" + self.__execute__(""" select `key_md5`, `usage` from `api_key_usage` """) results = self.cursor.fetchall() @@ -273,23 +281,24 @@ class DatabaseManager: return usage def dump_usage_json(self, usage: dict): + json_str = json.dumps(usage) - self.execute(""" + self.__execute__(""" select count(*) from `account_usage`""") result = self.cursor.fetchone() if result[0] == 0: # 不存在则插入 - self.execute(""" + self.__execute__(""" insert into `account_usage` (`json`) values ('{}') """.format(json_str)) else: # 存在则更新 - self.execute(""" + self.__execute__(""" update `account_usage` set `json` = '{}' where `id` = 1 """.format(json_str)) def load_usage_json(self): - self.execute(""" + self.__execute__(""" select `json` from `account_usage` order by id desc limit 1 """) result = self.cursor.fetchone() diff --git a/pkg/openai/__init__.py b/pkg/openai/__init__.py index e69de29..e6a669c 100644 --- a/pkg/openai/__init__.py +++ b/pkg/openai/__init__.py @@ -0,0 +1,2 @@ +"""OpenAI 接口处理及会话管理相关 +""" diff --git a/pkg/openai/dprompt.py b/pkg/openai/dprompt.py index 29d9a88..3aba31c 100644 --- a/pkg/openai/dprompt.py +++ b/pkg/openai/dprompt.py @@ -1,8 +1,13 @@ # 多情景预设值管理 __current__ = "default" +"""当前默认使用的情景预设的名称 + +由管理员使用`!default <名称>`指令切换 +""" __prompts_from_files__ = {} +"""从文件中读取的情景预设值""" def read_prompt_from_file() -> str: diff --git a/pkg/openai/keymgr.py b/pkg/openai/keymgr.py index 78162fa..7127db8 100644 --- a/pkg/openai/keymgr.py +++ b/pkg/openai/keymgr.py @@ -5,18 +5,26 @@ import logging import pkg.plugin.host as plugin_host import pkg.plugin.models as plugin_models + class KeysManager: api_key = {} + """所有api-key""" - # api-key的使用量 - # 其中键为api-key的md5值,值为使用量 using_key = "" + """当前使用的api-key + """ alerted = [] + """已提示过超额的key + + 记录在此以避免重复提示 + """ - # 在此list中的都是经超额报错标记过的api-key - # 记录的是key值,仅在运行时有效 exceeded = [] + """已超额的key + + 供自动切换功能识别 + """ def get_using_key(self): return self.using_key @@ -25,8 +33,6 @@ class KeysManager: return hashlib.md5(self.using_key.encode('utf-8')).hexdigest() def __init__(self, api_key): - # if hasattr(config, 'api_key_usage_threshold'): - # self.api_key_usage_threshold = config.api_key_usage_threshold if type(api_key) is dict: self.api_key = api_key @@ -42,9 +48,13 @@ class KeysManager: self.auto_switch() - # 根据tested自动切换到可用的api-key - # 返回是否切换成功, 切换后的api-key的别名 def auto_switch(self) -> (bool, str): + """尝试切换api-key + + Returns: + 是否切换成功, 切换后的api-key的别名 + """ + for key_name in self.api_key: if self.api_key[key_name] not in self.exceeded: self.using_key = self.api_key[key_name] @@ -68,12 +78,9 @@ class KeysManager: def add(self, key_name, key): self.api_key[key_name] = key - # 设置当前使用的api-key使用量超限 - # 这是在尝试调用api时发生超限异常时调用的 def set_current_exceeded(self): - # md5 = hashlib.md5(self.using_key.encode('utf-8')).hexdigest() - # self.usage[md5] = self.api_key_usage_threshold - # self.fee[md5] = self.api_key_fee_threshold + """设置当前使用的api-key使用量超限 + """ self.exceeded.append(self.using_key) def get_key_name(self, api_key): diff --git a/pkg/openai/manager.py b/pkg/openai/manager.py index e5cef33..4a3ceab 100644 --- a/pkg/openai/manager.py +++ b/pkg/openai/manager.py @@ -7,9 +7,12 @@ import pkg.utils.context import pkg.audit.gatherer from pkg.openai.modelmgr import ModelRequest, create_openai_model_request -# 为其他模块提供与OpenAI交互的接口 + class OpenAIInteract: - api_params = {} + """OpenAI 接口封装 + + 将文字接口和图片接口封装供调用方使用 + """ key_mgr: pkg.openai.keymgr.KeysManager = None @@ -20,7 +23,6 @@ class OpenAIInteract: } def __init__(self, api_key: str): - # self.api_key = api_key self.key_mgr = pkg.openai.keymgr.KeysManager(api_key) self.audit_mgr = pkg.audit.gatherer.DataGatherer() @@ -32,7 +34,16 @@ class OpenAIInteract: pkg.utils.context.set_openai_manager(self) # 请求OpenAI Completion - def request_completion(self, prompts): + def request_completion(self, prompts) -> str: + """请求补全接口回复 + + Parameters: + prompts (str): 提示语 + + Returns: + str: 回复 + """ + config = pkg.utils.context.get_config() # 根据模型选择使用的接口 @@ -58,8 +69,15 @@ class OpenAIInteract: return ai.get_message() - def request_image(self, prompt): + def request_image(self, prompt) -> dict: + """请求图片接口回复 + Parameters: + prompt (str): 提示语 + + Returns: + dict: 响应 + """ config = pkg.utils.context.get_config() params = config.image_api_params if hasattr(config, "image_api_params") else self.default_image_api_params diff --git a/pkg/openai/modelmgr.py b/pkg/openai/modelmgr.py index fe41cbe..e67f98c 100644 --- a/pkg/openai/modelmgr.py +++ b/pkg/openai/modelmgr.py @@ -1,4 +1,10 @@ -# 提供与模型交互的抽象接口 +"""OpenAI 接口底层封装 + +目前使用的对话接口有: +ChatCompletion - gpt-3.5-turbo 等模型 +Completion - text-davinci-003 等模型 +此模块封装此两个接口的请求实现,为上层提供统一的调用方式 +""" import openai, logging, threading, asyncio import openai.error as aiE @@ -26,14 +32,15 @@ IMAGE_MODELS = { } -class ModelRequest(): - """GPT父类""" +class ModelRequest: + """模型接口请求父类""" + can_chat = False - runtime:threading.Thread = None + runtime: threading.Thread = None ret = {} - proxy:str = None + proxy: str = None request_ready = True - error_info:str = "若在没有任何错误的情况下看到这句话,请带着配置文件上报Issues" + error_info: str = "若在没有任何错误的情况下看到这句话,请带着配置文件上报Issues" def __init__(self, model_name, user_name, request_fun, http_proxy:str = None, time_out = None): self.model_name = model_name @@ -46,6 +53,8 @@ class ModelRequest(): self.request_ready = False async def __a_request__(self, **kwargs): + """异步请求""" + try: self.ret:dict = await self.request_fun(**kwargs) self.request_ready = True @@ -59,6 +68,8 @@ class ModelRequest(): raise Exception(self.error_info) def request(self, **kwargs): + """向接口发起请求""" + if self.proxy != None: #异步请求 self.request_ready = False loop = asyncio.new_event_loop() @@ -97,8 +108,10 @@ class ModelRequest(): def get_response(self): return self.ret + class ChatCompletionModel(ModelRequest): - """ChatCompletion类模型""" + """ChatCompletion接口的请求实现""" + Chat_role = ['system', 'user', 'assistant'] def __init__(self, model_name, user_name, http_proxy:str = None, **kwargs): if http_proxy == None: @@ -126,7 +139,8 @@ class ChatCompletionModel(ModelRequest): class CompletionModel(ModelRequest): - """Completion类模型""" + """Completion接口的请求实现""" + def __init__(self, model_name, user_name, http_proxy:str = None, **kwargs): if http_proxy == None: request_fun = openai.Completion.create diff --git a/pkg/openai/session.py b/pkg/openai/session.py index 6932a85..38a629d 100644 --- a/pkg/openai/session.py +++ b/pkg/openai/session.py @@ -1,3 +1,8 @@ +"""主线使用的会话管理模块 + +每个人、每个群单独一个session,session内部保留了对话的上下文, +""" + import logging import threading import time @@ -19,6 +24,7 @@ class SessionOfflineStatus: ON_GOING = 'on_going' EXPLICITLY_CLOSED = 'explicitly_closed' + # 重置session.prompt def reset_session_prompt(session_name, prompt): # 备份原始数据 @@ -43,11 +49,14 @@ def reset_session_prompt(session_name, prompt): 用户[{}]的数据已被重置,有可能是因为数据版本过旧或存储错误 原始数据将备份在: {}""".format(session_name, bak_path) - ) + ) # 为保证多行文本格式正确故无缩进 return prompt + # 从数据加载session def load_sessions(): + """从数据库加载sessions""" + global sessions db_inst = pkg.utils.context.get_database_manager() @@ -93,10 +102,13 @@ class Session: name = '' prompt = [] + """使用list来保存会话中的回合""" create_timestamp = 0 + """会话创建时间""" last_interact_timestamp = 0 + """上次交互(产生回复)时间""" just_switched_to_exist_session = False @@ -116,7 +128,7 @@ class Session: logging.debug('{},lock release successfully,{}'.format(self.name, self.response_lock)) # 从配置文件获取会话预设信息 - def get_default_prompt(self, use_default: str=None): + def get_default_prompt(self, use_default: str = None): config = pkg.utils.context.get_config() import pkg.openai.dprompt as dprompt @@ -130,7 +142,7 @@ class Session: { 'role': 'user', 'content': current_default_prompt - },{ + }, { 'role': 'assistant', 'content': 'ok' } @@ -182,6 +194,8 @@ class Session: # 请求回复 # 这个函数是阻塞的 def append(self, text: str) -> str: + """向session中添加一条消息,返回接口回复""" + self.last_interact_timestamp = int(time.time()) # 触发插件事件 @@ -215,14 +229,14 @@ class Session: res_ans = '\n\n'.join(res_ans_spt) # 将此次对话的双方内容加入到prompt中 - self.prompt.append({'role':'user', 'content':text}) - self.prompt.append({'role':'assistant', 'content':res_ans}) + self.prompt.append({'role': 'user', 'content': text}) + self.prompt.append({'role': 'assistant', 'content': res_ans}) if self.just_switched_to_exist_session: self.just_switched_to_exist_session = False self.set_ongoing() - return res_ans if res_ans[0]!='\n' else res_ans[1:] + return res_ans if res_ans[0] != '\n' else res_ans[1:] # 删除上一回合并返回上一回合的问题 def undo(self) -> str: @@ -231,10 +245,10 @@ class Session: # 删除最后两个消息 if len(self.prompt) < 2: raise Exception('之前无对话,无法撤销') - + question = self.prompt[-2]['content'] self.prompt = self.prompt[:-2] - + # 返回上一回合的问题 return question @@ -242,13 +256,13 @@ class Session: def cut_out(self, msg: str, max_tokens: int) -> list: """将现有prompt进行切割处理,使得新的prompt长度不超过max_tokens""" # 如果用户消息长度超过max_tokens,直接返回 - + temp_prompt = [ - { - 'role': 'user', - 'content': msg - } - ] + { + 'role': 'user', + 'content': msg + } + ] token_count = len(msg) # 倒序遍历prompt diff --git a/pkg/plugin/__init__.py b/pkg/plugin/__init__.py index e69de29..c543161 100644 --- a/pkg/plugin/__init__.py +++ b/pkg/plugin/__init__.py @@ -0,0 +1,4 @@ +"""插件支持包 + +包含插件基类、插件宿主以及部分API接口 +""" \ No newline at end of file diff --git a/pkg/plugin/host.py b/pkg/plugin/host.py index 962670a..a8163f1 100644 --- a/pkg/plugin/host.py +++ b/pkg/plugin/host.py @@ -116,7 +116,9 @@ def initialize_plugins(): def unload_plugins(): - """ 卸载插件 """ + """ 卸载插件 + """ + # 不再显式卸载插件,因为当程序结束时,插件的析构函数会被系统执行 # for plugin in __plugins__.values(): # if plugin['enabled'] and plugin['instance'] is not None: # if not hasattr(plugin['instance'], '__del__'): diff --git a/pkg/plugin/models.py b/pkg/plugin/models.py index c190678..180f074 100644 --- a/pkg/plugin/models.py +++ b/pkg/plugin/models.py @@ -145,6 +145,7 @@ __current_registering_plugin__ = "" class Plugin: + """插件基类""" host: host.PluginHost """插件宿主,提供插件的一些基础功能"""