diff --git a/config-template.py b/config-template.py index dc3b250..f0613aa 100644 --- a/config-template.py +++ b/config-template.py @@ -166,6 +166,22 @@ alter_tip_message = '出错了,请稍后再试' # 默认值20分钟 session_expire_time = 60 * 20 +# 会话限速 +# 单会话内每分钟可进行的对话次数 +# 若不需要限速,可以设置为一个很大的值 +# 默认值60次,基本上不会触发限速 +rate_limitation = 60 + +# 会话限速策略 +# - "wait": 每次对话获取到回复时,等待一定时间再发送回复,保证其不会超过限速均值 +# - "drop": 此分钟内,若对话次数超过限速次数,则丢弃之后的对话,每自然分钟重置 +rate_limit_strategy = "wait" + +# drop策略时,超过限速均值时,丢弃的对话的提示信息 +# 仅当rate_limitation_strategy为"drop"时生效 +# 若设置为空字符串,则不发送提示信息 +rate_limit_drop_tip = "本分钟对话次数超过限速次数,此对话被丢弃" + # 是否上报统计信息 # 用于统计机器人的使用情况,不会收集任何用户信息 # 仅上报时间、字数使用量、绘图使用量,其他信息不会上报 diff --git a/pkg/qqbot/message.py b/pkg/qqbot/message.py index 2b2cc6c..864fda2 100644 --- a/pkg/qqbot/message.py +++ b/pkg/qqbot/message.py @@ -1,5 +1,6 @@ # 普通消息处理模块 import logging +import time import openai import pkg.utils.context import pkg.openai.session diff --git a/pkg/qqbot/process.py b/pkg/qqbot/process.py index b5fada0..00785b6 100644 --- a/pkg/qqbot/process.py +++ b/pkg/qqbot/process.py @@ -1,5 +1,6 @@ # 此模块提供了消息处理的具体逻辑的接口 import asyncio +import time import mirai import logging @@ -19,6 +20,7 @@ import pkg.utils.updater import pkg.utils.context import pkg.qqbot.message import pkg.qqbot.command +import pkg.qqbot.ratelimit as ratelimit import pkg.plugin.host as plugin_host import pkg.plugin.models as plugin_models @@ -99,6 +101,14 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes mgr, config, launcher_type, launcher_id, sender_id) else: # 消息 + # 限速丢弃检查 + # print(ratelimit.__crt_minute_usage__[session_name]) + if hasattr(config, "rate_limitation") and config.rate_limit_strategy == "drop": + if ratelimit.is_reach_limit(session_name): + logging.info("根据限速策略丢弃[{}]消息: {}".format(session_name, text_message)) + return MessageChain(["[bot]"+config.rate_limit_drop_tip]) if hasattr(config, "rate_limit_drop_tip") and config.rate_limit_drop_tip != "" else [] + + before = time.time() # 触发插件事件 args = { "launcher_type": launcher_type, @@ -121,6 +131,13 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes reply = pkg.qqbot.message.process_normal_message(text_message, mgr, config, launcher_type, launcher_id, sender_id) + # 限速等待时间 + if hasattr(config, "rate_limitation") and config.rate_limit_strategy == "wait": + time.sleep(ratelimit.get_rest_wait_time(session_name, time.time() - before)) + + if hasattr(config, "rate_limitation"): + ratelimit.add_usage(session_name) + if reply is not None and (type(reply[0]) == str or type(reply[0]) == mirai.Plain): logging.info( "回复[{}]文字消息:{}".format(session_name, @@ -135,4 +152,4 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes finally: pkg.openai.session.get_session(session_name).release_response_lock() - return MessageChain(reply) + return MessageChain(reply) diff --git a/pkg/qqbot/ratelimit.py b/pkg/qqbot/ratelimit.py new file mode 100644 index 0000000..2a759b6 --- /dev/null +++ b/pkg/qqbot/ratelimit.py @@ -0,0 +1,86 @@ +# 限速相关模块 +import time +import logging +import threading + +__crt_minute_usage__ = {} +"""当前分钟每个会话的对话次数""" + + +__timer_thr__: threading.Thread = None + + +def add_usage(session_name: str): + """增加会话的对话次数""" + global __crt_minute_usage__ + if session_name in __crt_minute_usage__: + __crt_minute_usage__[session_name] += 1 + else: + __crt_minute_usage__[session_name] = 1 + + +def start_timer(): + """启动定时器""" + global __timer_thr__ + __timer_thr__ = threading.Thread(target=run_timer, daemon=True) + __timer_thr__.start() + + +def run_timer(): + """启动定时器,每分钟清空一次对话次数""" + global __crt_minute_usage__ + global __timer_thr__ + + # 等待直到整分钟 + time.sleep(60 - time.time() % 60) + + while True: + if __timer_thr__ != threading.current_thread(): + break + + logging.debug("清空当前分钟的对话次数") + __crt_minute_usage__ = {} + time.sleep(60) + + +def get_usage(session_name: str) -> int: + """获取会话的对话次数""" + global __crt_minute_usage__ + if session_name in __crt_minute_usage__: + return __crt_minute_usage__[session_name] + else: + return 0 + + +def get_rest_wait_time(session_name: str, spent: float) -> float: + """获取会话此回合的剩余等待时间""" + global __crt_minute_usage__ + + import config + + if not hasattr(config, 'rate_limitation'): + return 0 + + min_seconds_per_round = 60.0 / config.rate_limitation + + if session_name in __crt_minute_usage__: + return max(0, min_seconds_per_round - spent) + else: + return 0 + + +def is_reach_limit(session_name: str) -> bool: + """判断会话是否超过限制""" + global __crt_minute_usage__ + + import config + + if not hasattr(config, 'rate_limitation'): + return False + + if session_name in __crt_minute_usage__: + return __crt_minute_usage__[session_name] >= config.rate_limitation + else: + return False + +start_timer()