From 77ec1c7ff0df4d467995d949fe7988c9885d8d3a Mon Sep 17 00:00:00 2001 From: Rock Chin Date: Thu, 8 Dec 2022 12:06:04 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=AE=8C=E6=88=90=E5=9F=BA=E6=9C=AC?= =?UTF-8?q?=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config-template.py | 1 + main.py | 17 +++++++++++++++++ pkg/database/manager.py | 6 ++++-- pkg/openai/session.py | 13 +++++++++++++ pkg/qqbot/manager.py | 32 ++++++++++++++++++-------------- 5 files changed, 53 insertions(+), 16 deletions(-) diff --git a/config-template.py b/config-template.py index 7740158..533202f 100644 --- a/config-template.py +++ b/config-template.py @@ -1,5 +1,6 @@ mirai_http_api_config = { "host": "", + "port": 8080, "verifyKey": "", "qq": 0 } diff --git a/main.py b/main.py index 18c002f..0b5acf6 100644 --- a/main.py +++ b/main.py @@ -1,10 +1,13 @@ import os import shutil import sys +import threading +import time import pkg.openai.manager import pkg.database.manager import pkg.openai.session +import pkg.qqbot.manager def init_db(): @@ -32,9 +35,23 @@ def main(): # 加载所有未超时的session pkg.openai.session.load_sessions() + # 初始化qq机器人 + qqbot = pkg.qqbot.manager.QQBotManager(mirai_http_api_config=config.mirai_http_api_config, + timeout=config.process_message_timeout, retry=config.retry_times) + + qq_bot_thread = threading.Thread(target=qqbot.bot.run, args=(), daemon=True) + qq_bot_thread.start() + if __name__ == '__main__': if len(sys.argv) > 1 and sys.argv[1] == 'init_db': init_db() sys.exit(0) main() + + while True: + try: + time.sleep(86400) + except KeyboardInterrupt: + print("程序退出") + break diff --git a/pkg/database/manager.py b/pkg/database/manager.py index d5716f0..be805f0 100644 --- a/pkg/database/manager.py +++ b/pkg/database/manager.py @@ -1,6 +1,7 @@ import time import pymysql +from pymysql.converters import escape_string import config @@ -61,12 +62,13 @@ class DatabaseManager: insert into `sessions` (`name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`) values ('{}', '{}', {}, {}, {}, '{}') """.format("{}_{}".format(subject_type, subject_number), subject_type, subject_number, create_timestamp, - last_interact_timestamp, prompt)) + last_interact_timestamp, escape_string(prompt))) else: self.cursor.execute(""" update `sessions` set `last_interact_timestamp` = {}, `prompt` = '{}' where `type` = '{}' and `number` = {} and `create_timestamp` = {} - """.format(last_interact_timestamp, prompt, subject_type, subject_number, create_timestamp)) + """.format(last_interact_timestamp, escape_string(prompt), subject_type, + subject_number, create_timestamp)) # 记载还没过期的session数据 def load_valid_sessions(self) -> dict: diff --git a/pkg/openai/session.py b/pkg/openai/session.py index a3e9bcd..aa0f100 100644 --- a/pkg/openai/session.py +++ b/pkg/openai/session.py @@ -88,3 +88,16 @@ class Session: db_inst.persistence_session(subject_type, subject_number, self.create_timestamp, self.last_interact_timestamp, self.prompt) + + def reset(self): + if self.prompt != '': + self.persistence() + self.prompt = '' + self.create_timestamp = int(time.time()) + self.last_interact_timestamp = 0 + + def last_session(self): + pass + + def next_session(self): + pass diff --git a/pkg/qqbot/manager.py b/pkg/qqbot/manager.py index 2074d01..aa9a4d3 100644 --- a/pkg/qqbot/manager.py +++ b/pkg/qqbot/manager.py @@ -2,7 +2,7 @@ from mirai import At, GroupMessage, MessageEvent, Mirai, Plain, StrangerMessage, import pkg.openai.session from func_timeout import func_set_timeout, FunctionTimedOut -help_text = """ +help_text = """帮助信息: !help - 显示帮助 !reset - 重置会话 !last - 切换到上一次的对话 @@ -13,6 +13,7 @@ inst = None processing = [] + class QQBotManager: timeout = 60 retry = 3 @@ -29,7 +30,7 @@ class QQBotManager: adapter=WebSocketAdapter( verify_key=mirai_http_api_config['verifyKey'], host=mirai_http_api_config['host'], - port=8080 + port=mirai_http_api_config['port'] ) ) @@ -56,26 +57,27 @@ class QQBotManager: reply = '' session_name = "{}_{}".format(launcher_type, launcher_id) - if text_message.startswith('!'): # 指令 - cmd = text_message + if text_message.startswith('!') or text_message.startswith("!"): # 指令 + cmd = text_message[1:].strip() - if cmd == '!help': - reply = help_text - elif cmd == '!reset': + if cmd == 'help': + reply = "[bot]" + help_text + elif cmd == 'reset': + pkg.openai.session.get_session(session_name).reset() + reply = "[bot]会话已重置" + elif cmd == 'last': pass - elif cmd == '!last': - pass - elif cmd == '!next': + elif cmd == 'next': pass else: # 消息 session = pkg.openai.session.get_session(session_name) - reply = session.append(text_message) + reply = "[GPT]" + session.append(text_message) return reply async def on_person_message(self, event: MessageEvent): if "person_{}".format(event.sender.id) in processing: - return + return await self.bot.send(event, "err:正在处理中,请稍后再试") reply = '' @@ -107,7 +109,7 @@ class QQBotManager: async def on_group_message(self, event: GroupMessage): if "group_{}".format(event.group.id) in processing: - return + return await self.bot.send(event, "err:正在处理中,请稍后再试") reply = '' @@ -116,13 +118,15 @@ class QQBotManager: elif At(self.bot.qq) not in event.message_chain: pass else: + event.message_chain.remove(At(self.bot.qq)) + processing.append("group_{}".format(event.sender.id)) # 超时则重试,重试超过次数则放弃 failed = 0 for i in range(self.retry): try: - reply = self.process_message('group', event.group.id, str(event.message_chain)) + reply = self.process_message('group', event.group.id, str(event.message_chain).strip()) break except FunctionTimedOut: failed += 1