mirror of
https://github.com/RockChinQ/QChatGPT.git
synced 2024-11-16 11:42:44 +08:00
feat: 限制机器人回复的频率 (#72)
This commit is contained in:
parent
eb60c1b0a0
commit
e7c79a5156
|
@ -166,6 +166,22 @@ alter_tip_message = '出错了,请稍后再试'
|
||||||
# 默认值20分钟
|
# 默认值20分钟
|
||||||
session_expire_time = 60 * 20
|
session_expire_time = 60 * 20
|
||||||
|
|
||||||
|
# 会话限速
|
||||||
|
# 单会话内每分钟可进行的对话次数
|
||||||
|
# 若不需要限速,可以设置为一个很大的值
|
||||||
|
# 默认值60次,基本上不会触发限速
|
||||||
|
rate_limitation = 60
|
||||||
|
|
||||||
|
# 会话限速策略
|
||||||
|
# - "wait": 每次对话获取到回复时,等待一定时间再发送回复,保证其不会超过限速均值
|
||||||
|
# - "drop": 此分钟内,若对话次数超过限速次数,则丢弃之后的对话,每自然分钟重置
|
||||||
|
rate_limit_strategy = "wait"
|
||||||
|
|
||||||
|
# drop策略时,超过限速均值时,丢弃的对话的提示信息
|
||||||
|
# 仅当rate_limitation_strategy为"drop"时生效
|
||||||
|
# 若设置为空字符串,则不发送提示信息
|
||||||
|
rate_limit_drop_tip = "本分钟对话次数超过限速次数,此对话被丢弃"
|
||||||
|
|
||||||
# 是否上报统计信息
|
# 是否上报统计信息
|
||||||
# 用于统计机器人的使用情况,不会收集任何用户信息
|
# 用于统计机器人的使用情况,不会收集任何用户信息
|
||||||
# 仅上报时间、字数使用量、绘图使用量,其他信息不会上报
|
# 仅上报时间、字数使用量、绘图使用量,其他信息不会上报
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# 普通消息处理模块
|
# 普通消息处理模块
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
import openai
|
import openai
|
||||||
import pkg.utils.context
|
import pkg.utils.context
|
||||||
import pkg.openai.session
|
import pkg.openai.session
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# 此模块提供了消息处理的具体逻辑的接口
|
# 此模块提供了消息处理的具体逻辑的接口
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import time
|
||||||
|
|
||||||
import mirai
|
import mirai
|
||||||
import logging
|
import logging
|
||||||
|
@ -19,6 +20,7 @@ import pkg.utils.updater
|
||||||
import pkg.utils.context
|
import pkg.utils.context
|
||||||
import pkg.qqbot.message
|
import pkg.qqbot.message
|
||||||
import pkg.qqbot.command
|
import pkg.qqbot.command
|
||||||
|
import pkg.qqbot.ratelimit as ratelimit
|
||||||
|
|
||||||
import pkg.plugin.host as plugin_host
|
import pkg.plugin.host as plugin_host
|
||||||
import pkg.plugin.models as plugin_models
|
import pkg.plugin.models as plugin_models
|
||||||
|
@ -99,6 +101,14 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes
|
||||||
mgr, config, launcher_type, launcher_id, sender_id)
|
mgr, config, launcher_type, launcher_id, sender_id)
|
||||||
|
|
||||||
else: # 消息
|
else: # 消息
|
||||||
|
# 限速丢弃检查
|
||||||
|
# print(ratelimit.__crt_minute_usage__[session_name])
|
||||||
|
if hasattr(config, "rate_limitation") and config.rate_limit_strategy == "drop":
|
||||||
|
if ratelimit.is_reach_limit(session_name):
|
||||||
|
logging.info("根据限速策略丢弃[{}]消息: {}".format(session_name, text_message))
|
||||||
|
return MessageChain(["[bot]"+config.rate_limit_drop_tip]) if hasattr(config, "rate_limit_drop_tip") and config.rate_limit_drop_tip != "" else []
|
||||||
|
|
||||||
|
before = time.time()
|
||||||
# 触发插件事件
|
# 触发插件事件
|
||||||
args = {
|
args = {
|
||||||
"launcher_type": launcher_type,
|
"launcher_type": launcher_type,
|
||||||
|
@ -121,6 +131,13 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes
|
||||||
reply = pkg.qqbot.message.process_normal_message(text_message,
|
reply = pkg.qqbot.message.process_normal_message(text_message,
|
||||||
mgr, config, launcher_type, launcher_id, sender_id)
|
mgr, config, launcher_type, launcher_id, sender_id)
|
||||||
|
|
||||||
|
# 限速等待时间
|
||||||
|
if hasattr(config, "rate_limitation") and config.rate_limit_strategy == "wait":
|
||||||
|
time.sleep(ratelimit.get_rest_wait_time(session_name, time.time() - before))
|
||||||
|
|
||||||
|
if hasattr(config, "rate_limitation"):
|
||||||
|
ratelimit.add_usage(session_name)
|
||||||
|
|
||||||
if reply is not None and (type(reply[0]) == str or type(reply[0]) == mirai.Plain):
|
if reply is not None and (type(reply[0]) == str or type(reply[0]) == mirai.Plain):
|
||||||
logging.info(
|
logging.info(
|
||||||
"回复[{}]文字消息:{}".format(session_name,
|
"回复[{}]文字消息:{}".format(session_name,
|
||||||
|
@ -135,4 +152,4 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes
|
||||||
finally:
|
finally:
|
||||||
pkg.openai.session.get_session(session_name).release_response_lock()
|
pkg.openai.session.get_session(session_name).release_response_lock()
|
||||||
|
|
||||||
return MessageChain(reply)
|
return MessageChain(reply)
|
||||||
|
|
86
pkg/qqbot/ratelimit.py
Normal file
86
pkg/qqbot/ratelimit.py
Normal file
|
@ -0,0 +1,86 @@
|
||||||
|
# 限速相关模块
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
|
||||||
|
__crt_minute_usage__ = {}
|
||||||
|
"""当前分钟每个会话的对话次数"""
|
||||||
|
|
||||||
|
|
||||||
|
__timer_thr__: threading.Thread = None
|
||||||
|
|
||||||
|
|
||||||
|
def add_usage(session_name: str):
|
||||||
|
"""增加会话的对话次数"""
|
||||||
|
global __crt_minute_usage__
|
||||||
|
if session_name in __crt_minute_usage__:
|
||||||
|
__crt_minute_usage__[session_name] += 1
|
||||||
|
else:
|
||||||
|
__crt_minute_usage__[session_name] = 1
|
||||||
|
|
||||||
|
|
||||||
|
def start_timer():
|
||||||
|
"""启动定时器"""
|
||||||
|
global __timer_thr__
|
||||||
|
__timer_thr__ = threading.Thread(target=run_timer, daemon=True)
|
||||||
|
__timer_thr__.start()
|
||||||
|
|
||||||
|
|
||||||
|
def run_timer():
|
||||||
|
"""启动定时器,每分钟清空一次对话次数"""
|
||||||
|
global __crt_minute_usage__
|
||||||
|
global __timer_thr__
|
||||||
|
|
||||||
|
# 等待直到整分钟
|
||||||
|
time.sleep(60 - time.time() % 60)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
if __timer_thr__ != threading.current_thread():
|
||||||
|
break
|
||||||
|
|
||||||
|
logging.debug("清空当前分钟的对话次数")
|
||||||
|
__crt_minute_usage__ = {}
|
||||||
|
time.sleep(60)
|
||||||
|
|
||||||
|
|
||||||
|
def get_usage(session_name: str) -> int:
|
||||||
|
"""获取会话的对话次数"""
|
||||||
|
global __crt_minute_usage__
|
||||||
|
if session_name in __crt_minute_usage__:
|
||||||
|
return __crt_minute_usage__[session_name]
|
||||||
|
else:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def get_rest_wait_time(session_name: str, spent: float) -> float:
|
||||||
|
"""获取会话此回合的剩余等待时间"""
|
||||||
|
global __crt_minute_usage__
|
||||||
|
|
||||||
|
import config
|
||||||
|
|
||||||
|
if not hasattr(config, 'rate_limitation'):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
min_seconds_per_round = 60.0 / config.rate_limitation
|
||||||
|
|
||||||
|
if session_name in __crt_minute_usage__:
|
||||||
|
return max(0, min_seconds_per_round - spent)
|
||||||
|
else:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def is_reach_limit(session_name: str) -> bool:
|
||||||
|
"""判断会话是否超过限制"""
|
||||||
|
global __crt_minute_usage__
|
||||||
|
|
||||||
|
import config
|
||||||
|
|
||||||
|
if not hasattr(config, 'rate_limitation'):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if session_name in __crt_minute_usage__:
|
||||||
|
return __crt_minute_usage__[session_name] >= config.rate_limitation
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
start_timer()
|
Loading…
Reference in New Issue
Block a user