feat: 支持设置回复时引用原消息 #73

This commit is contained in:
Rock Chin 2023-01-01 17:20:54 +08:00
parent 88ec74c0a4
commit b085c133bf
3 changed files with 37 additions and 17 deletions

View File

@ -84,6 +84,12 @@ image_api_params = {
"size": "256x256", # 图片尺寸支持256x256, 512x512, 1024x1024
}
# 回复消息时是否引用原消息
quote_origin = True
# 回复绘图时是否包含图片描述
include_image_description = True
# 消息处理的超时时间,单位为秒
process_message_timeout = 15

View File

@ -7,6 +7,8 @@ import openai.error
from mirai import At, GroupMessage, MessageEvent, Mirai, Plain, StrangerMessage, WebSocketAdapter, HTTPAdapter, \
FriendMessage, Image
from mirai.models.message import Quote
import config
import pkg.openai.session
import pkg.openai.manager
@ -108,8 +110,9 @@ class QQBotManager:
global inst
inst = self
def send(self, event, msg):
asyncio.run(self.bot.send(event, msg))
def send(self, event, msg, check_quote=True):
asyncio.run(
self.bot.send(event, msg, quote=True if hasattr(config, "quote_origin") and config.quote_origin and check_quote else False))
# 私聊消息处理
def on_person_message(self, event: MessageEvent):
@ -126,7 +129,9 @@ class QQBotManager:
failed = 0
for i in range(self.retry):
try:
reply = processor.process_message('person', event.sender.id, str(event.message_chain))
reply = processor.process_message('person', event.sender.id, str(event.message_chain),
event.message_chain,
event.sender.id)
break
except FunctionTimedOut:
pkg.openai.session.get_session('person_{}'.format(event.sender.id)).release_response_lock()
@ -139,7 +144,7 @@ class QQBotManager:
reply = ["[bot]err:请求超时"]
if reply:
return self.send(event, reply)
return self.send(event, reply, check_quote=False)
# 群消息处理
def on_group_message(self, event: GroupMessage):
@ -156,7 +161,9 @@ class QQBotManager:
for i in range(self.retry):
try:
replys = processor.process_message('group', event.group.id,
str(event.message_chain).strip() if text is None else text)
str(event.message_chain).strip() if text is None else text,
event.message_chain,
event.sender.id)
break
except FunctionTimedOut:
failed += 1

View File

@ -6,7 +6,8 @@ from func_timeout import func_set_timeout
import logging
import openai
from mirai import Image
from mirai import Image, MessageChain
from mirai.models.message import Quote
import config
@ -17,7 +18,8 @@ processing = []
@func_set_timeout(config.process_message_timeout)
def process_message(launcher_type: str, launcher_id: int, text_message: str) -> []:
def process_message(launcher_type: str, launcher_id: int, text_message: str, message_chain: MessageChain,
sender_id: int) -> MessageChain:
global processing
mgr = pkg.qqbot.manager.get_inst()
@ -118,13 +120,14 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str) ->
using_key_name = ""
for api_key in api_keys:
reply_str += "{}:\n - {}美元 {}%\n".format(api_key,
round(pkg.openai.manager.get_inst().key_mgr.get_fee(
api_keys[api_key]), 6),
round(
pkg.openai.manager.get_inst().key_mgr.get_fee(
api_keys[
api_key]) / pkg.openai.manager.get_inst().key_mgr.api_key_fee_threshold * 100,
3))
round(
pkg.openai.manager.get_inst().key_mgr.get_fee(
api_keys[api_key]), 6),
round(
pkg.openai.manager.get_inst().key_mgr.get_fee(
api_keys[
api_key]) / pkg.openai.manager.get_inst().key_mgr.api_key_fee_threshold * 100,
3))
if api_keys[api_key] == pkg.openai.manager.get_inst().key_mgr.using_key:
using_key_name = api_key
reply_str += "\n当前使用:{}".format(using_key_name)
@ -140,7 +143,10 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str) ->
res = session.draw_image(" ".join(params))
logging.debug("draw_image result:{}".format(res))
reply = [Image(url=res['data'][0]['url']), " ".join(params)]
reply = [Image(url=res['data'][0]['url'])]
if not (hasattr(config, 'include_image_description')
and not config.include_image_description):
reply.append(" ".join(params))
except Exception as e:
mgr.notify_admin("{}指令执行失败:{}".format(session_name, e))
logging.exception(e)
@ -186,7 +192,8 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str) ->
if reply is not None and type(reply[0]) == str:
logging.info(
"回复[{}]文字消息:{}".format(session_name,
reply[0][:min(100, len(reply[0]))] + ("..." if len(reply[0]) > 100 else "")))
reply[0][:min(100, len(reply[0]))] + (
"..." if len(reply[0]) > 100 else "")))
reply = [mgr.reply_filter.process(reply[0])]
else:
logging.info("回复[{}]图片消息:{}".format(session_name, reply))
@ -196,4 +203,4 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str) ->
finally:
pkg.openai.session.get_session(session_name).release_response_lock()
return reply
return MessageChain(reply)