Merge pull request #773 from RockChinQ/feat/multi-modal

Feat: 多模态
This commit is contained in:
Junyan Qin 2024-05-16 21:13:15 +08:00 committed by GitHub
commit a7f830dd73
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
42 changed files with 650 additions and 238 deletions

View File

@ -13,11 +13,16 @@ class CommandReturn(pydantic.BaseModel):
"""命令返回值 """命令返回值
""" """
text: typing.Optional[str] text: typing.Optional[str] = None
"""文本 """文本
""" """
image: typing.Optional[mirai.Image] image: typing.Optional[mirai.Image] = None
"""弃用"""
image_url: typing.Optional[str] = None
"""图片链接
"""
error: typing.Optional[errors.CommandError]= None error: typing.Optional[errors.CommandError]= None
"""错误 """错误

View File

@ -24,7 +24,7 @@ class DefaultOperator(operator.CommandOperator):
content = "" content = ""
for msg in prompt.messages: for msg in prompt.messages:
content += f" {msg.role}: {msg.content}" content += f" {msg.readable_str()}\n"
reply_str += f"名称: {prompt.name}\n内容: \n{content}\n\n" reply_str += f"名称: {prompt.name}\n内容: \n{content}\n\n"
@ -45,18 +45,18 @@ class DefaultSetOperator(operator.CommandOperator):
context: entities.ExecuteContext context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]: ) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0: if len(context.crt_params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供情景预设名称')) yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供情景预设名称'))
else: else:
prompt_name = context.crt_params[0] prompt_name = context.crt_params[0]
try: try:
prompt = await self.ap.prompt_mgr.get_prompt_by_prefix(prompt_name) prompt = await self.ap.prompt_mgr.get_prompt_by_prefix(prompt_name)
if prompt is None: if prompt is None:
yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: 未找到情景预设 {}".format(prompt_name))) yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: 未找到情景预设 {}".format(prompt_name)))
else: else:
context.session.use_prompt_name = prompt.name context.session.use_prompt_name = prompt.name
yield entities.CommandReturn(text=f"已设置当前会话默认情景预设为 {prompt_name}, !reset 后生效") yield entities.CommandReturn(text=f"已设置当前会话默认情景预设为 {prompt_name}, !reset 后生效")
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: "+str(e))) yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: "+str(e)))

View File

@ -30,7 +30,7 @@ class LastOperator(operator.CommandOperator):
context.session.using_conversation = context.session.conversations[index-1] context.session.using_conversation = context.session.conversations[index-1]
time_str = context.session.using_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S") time_str = context.session.using_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S")
yield entities.CommandReturn(text=f"已切换到上一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].content}") yield entities.CommandReturn(text=f"已切换到上一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].readable_str()}")
return return
else: else:
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话')) yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))

View File

@ -42,7 +42,7 @@ class ListOperator(operator.CommandOperator):
using_conv_index = index using_conv_index = index
if index >= page * record_per_page and index < (page + 1) * record_per_page: if index >= page * record_per_page and index < (page + 1) * record_per_page:
content += f"{index} {time_str}: {conv.messages[0].content if len(conv.messages) > 0 else '无内容'}\n" content += f"{index} {time_str}: {conv.messages[0].readable_str() if len(conv.messages) > 0 else '无内容'}\n"
index += 1 index += 1
if content == '': if content == '':
@ -51,6 +51,6 @@ class ListOperator(operator.CommandOperator):
if context.session.using_conversation is None: if context.session.using_conversation is None:
content += "\n当前处于新会话" content += "\n当前处于新会话"
else: else:
content += f"\n当前会话: {using_conv_index} {context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S')}: {context.session.using_conversation.messages[0].content if len(context.session.using_conversation.messages) > 0 else '无内容'}" content += f"\n当前会话: {using_conv_index} {context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S')}: {context.session.using_conversation.messages[0].readable_str() if len(context.session.using_conversation.messages) > 0 else '无内容'}"
yield entities.CommandReturn(text=f"{page + 1} 页 (时间倒序):\n{content}") yield entities.CommandReturn(text=f"{page + 1} 页 (时间倒序):\n{content}")

View File

@ -27,7 +27,7 @@ class JSONConfigFile(file_model.ConfigFile):
else: else:
raise ValueError("template_file_name or template_data must be provided") raise ValueError("template_file_name or template_data must be provided")
async def load(self) -> dict: async def load(self, completion: bool=True) -> dict:
if not self.exists(): if not self.exists():
await self.create() await self.create()
@ -39,9 +39,11 @@ class JSONConfigFile(file_model.ConfigFile):
with open(self.config_file_name, "r", encoding="utf-8") as f: with open(self.config_file_name, "r", encoding="utf-8") as f:
cfg = json.load(f) cfg = json.load(f)
for key in self.template_data: if completion:
if key not in cfg:
cfg[key] = self.template_data[key] for key in self.template_data:
if key not in cfg:
cfg[key] = self.template_data[key]
return cfg return cfg

View File

@ -25,7 +25,7 @@ class PythonModuleConfigFile(file_model.ConfigFile):
async def create(self): async def create(self):
shutil.copyfile(self.template_file_name, self.config_file_name) shutil.copyfile(self.template_file_name, self.config_file_name)
async def load(self) -> dict: async def load(self, completion: bool=True) -> dict:
module_name = os.path.splitext(os.path.basename(self.config_file_name))[0] module_name = os.path.splitext(os.path.basename(self.config_file_name))[0]
module = importlib.import_module(module_name) module = importlib.import_module(module_name)
@ -43,18 +43,19 @@ class PythonModuleConfigFile(file_model.ConfigFile):
cfg[key] = getattr(module, key) cfg[key] = getattr(module, key)
# 从模板模块文件中进行补全 # 从模板模块文件中进行补全
module_name = os.path.splitext(os.path.basename(self.template_file_name))[0] if completion:
module = importlib.import_module(module_name) module_name = os.path.splitext(os.path.basename(self.template_file_name))[0]
module = importlib.import_module(module_name)
for key in dir(module): for key in dir(module):
if key.startswith('__'): if key.startswith('__'):
continue continue
if not isinstance(getattr(module, key), allowed_types): if not isinstance(getattr(module, key), allowed_types):
continue continue
if key not in cfg: if key not in cfg:
cfg[key] = getattr(module, key) cfg[key] = getattr(module, key)
return cfg return cfg

View File

@ -20,8 +20,8 @@ class ConfigManager:
self.file = cfg_file self.file = cfg_file
self.data = {} self.data = {}
async def load_config(self): async def load_config(self, completion: bool=True):
self.data = await self.file.load() self.data = await self.file.load(completion=completion)
async def dump_config(self): async def dump_config(self):
await self.file.save(self.data) await self.file.save(self.data)
@ -30,7 +30,7 @@ class ConfigManager:
self.file.save_sync(self.data) self.file.save_sync(self.data)
async def load_python_module_config(config_name: str, template_name: str) -> ConfigManager: async def load_python_module_config(config_name: str, template_name: str, completion: bool=True) -> ConfigManager:
"""加载Python模块配置文件""" """加载Python模块配置文件"""
cfg_inst = pymodule.PythonModuleConfigFile( cfg_inst = pymodule.PythonModuleConfigFile(
config_name, config_name,
@ -38,12 +38,12 @@ async def load_python_module_config(config_name: str, template_name: str) -> Con
) )
cfg_mgr = ConfigManager(cfg_inst) cfg_mgr = ConfigManager(cfg_inst)
await cfg_mgr.load_config() await cfg_mgr.load_config(completion=completion)
return cfg_mgr return cfg_mgr
async def load_json_config(config_name: str, template_name: str=None, template_data: dict=None) -> ConfigManager: async def load_json_config(config_name: str, template_name: str=None, template_data: dict=None, completion: bool=True) -> ConfigManager:
"""加载JSON配置文件""" """加载JSON配置文件"""
cfg_inst = json_file.JSONConfigFile( cfg_inst = json_file.JSONConfigFile(
config_name, config_name,
@ -52,6 +52,6 @@ async def load_json_config(config_name: str, template_name: str=None, template_d
) )
cfg_mgr = ConfigManager(cfg_inst) cfg_mgr = ConfigManager(cfg_inst)
await cfg_mgr.load_config() await cfg_mgr.load_config(completion=completion)
return cfg_mgr return cfg_mgr

View File

@ -0,0 +1,19 @@
from __future__ import annotations
from .. import migration
@migration.migration_class("vision-config", 6)
class VisionConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return "enable-vision" not in self.ap.provider_cfg.data
async def run(self):
"""执行迁移"""
if "enable-vision" not in self.ap.provider_cfg.data:
self.ap.provider_cfg.data["enable-vision"] = False
await self.ap.provider_cfg.dump_config()

View File

@ -22,7 +22,7 @@ class ConfigFile(metaclass=abc.ABCMeta):
pass pass
@abc.abstractmethod @abc.abstractmethod
async def load(self) -> dict: async def load(self, completion: bool=True) -> dict:
pass pass
@abc.abstractmethod @abc.abstractmethod

View File

@ -14,6 +14,7 @@ required_deps = {
"yaml": "pyyaml", "yaml": "pyyaml",
"aiohttp": "aiohttp", "aiohttp": "aiohttp",
"psutil": "psutil", "psutil": "psutil",
"async_lru": "async-lru",
} }

View File

@ -70,7 +70,7 @@ class Query(pydantic.BaseModel):
resp_messages: typing.Optional[list[llm_entities.Message]] = [] resp_messages: typing.Optional[list[llm_entities.Message]] = []
"""由Process阶段生成的回复消息对象列表""" """由Process阶段生成的回复消息对象列表"""
resp_message_chain: typing.Optional[mirai.MessageChain] = None resp_message_chain: typing.Optional[list[mirai.MessageChain]] = None
"""回复消息链从resp_messages包装而得""" """回复消息链从resp_messages包装而得"""
class Config: class Config:

View File

@ -15,7 +15,6 @@ from ...provider.sysprompt import sysprompt as llm_prompt_mgr
from ...provider.tools import toolmgr as llm_tool_mgr from ...provider.tools import toolmgr as llm_tool_mgr
from ...platform import manager as im_mgr from ...platform import manager as im_mgr
@stage.stage_class("BuildAppStage") @stage.stage_class("BuildAppStage")
class BuildAppStage(stage.BootingStage): class BuildAppStage(stage.BootingStage):
"""构建应用阶段 """构建应用阶段
@ -83,7 +82,6 @@ class BuildAppStage(stage.BootingStage):
llm_tool_mgr_inst = llm_tool_mgr.ToolManager(ap) llm_tool_mgr_inst = llm_tool_mgr.ToolManager(ap)
await llm_tool_mgr_inst.initialize() await llm_tool_mgr_inst.initialize()
ap.tool_mgr = llm_tool_mgr_inst ap.tool_mgr = llm_tool_mgr_inst
im_mgr_inst = im_mgr.PlatformManager(ap=ap) im_mgr_inst = im_mgr.PlatformManager(ap=ap)
await im_mgr_inst.initialize() await im_mgr_inst.initialize()
ap.platform_mgr = im_mgr_inst ap.platform_mgr = im_mgr_inst
@ -92,5 +90,6 @@ class BuildAppStage(stage.BootingStage):
await stage_mgr.initialize() await stage_mgr.initialize()
ap.stage_mgr = stage_mgr ap.stage_mgr = stage_mgr
ctrl = controller.Controller(ap) ctrl = controller.Controller(ap)
ap.ctrl = ctrl ap.ctrl = ctrl

View File

@ -12,11 +12,11 @@ class LoadConfigStage(stage.BootingStage):
async def run(self, ap: app.Application): async def run(self, ap: app.Application):
"""启动 """启动
""" """
ap.command_cfg = await config.load_json_config("data/config/command.json", "templates/command.json") ap.command_cfg = await config.load_json_config("data/config/command.json", "templates/command.json", completion=False)
ap.pipeline_cfg = await config.load_json_config("data/config/pipeline.json", "templates/pipeline.json") ap.pipeline_cfg = await config.load_json_config("data/config/pipeline.json", "templates/pipeline.json", completion=False)
ap.platform_cfg = await config.load_json_config("data/config/platform.json", "templates/platform.json") ap.platform_cfg = await config.load_json_config("data/config/platform.json", "templates/platform.json", completion=False)
ap.provider_cfg = await config.load_json_config("data/config/provider.json", "templates/provider.json") ap.provider_cfg = await config.load_json_config("data/config/provider.json", "templates/provider.json", completion=False)
ap.system_cfg = await config.load_json_config("data/config/system.json", "templates/system.json") ap.system_cfg = await config.load_json_config("data/config/system.json", "templates/system.json", completion=False)
ap.plugin_setting_meta = await config.load_json_config("plugins/plugins.json", "templates/plugin-settings.json") ap.plugin_setting_meta = await config.load_json_config("plugins/plugins.json", "templates/plugin-settings.json")
await ap.plugin_setting_meta.dump_config() await ap.plugin_setting_meta.dump_config()

View File

@ -4,7 +4,7 @@ import importlib
from .. import stage, app from .. import stage, app
from ...config import migration from ...config import migration
from ...config.migrations import m001_sensitive_word_migration, m002_openai_config_migration, m003_anthropic_requester_cfg_completion, m004_moonshot_cfg_completion from ...config.migrations import m001_sensitive_word_migration, m002_openai_config_migration, m003_anthropic_requester_cfg_completion, m004_moonshot_cfg_completion, m006_vision_config
from ...config.migrations import m005_deepseek_cfg_completion from ...config.migrations import m005_deepseek_cfg_completion

View File

@ -8,7 +8,10 @@ from ...config import manager as cfg_mgr
@stage.stage_class('BanSessionCheckStage') @stage.stage_class('BanSessionCheckStage')
class BanSessionCheckStage(stage.PipelineStage): class BanSessionCheckStage(stage.PipelineStage):
"""访问控制处理阶段""" """访问控制处理阶段
仅检查query中群号或个人号是否在访问控制列表中
"""
async def initialize(self): async def initialize(self):
pass pass

View File

@ -9,12 +9,24 @@ from ...core import entities as core_entities
from ...config import manager as cfg_mgr from ...config import manager as cfg_mgr
from . import filter as filter_model, entities as filter_entities from . import filter as filter_model, entities as filter_entities
from .filters import cntignore, banwords, baiduexamine from .filters import cntignore, banwords, baiduexamine
from ...provider import entities as llm_entities
@stage.stage_class('PostContentFilterStage') @stage.stage_class('PostContentFilterStage')
@stage.stage_class('PreContentFilterStage') @stage.stage_class('PreContentFilterStage')
class ContentFilterStage(stage.PipelineStage): class ContentFilterStage(stage.PipelineStage):
"""内容过滤阶段""" """内容过滤阶段
前置
检查消息是否符合规则不符合则拦截
改写
message_chain
后置
检查AI回复消息是否符合规则可能进行改写不符合则拦截
改写
query.resp_messages
"""
filter_chain: list[filter_model.ContentFilter] filter_chain: list[filter_model.ContentFilter]
@ -130,6 +142,21 @@ class ContentFilterStage(stage.PipelineStage):
"""处理 """处理
""" """
if stage_inst_name == 'PreContentFilterStage': if stage_inst_name == 'PreContentFilterStage':
contain_non_text = False
for me in query.message_chain:
if not isinstance(me, mirai.Plain):
contain_non_text = True
break
if contain_non_text:
self.ap.logger.debug(f"消息中包含非文本消息,跳过内容过滤器检查。")
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
return await self._pre_process( return await self._pre_process(
str(query.message_chain).strip(), str(query.message_chain).strip(),
query query

View File

@ -4,6 +4,8 @@ import enum
import pydantic import pydantic
from ...provider import entities as llm_entities
class ResultLevel(enum.Enum): class ResultLevel(enum.Enum):
"""结果等级""" """结果等级"""
@ -38,7 +40,7 @@ class FilterResult(pydantic.BaseModel):
""" """
replacement: str replacement: str
"""替换后的消息 """替换后的文本消息
内容过滤器可以进行一些遮掩处理然后把遮掩后的消息返回 内容过滤器可以进行一些遮掩处理然后把遮掩后的消息返回
若没有修改内容也需要返回原消息 若没有修改内容也需要返回原消息

View File

@ -5,6 +5,7 @@ import typing
from ...core import app from ...core import app
from . import entities from . import entities
from ...provider import entities as llm_entities
preregistered_filters: list[typing.Type[ContentFilter]] = [] preregistered_filters: list[typing.Type[ContentFilter]] = []
@ -63,7 +64,7 @@ class ContentFilter(metaclass=abc.ABCMeta):
pass pass
@abc.abstractmethod @abc.abstractmethod
async def process(self, message: str) -> entities.FilterResult: async def process(self, message: str=None, image_url=None) -> entities.FilterResult:
"""处理消息 """处理消息
分为前后阶段具体取决于 enable_stages 的值 分为前后阶段具体取决于 enable_stages 的值
@ -71,6 +72,7 @@ class ContentFilter(metaclass=abc.ABCMeta):
Args: Args:
message (str): 需要检查的内容 message (str): 需要检查的内容
image_url (str): 要检查的图片的 URL
Returns: Returns:
entities.FilterResult: 过滤结果具体内容请查看 entities.FilterResult 类的文档 entities.FilterResult: 过滤结果具体内容请查看 entities.FilterResult 类的文档

View File

@ -8,7 +8,7 @@ from ....config import manager as cfg_mgr
@filter_model.filter_class("ban-word-filter") @filter_model.filter_class("ban-word-filter")
class BanWordFilter(filter_model.ContentFilter): class BanWordFilter(filter_model.ContentFilter):
"""根据内容禁言""" """根据内容过滤"""
async def initialize(self): async def initialize(self):
pass pass

View File

@ -16,6 +16,9 @@ from ...config import manager as cfg_mgr
@stage.stage_class("LongTextProcessStage") @stage.stage_class("LongTextProcessStage")
class LongTextProcessStage(stage.PipelineStage): class LongTextProcessStage(stage.PipelineStage):
"""长消息处理阶段 """长消息处理阶段
改写
- resp_message_chain
""" """
strategy_impl: strategy.LongTextStrategy strategy_impl: strategy.LongTextStrategy
@ -59,15 +62,15 @@ class LongTextProcessStage(stage.PipelineStage):
# 检查是否包含非 Plain 组件 # 检查是否包含非 Plain 组件
contains_non_plain = False contains_non_plain = False
for msg in query.resp_message_chain: for msg in query.resp_message_chain[-1]:
if not isinstance(msg, Plain): if not isinstance(msg, Plain):
contains_non_plain = True contains_non_plain = True
break break
if contains_non_plain: if contains_non_plain:
self.ap.logger.debug("消息中包含非 Plain 组件,跳过长消息处理。") self.ap.logger.debug("消息中包含非 Plain 组件,跳过长消息处理。")
elif len(str(query.resp_message_chain)) > self.ap.platform_cfg.data['long-text-process']['threshold']: elif len(str(query.resp_message_chain[-1])) > self.ap.platform_cfg.data['long-text-process']['threshold']:
query.resp_message_chain = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain), query)) query.resp_message_chain[-1] = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain[-1]), query))
return entities.StageProcessResult( return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,

View File

@ -43,7 +43,7 @@ class QueryPool:
message_event=message_event, message_event=message_event,
message_chain=message_chain, message_chain=message_chain,
resp_messages=[], resp_messages=[],
resp_message_chain=None, resp_message_chain=[],
adapter=adapter adapter=adapter
) )
self.queries.append(query) self.queries.append(query)

View File

@ -1,5 +1,7 @@
from __future__ import annotations from __future__ import annotations
import mirai
from .. import stage, entities, stagemgr from .. import stage, entities, stagemgr
from ...core import entities as core_entities from ...core import entities as core_entities
from ...provider import entities as llm_entities from ...provider import entities as llm_entities
@ -9,6 +11,16 @@ from ...plugin import events
@stage.stage_class("PreProcessor") @stage.stage_class("PreProcessor")
class PreProcessor(stage.PipelineStage): class PreProcessor(stage.PipelineStage):
"""请求预处理阶段 """请求预处理阶段
签出会话prompt上文模型内容函数
改写
- session
- prompt
- messages
- user_message
- use_model
- use_funcs
""" """
async def process( async def process(
@ -27,21 +39,42 @@ class PreProcessor(stage.PipelineStage):
query.prompt = conversation.prompt.copy() query.prompt = conversation.prompt.copy()
query.messages = conversation.messages.copy() query.messages = conversation.messages.copy()
query.user_message = llm_entities.Message(
role='user',
content=str(query.message_chain).strip()
)
query.use_model = conversation.use_model query.use_model = conversation.use_model
query.use_funcs = conversation.use_funcs query.use_funcs = conversation.use_funcs if query.use_model.tool_call_supported else None
# 检查vision是否启用没启用就删除所有图片
if not self.ap.provider_cfg.data['enable-vision'] or not query.use_model.vision_supported:
for msg in query.messages:
if isinstance(msg.content, list):
for me in msg.content:
if me.type == 'image_url':
msg.content.remove(me)
content_list = []
for me in query.message_chain:
if isinstance(me, mirai.Plain):
content_list.append(
llm_entities.ContentElement.from_text(me.text)
)
elif isinstance(me, mirai.Image):
if self.ap.provider_cfg.data['enable-vision'] and query.use_model.vision_supported:
if me.url is not None:
content_list.append(
llm_entities.ContentElement.from_image_url(str(me.url))
)
query.user_message = llm_entities.Message( # TODO 适配多模态输入
role='user',
content=content_list
)
# =========== 触发事件 PromptPreProcessing # =========== 触发事件 PromptPreProcessing
session = query.session
event_ctx = await self.ap.plugin_mgr.emit_event( event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.PromptPreProcessing( event=events.PromptPreProcessing(
session_name=f'{session.launcher_type.value}_{session.launcher_id}', session_name=f'{query.session.launcher_type.value}_{query.session.launcher_id}',
default_prompt=query.prompt.messages, default_prompt=query.prompt.messages,
prompt=query.messages, prompt=query.messages,
query=query query=query

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import typing import typing
import time import time
import traceback import traceback
import json
import mirai import mirai
@ -70,17 +71,13 @@ class ChatMessageHandler(handler.MessageHandler):
mirai.Plain(event_ctx.event.alter) mirai.Plain(event_ctx.event.alter)
]) ])
query.messages.append(
query.user_message
)
text_length = 0 text_length = 0
start_time = time.time() start_time = time.time()
try: try:
async for result in query.use_model.requester.request(query): async for result in self.runner(query):
query.resp_messages.append(result) query.resp_messages.append(result)
self.ap.logger.info(f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}') self.ap.logger.info(f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}')
@ -92,6 +89,9 @@ class ChatMessageHandler(handler.MessageHandler):
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,
new_query=query new_query=query
) )
query.session.using_conversation.messages.append(query.user_message)
query.session.using_conversation.messages.extend(query.resp_messages)
except Exception as e: except Exception as e:
self.ap.logger.error(f'对话({query.query_id})请求失败: {str(e)}') self.ap.logger.error(f'对话({query.query_id})请求失败: {str(e)}')
@ -104,8 +104,6 @@ class ChatMessageHandler(handler.MessageHandler):
debug_notice=traceback.format_exc() debug_notice=traceback.format_exc()
) )
finally: finally:
query.session.using_conversation.messages.append(query.user_message)
query.session.using_conversation.messages.extend(query.resp_messages)
await self.ap.ctr_mgr.usage.post_query_record( await self.ap.ctr_mgr.usage.post_query_record(
session_type=query.session.launcher_type.value, session_type=query.session.launcher_type.value,
@ -115,4 +113,65 @@ class ChatMessageHandler(handler.MessageHandler):
model_name=query.use_model.name, model_name=query.use_model.name,
response_seconds=int(time.time() - start_time), response_seconds=int(time.time() - start_time),
retry_times=-1, retry_times=-1,
) )
async def runner(
self,
query: core_entities.Query,
) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""执行一个请求处理过程中的LLM接口请求、函数调用的循环
这是临时处理方案后续可能改为使用LangChain或者自研的工作流处理器
"""
await query.use_model.requester.preprocess(query)
pending_tool_calls = []
req_messages = query.prompt.messages.copy() + query.messages.copy() + [query.user_message]
# 首次请求
msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs)
yield msg
pending_tool_calls = msg.tool_calls
req_messages.append(msg)
# 持续请求,只要还有待处理的工具调用就继续处理调用
while pending_tool_calls:
for tool_call in pending_tool_calls:
try:
func = tool_call.function
parameters = json.loads(func.arguments)
func_ret = await self.ap.tool_mgr.execute_func_call(
query, func.name, parameters
)
msg = llm_entities.Message(
role="tool", content=json.dumps(func_ret, ensure_ascii=False), tool_call_id=tool_call.id
)
yield msg
req_messages.append(msg)
except Exception as e:
# 工具调用出错,添加一个报错信息到 req_messages
err_msg = llm_entities.Message(
role="tool", content=f"err: {e}", tool_call_id=tool_call.id
)
yield err_msg
req_messages.append(err_msg)
# 处理完所有调用,再次请求
msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs)
yield msg
pending_tool_calls = msg.tool_calls
req_messages.append(msg)

View File

@ -80,9 +80,6 @@ class CommandHandler(handler.MessageHandler):
session=session session=session
): ):
if ret.error is not None: if ret.error is not None:
# query.resp_message_chain = mirai.MessageChain([
# mirai.Plain(str(ret.error))
# ])
query.resp_messages.append( query.resp_messages.append(
llm_entities.Message( llm_entities.Message(
role='command', role='command',
@ -96,18 +93,28 @@ class CommandHandler(handler.MessageHandler):
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,
new_query=query new_query=query
) )
elif ret.text is not None: elif ret.text is not None or ret.image_url is not None:
# query.resp_message_chain = mirai.MessageChain([
# mirai.Plain(ret.text) content: list[llm_entities.ContentElement]= []
# ])
if ret.text is not None:
content.append(
llm_entities.ContentElement.from_text(ret.text)
)
if ret.image_url is not None:
content.append(
llm_entities.ContentElement.from_image_url(ret.image_url)
)
query.resp_messages.append( query.resp_messages.append(
llm_entities.Message( llm_entities.Message(
role='command', role='command',
content=ret.text, content=content,
) )
) )
self.ap.logger.info(f'命令返回: {self.cut_str(ret.text)}') self.ap.logger.info(f'命令返回: {self.cut_str(str(content[0]))}')
yield entities.StageProcessResult( yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,

View File

@ -11,7 +11,13 @@ from ...config import manager as cfg_mgr
@stage.stage_class("MessageProcessor") @stage.stage_class("MessageProcessor")
class Processor(stage.PipelineStage): class Processor(stage.PipelineStage):
"""请求实际处理阶段""" """请求实际处理阶段
通过命令处理器和聊天处理器处理消息
改写
- resp_messages
"""
cmd_handler: handler.MessageHandler cmd_handler: handler.MessageHandler

View File

@ -11,7 +11,10 @@ from ...core import entities as core_entities
@stage.stage_class("RequireRateLimitOccupancy") @stage.stage_class("RequireRateLimitOccupancy")
@stage.stage_class("ReleaseRateLimitOccupancy") @stage.stage_class("ReleaseRateLimitOccupancy")
class RateLimit(stage.PipelineStage): class RateLimit(stage.PipelineStage):
"""限速器控制阶段""" """限速器控制阶段
不改写query只检查是否需要限速
"""
algo: algo.ReteLimitAlgo algo: algo.ReteLimitAlgo

View File

@ -31,7 +31,7 @@ class SendResponseBackStage(stage.PipelineStage):
await self.ap.platform_mgr.send( await self.ap.platform_mgr.send(
query.message_event, query.message_event,
query.resp_message_chain, query.resp_message_chain[-1],
adapter=query.adapter adapter=query.adapter
) )

View File

@ -14,9 +14,12 @@ from ...config import manager as cfg_mgr
@stage.stage_class("GroupRespondRuleCheckStage") @stage.stage_class("GroupRespondRuleCheckStage")
class GroupRespondRuleCheckStage(stage.PipelineStage): class GroupRespondRuleCheckStage(stage.PipelineStage):
"""群组响应规则检查器 """群组响应规则检查器
仅检查群消息是否符合规则
""" """
rule_matchers: list[rule.GroupRespondRule] rule_matchers: list[rule.GroupRespondRule]
"""检查器实例"""
async def initialize(self): async def initialize(self):
"""初始化检查器 """初始化检查器
@ -31,7 +34,7 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
if query.launcher_type.value != 'group': if query.launcher_type.value != 'group': # 只处理群消息
return entities.StageProcessResult( return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,
new_query=query new_query=query

View File

@ -17,17 +17,17 @@ from .ratelimit import ratelimit
# 请求处理阶段顺序 # 请求处理阶段顺序
stage_order = [ stage_order = [
"GroupRespondRuleCheckStage", "GroupRespondRuleCheckStage", # 群响应规则检查
"BanSessionCheckStage", "BanSessionCheckStage", # 封禁会话检查
"PreContentFilterStage", "PreContentFilterStage", # 内容过滤前置阶段
"PreProcessor", "PreProcessor", # 预处理器
"RequireRateLimitOccupancy", "RequireRateLimitOccupancy", # 请求速率限制占用
"MessageProcessor", "MessageProcessor", # 处理器
"ReleaseRateLimitOccupancy", "ReleaseRateLimitOccupancy", # 释放速率限制占用
"PostContentFilterStage", "PostContentFilterStage", # 内容过滤后置阶段
"ResponseWrapper", "ResponseWrapper", # 响应包装器
"LongTextProcessStage", "LongTextProcessStage", # 长文本处理
"SendResponseBackStage", "SendResponseBackStage", # 发送响应
] ]

View File

@ -14,6 +14,13 @@ from ...plugin import events
@stage.stage_class("ResponseWrapper") @stage.stage_class("ResponseWrapper")
class ResponseWrapper(stage.PipelineStage): class ResponseWrapper(stage.PipelineStage):
"""回复包装阶段
把回复的 message 包装成人类识读的形式
改写
- resp_message_chain
"""
async def initialize(self): async def initialize(self):
pass pass
@ -27,17 +34,19 @@ class ResponseWrapper(stage.PipelineStage):
""" """
if query.resp_messages[-1].role == 'command': if query.resp_messages[-1].role == 'command':
query.resp_message_chain = mirai.MessageChain("[bot] "+query.resp_messages[-1].content) # query.resp_message_chain.append(mirai.MessageChain("[bot] "+query.resp_messages[-1].content))
query.resp_message_chain.append(query.resp_messages[-1].get_content_mirai_message_chain(prefix_text='[bot] '))
yield entities.StageProcessResult( yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,
new_query=query new_query=query
) )
elif query.resp_messages[-1].role == 'plugin': elif query.resp_messages[-1].role == 'plugin':
if not isinstance(query.resp_messages[-1].content, mirai.MessageChain): # if not isinstance(query.resp_messages[-1].content, mirai.MessageChain):
query.resp_message_chain = mirai.MessageChain(query.resp_messages[-1].content) # query.resp_message_chain.append(mirai.MessageChain(query.resp_messages[-1].content))
else: # else:
query.resp_message_chain = query.resp_messages[-1].content # query.resp_message_chain.append(query.resp_messages[-1].content)
query.resp_message_chain.append(query.resp_messages[-1].get_content_mirai_message_chain())
yield entities.StageProcessResult( yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,
@ -52,7 +61,7 @@ class ResponseWrapper(stage.PipelineStage):
reply_text = '' reply_text = ''
if result.content is not None: # 有内容 if result.content is not None: # 有内容
reply_text = result.content reply_text = str(result.get_content_mirai_message_chain())
# ============= 触发插件事件 =============== # ============= 触发插件事件 ===============
event_ctx = await self.ap.plugin_mgr.emit_event( event_ctx = await self.ap.plugin_mgr.emit_event(
@ -76,11 +85,11 @@ class ResponseWrapper(stage.PipelineStage):
else: else:
if event_ctx.event.reply is not None: if event_ctx.event.reply is not None:
query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply) query.resp_message_chain.append(mirai.MessageChain(event_ctx.event.reply))
else: else:
query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)]) query.resp_message_chain.append(result.get_content_mirai_message_chain())
yield entities.StageProcessResult( yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,
@ -93,7 +102,7 @@ class ResponseWrapper(stage.PipelineStage):
reply_text = f'调用函数 {".".join(function_names)}...' reply_text = f'调用函数 {".".join(function_names)}...'
query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)]) query.resp_message_chain.append(mirai.MessageChain([mirai.Plain(reply_text)]))
if self.ap.platform_cfg.data['track-function-calls']: if self.ap.platform_cfg.data['track-function-calls']:
@ -119,13 +128,13 @@ class ResponseWrapper(stage.PipelineStage):
else: else:
if event_ctx.event.reply is not None: if event_ctx.event.reply is not None:
query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply) query.resp_message_chain.append(mirai.MessageChain(event_ctx.event.reply))
else: else:
query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)]) query.resp_message_chain.append(mirai.MessageChain([mirai.Plain(reply_text)]))
yield entities.StageProcessResult( yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE, result_type=entities.ResultType.CONTINUE,
new_query=query new_query=query
) )

View File

@ -21,6 +21,39 @@ class ToolCall(pydantic.BaseModel):
function: FunctionCall function: FunctionCall
class ImageURLContentObject(pydantic.BaseModel):
url: str
def __str__(self):
return self.url[:128] + ('...' if len(self.url) > 128 else '')
class ContentElement(pydantic.BaseModel):
type: str
"""内容类型"""
text: typing.Optional[str] = None
image_url: typing.Optional[ImageURLContentObject] = None
def __str__(self):
if self.type == 'text':
return self.text
elif self.type == 'image_url':
return f'[图片]({self.image_url})'
else:
return '未知内容'
@classmethod
def from_text(cls, text: str):
return cls(type='text', text=text)
@classmethod
def from_image_url(cls, image_url: str):
return cls(type='image_url', image_url=ImageURLContentObject(url=image_url))
class Message(pydantic.BaseModel): class Message(pydantic.BaseModel):
"""消息""" """消息"""
@ -30,12 +63,9 @@ class Message(pydantic.BaseModel):
name: typing.Optional[str] = None name: typing.Optional[str] = None
"""名称,仅函数调用返回时设置""" """名称,仅函数调用返回时设置"""
content: typing.Optional[str] | typing.Optional[mirai.MessageChain] = None content: typing.Optional[list[ContentElement]] | typing.Optional[str] = None
"""内容""" """内容"""
function_call: typing.Optional[FunctionCall] = None
"""函数调用不再受支持请使用tool_calls"""
tool_calls: typing.Optional[list[ToolCall]] = None tool_calls: typing.Optional[list[ToolCall]] = None
"""工具调用""" """工具调用"""
@ -43,10 +73,38 @@ class Message(pydantic.BaseModel):
def readable_str(self) -> str: def readable_str(self) -> str:
if self.content is not None: if self.content is not None:
return str(self.content) return str(self.role) + ": " + str(self.get_content_mirai_message_chain())
elif self.function_call is not None:
return f'{self.function_call.name}({self.function_call.arguments})'
elif self.tool_calls is not None: elif self.tool_calls is not None:
return f'调用工具: {self.tool_calls[0].id}' return f'调用工具: {self.tool_calls[0].id}'
else: else:
return '未知消息' return '未知消息'
def get_content_mirai_message_chain(self, prefix_text: str="") -> mirai.MessageChain | None:
"""将内容转换为 Mirai MessageChain 对象
Args:
prefix_text (str): 首个文字组件的前缀文本
"""
if self.content is None:
return None
elif isinstance(self.content, str):
return mirai.MessageChain([mirai.Plain(prefix_text+self.content)])
elif isinstance(self.content, list):
mc = []
for ce in self.content:
if ce.type == 'text':
mc.append(mirai.Plain(ce.text))
elif ce.type == 'image':
mc.append(mirai.Image(url=ce.image_url))
# 找第一个文字组件
if prefix_text:
for i, c in enumerate(mc):
if isinstance(c, mirai.Plain):
mc[i] = mirai.Plain(prefix_text+c.text)
break
else:
mc.insert(0, mirai.Plain(prefix_text))
return mirai.MessageChain(mc)

View File

@ -6,6 +6,8 @@ import typing
from ...core import app from ...core import app
from ...core import entities as core_entities from ...core import entities as core_entities
from .. import entities as llm_entities from .. import entities as llm_entities
from . import entities as modelmgr_entities
from ..tools import entities as tools_entities
preregistered_requesters: list[typing.Type[LLMAPIRequester]] = [] preregistered_requesters: list[typing.Type[LLMAPIRequester]] = []
@ -33,20 +35,31 @@ class LLMAPIRequester(metaclass=abc.ABCMeta):
async def initialize(self): async def initialize(self):
pass pass
@abc.abstractmethod async def preprocess(
async def request(
self, self,
query: core_entities.Query, query: core_entities.Query,
) -> typing.AsyncGenerator[llm_entities.Message, None]: ):
"""请求API """预处理
在这里处理特定API对Query对象的兼容性问题
"""
pass
对话前文可以从 query 对象中获取 @abc.abstractmethod
可以多次yield消息对象 async def call(
self,
model: modelmgr_entities.LLMModelInfo,
messages: typing.List[llm_entities.Message],
funcs: typing.List[tools_entities.LLMFunction] = None,
) -> llm_entities.Message:
"""调用API
Args: Args:
query (core_entities.Query): 本次请求的上下文对象 model (modelmgr_entities.LLMModelInfo): 使用的模型信息
messages (typing.List[llm_entities.Message]): 消息对象列表
funcs (typing.List[tools_entities.LLMFunction], optional): 使用的工具函数列表. Defaults to None.
Yields: Returns:
pkg.provider.entities.Message: 返回消息对象 llm_entities.Message: 返回消息对象
""" """
raise NotImplementedError pass

View File

@ -27,47 +27,60 @@ class AnthropicMessages(api.LLMAPIRequester):
proxies=self.ap.proxy_mgr.get_forward_proxies() proxies=self.ap.proxy_mgr.get_forward_proxies()
) )
async def request( async def call(
self, self,
query: core_entities.Query, model: entities.LLMModelInfo,
) -> typing.AsyncGenerator[llm_entities.Message, None]: messages: typing.List[llm_entities.Message],
self.client.api_key = query.use_model.token_mgr.get_token() funcs: typing.List[tools_entities.LLMFunction] = None,
) -> llm_entities.Message:
self.client.api_key = model.token_mgr.get_token()
args = self.ap.provider_cfg.data['requester']['anthropic-messages']['args'].copy() args = self.ap.provider_cfg.data['requester']['anthropic-messages']['args'].copy()
args["model"] = query.use_model.name if query.use_model.model_name is None else query.use_model.model_name args["model"] = model.name if model.model_name is None else model.model_name
req_messages = [ # req_messages 仅用于类内,外部同步由 query.messages 进行 # 处理消息
m.dict(exclude_none=True) for m in query.prompt.messages if m.content.strip() != ""
] + [m.dict(exclude_none=True) for m in query.messages]
# 删除所有 role=system & content='' 的消息 # system
req_messages = [ system_role_message = None
m for m in req_messages if not (m["role"] == "system" and m["content"].strip() == "")
]
# 检查是否有 role=system 的消息,若有,改为 role=user并在后面加一个 role=assistant 的消息 for i, m in enumerate(messages):
system_role_index = [] if m.role == "system":
for i, m in enumerate(req_messages): system_role_message = m
if m["role"] == "system":
system_role_index.append(i)
m["role"] = "user"
if system_role_index: messages.pop(i)
for i in system_role_index[::-1]: break
req_messages.insert(i + 1, {"role": "assistant", "content": "Okay, I'll follow."})
# 忽略掉空消息,用户可能发送空消息,而上层未过滤 if isinstance(system_role_message, llm_entities.Message) \
req_messages = [ and isinstance(system_role_message.content, str):
m for m in req_messages if m["content"].strip() != "" args['system'] = system_role_message.content
]
# 其他消息
# req_messages = [
# m.dict(exclude_none=True) for m in messages \
# if (isinstance(m.content, str) and m.content.strip() != "") \
# or (isinstance(m.content, list) and )
# ]
# 暂时不支持vision仅保留纯文字的content
req_messages = []
for m in messages:
if isinstance(m.content, str) and m.content.strip() != "":
req_messages.append(m.dict(exclude_none=True))
elif isinstance(m.content, list):
# 删除m.content中的type!=text的元素
m.content = [
c for c in m.content if c.get("type") == "text"
]
if len(m.content) > 0:
req_messages.append(m.dict(exclude_none=True))
args["messages"] = req_messages args["messages"] = req_messages
try: try:
resp = await self.client.messages.create(**args) resp = await self.client.messages.create(**args)
yield llm_entities.Message( return llm_entities.Message(
content=resp.content[0].text, content=resp.content[0].text,
role=resp.role role=resp.role
) )
@ -79,4 +92,4 @@ class AnthropicMessages(api.LLMAPIRequester):
if 'model: ' in str(e): if 'model: ' in str(e):
raise errors.RequesterError(f'模型无效: {e.message}') raise errors.RequesterError(f'模型无效: {e.message}')
else: else:
raise errors.RequesterError(f'请求地址无效: {e.message}') raise errors.RequesterError(f'请求地址无效: {e.message}')

View File

@ -3,16 +3,20 @@ from __future__ import annotations
import asyncio import asyncio
import typing import typing
import json import json
import base64
from typing import AsyncGenerator from typing import AsyncGenerator
import openai import openai
import openai.types.chat.chat_completion as chat_completion import openai.types.chat.chat_completion as chat_completion
import httpx import httpx
import aiohttp
import async_lru
from .. import api, entities, errors from .. import api, entities, errors
from ....core import entities as core_entities, app from ....core import entities as core_entities, app
from ... import entities as llm_entities from ... import entities as llm_entities
from ...tools import entities as tools_entities from ...tools import entities as tools_entities
from ....utils import image
@api.requester_class("openai-chat-completions") @api.requester_class("openai-chat-completions")
@ -43,7 +47,6 @@ class OpenAIChatCompletions(api.LLMAPIRequester):
self, self,
args: dict, args: dict,
) -> chat_completion.ChatCompletion: ) -> chat_completion.ChatCompletion:
self.ap.logger.debug(f"req chat_completion with args {args}")
return await self.client.chat.completions.create(**args) return await self.client.chat.completions.create(**args)
async def _make_msg( async def _make_msg(
@ -67,14 +70,22 @@ class OpenAIChatCompletions(api.LLMAPIRequester):
args = self.requester_cfg['args'].copy() args = self.requester_cfg['args'].copy()
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
if use_model.tool_call_supported: if use_funcs:
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs) tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
if tools: if tools:
args["tools"] = tools args["tools"] = tools
# 设置此次请求中的messages # 设置此次请求中的messages
messages = req_messages messages = req_messages.copy()
# 检查vision
for msg in messages:
if 'content' in msg and isinstance(msg["content"], list):
for me in msg["content"]:
if me["type"] == "image_url":
me["image_url"]['url'] = await self.get_base64_str(me["image_url"]['url'])
args["messages"] = messages args["messages"] = messages
# 发送请求 # 发送请求
@ -84,73 +95,19 @@ class OpenAIChatCompletions(api.LLMAPIRequester):
message = await self._make_msg(resp) message = await self._make_msg(resp)
return message return message
async def _request( async def call(
self, query: core_entities.Query self,
) -> typing.AsyncGenerator[llm_entities.Message, None]: model: entities.LLMModelInfo,
"""请求""" messages: typing.List[llm_entities.Message],
funcs: typing.List[tools_entities.LLMFunction] = None,
pending_tool_calls = [] ) -> llm_entities.Message:
req_messages = [ # req_messages 仅用于类内,外部同步由 query.messages 进行 req_messages = [ # req_messages 仅用于类内,外部同步由 query.messages 进行
m.dict(exclude_none=True) for m in query.prompt.messages if m.content.strip() != "" m.dict(exclude_none=True) for m in messages
] + [m.dict(exclude_none=True) for m in query.messages] ]
# req_messages.append({"role": "user", "content": str(query.message_chain)})
# 首次请求
msg = await self._closure(req_messages, query.use_model, query.use_funcs)
yield msg
pending_tool_calls = msg.tool_calls
req_messages.append(msg.dict(exclude_none=True))
# 持续请求,只要还有待处理的工具调用就继续处理调用
while pending_tool_calls:
for tool_call in pending_tool_calls:
try:
func = tool_call.function
parameters = json.loads(func.arguments)
func_ret = await self.ap.tool_mgr.execute_func_call(
query, func.name, parameters
)
msg = llm_entities.Message(
role="tool", content=json.dumps(func_ret, ensure_ascii=False), tool_call_id=tool_call.id
)
yield msg
req_messages.append(msg.dict(exclude_none=True))
except Exception as e:
# 出错,添加一个报错信息到 req_messages
err_msg = llm_entities.Message(
role="tool", content=f"err: {e}", tool_call_id=tool_call.id
)
yield err_msg
req_messages.append(
err_msg.dict(exclude_none=True)
)
# 处理完所有调用,继续请求
msg = await self._closure(req_messages, query.use_model, query.use_funcs)
yield msg
pending_tool_calls = msg.tool_calls
req_messages.append(msg.dict(exclude_none=True))
async def request(self, query: core_entities.Query) -> AsyncGenerator[llm_entities.Message, None]:
try: try:
async for msg in self._request(query): return await self._closure(req_messages, model, funcs)
yield msg
except asyncio.TimeoutError: except asyncio.TimeoutError:
raise errors.RequesterError('请求超时') raise errors.RequesterError('请求超时')
except openai.BadRequestError as e: except openai.BadRequestError as e:
@ -163,6 +120,16 @@ class OpenAIChatCompletions(api.LLMAPIRequester):
except openai.NotFoundError as e: except openai.NotFoundError as e:
raise errors.RequesterError(f'请求路径错误: {e.message}') raise errors.RequesterError(f'请求路径错误: {e.message}')
except openai.RateLimitError as e: except openai.RateLimitError as e:
raise errors.RequesterError(f'请求过于频繁: {e.message}') raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}')
except openai.APIError as e: except openai.APIError as e:
raise errors.RequesterError(f'请求错误: {e.message}') raise errors.RequesterError(f'请求错误: {e.message}')
@async_lru.alru_cache(maxsize=128)
async def get_base64_str(
self,
original_url: str,
) -> str:
base64_image = await image.qq_image_url_to_base64(original_url)
return f"data:image/jpeg;base64,{base64_image}"

View File

@ -3,7 +3,10 @@ from __future__ import annotations
from ....core import app from ....core import app
from . import chatcmpl from . import chatcmpl
from .. import api from .. import api, entities, errors
from ....core import entities as core_entities, app
from ... import entities as llm_entities
from ...tools import entities as tools_entities
@api.requester_class("deepseek-chat-completions") @api.requester_class("deepseek-chat-completions")
@ -12,4 +15,39 @@ class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions):
def __init__(self, ap: app.Application): def __init__(self, ap: app.Application):
self.requester_cfg = ap.provider_cfg.data['requester']['deepseek-chat-completions'] self.requester_cfg = ap.provider_cfg.data['requester']['deepseek-chat-completions']
self.ap = ap self.ap = ap
async def _closure(
self,
req_messages: list[dict],
use_model: entities.LLMModelInfo,
use_funcs: list[tools_entities.LLMFunction] = None,
) -> llm_entities.Message:
self.client.api_key = use_model.token_mgr.get_token()
args = self.requester_cfg['args'].copy()
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
if use_funcs:
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
if tools:
args["tools"] = tools
# 设置此次请求中的messages
messages = req_messages
# deepseek 不支持多模态把content都转换成纯文字
for m in messages:
if 'content' in m and isinstance(m["content"], list):
m["content"] = " ".join([c["text"] for c in m["content"]])
args["messages"] = messages
# 发送请求
resp = await self._req(args)
# 处理请求结果
message = await self._make_msg(resp)
return message

View File

@ -3,7 +3,10 @@ from __future__ import annotations
from ....core import app from ....core import app
from . import chatcmpl from . import chatcmpl
from .. import api from .. import api, entities, errors
from ....core import entities as core_entities, app
from ... import entities as llm_entities
from ...tools import entities as tools_entities
@api.requester_class("moonshot-chat-completions") @api.requester_class("moonshot-chat-completions")
@ -13,3 +16,41 @@ class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions):
def __init__(self, ap: app.Application): def __init__(self, ap: app.Application):
self.requester_cfg = ap.provider_cfg.data['requester']['moonshot-chat-completions'] self.requester_cfg = ap.provider_cfg.data['requester']['moonshot-chat-completions']
self.ap = ap self.ap = ap
async def _closure(
self,
req_messages: list[dict],
use_model: entities.LLMModelInfo,
use_funcs: list[tools_entities.LLMFunction] = None,
) -> llm_entities.Message:
self.client.api_key = use_model.token_mgr.get_token()
args = self.requester_cfg['args'].copy()
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
if use_funcs:
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
if tools:
args["tools"] = tools
# 设置此次请求中的messages
messages = req_messages
# deepseek 不支持多模态把content都转换成纯文字
for m in messages:
if 'content' in m and isinstance(m["content"], list):
m["content"] = " ".join([c["text"] for c in m["content"]])
# 删除空的
messages = [m for m in messages if m["content"].strip() != ""]
args["messages"] = messages
# 发送请求
resp = await self._req(args)
# 处理请求结果
message = await self._make_msg(resp)
return message

View File

@ -21,5 +21,7 @@ class LLMModelInfo(pydantic.BaseModel):
tool_call_supported: typing.Optional[bool] = False tool_call_supported: typing.Optional[bool] = False
vision_supported: typing.Optional[bool] = False
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True

View File

@ -37,7 +37,7 @@ class ModelManager:
raise ValueError(f"无法确定模型 {name} 的信息,请在元数据中配置") raise ValueError(f"无法确定模型 {name} 的信息,请在元数据中配置")
async def initialize(self): async def initialize(self):
# 初始化token_mgr, requester # 初始化token_mgr, requester
for k, v in self.ap.provider_cfg.data['keys'].items(): for k, v in self.ap.provider_cfg.data['keys'].items():
self.token_mgrs[k] = token.TokenManager(k, v) self.token_mgrs[k] = token.TokenManager(k, v)
@ -83,7 +83,8 @@ class ModelManager:
model_name=None, model_name=None,
token_mgr=self.token_mgrs[model['token_mgr']], token_mgr=self.token_mgrs[model['token_mgr']],
requester=self.requesters[model['requester']], requester=self.requesters[model['requester']],
tool_call_supported=model['tool_call_supported'] tool_call_supported=model['tool_call_supported'],
vision_supported=model['vision_supported']
) )
break break
@ -95,13 +96,15 @@ class ModelManager:
token_mgr = self.token_mgrs[model['token_mgr']] if 'token_mgr' in model else default_model_info.token_mgr token_mgr = self.token_mgrs[model['token_mgr']] if 'token_mgr' in model else default_model_info.token_mgr
requester = self.requesters[model['requester']] if 'requester' in model else default_model_info.requester requester = self.requesters[model['requester']] if 'requester' in model else default_model_info.requester
tool_call_supported = model.get('tool_call_supported', default_model_info.tool_call_supported) tool_call_supported = model.get('tool_call_supported', default_model_info.tool_call_supported)
vision_supported = model.get('vision_supported', default_model_info.vision_supported)
model_info = entities.LLMModelInfo( model_info = entities.LLMModelInfo(
name=model['name'], name=model['name'],
model_name=model_name, model_name=model_name,
token_mgr=token_mgr, token_mgr=token_mgr,
requester=requester, requester=requester,
tool_call_supported=tool_call_supported tool_call_supported=tool_call_supported,
vision_supported=vision_supported
) )
self.model_list.append(model_info) self.model_list.append(model_info)

41
pkg/utils/image.py Normal file
View File

@ -0,0 +1,41 @@
import base64
import typing
from urllib.parse import urlparse, parse_qs
import ssl
import aiohttp
async def qq_image_url_to_base64(
image_url: str
) -> str:
"""将QQ图片URL转为base64
Args:
image_url (str): QQ图片URL
Returns:
str: base64编码
"""
parsed = urlparse(image_url)
query = parse_qs(parsed.query)
# Flatten the query dictionary
query = {k: v[0] for k, v in query.items()}
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
async with aiohttp.ClientSession(trust_env=False) as session:
async with session.get(
f"http://{parsed.netloc}{parsed.path}",
params=query,
ssl=ssl_context
) as resp:
resp.raise_for_status() # 检查HTTP错误
file_bytes = await resp.read()
base64_str = base64.b64encode(file_bytes).decode()
return base64_str

View File

@ -13,4 +13,5 @@ aiohttp
pydantic pydantic
websockets websockets
urllib3 urllib3
psutil psutil
async-lru

View File

@ -4,23 +4,73 @@
"name": "default", "name": "default",
"requester": "openai-chat-completions", "requester": "openai-chat-completions",
"token_mgr": "openai", "token_mgr": "openai",
"tool_call_supported": false "tool_call_supported": false,
"vision_supported": false
},
{
"name": "gpt-3.5-turbo-0125",
"tool_call_supported": true,
"vision_supported": false
}, },
{ {
"name": "gpt-3.5-turbo", "name": "gpt-3.5-turbo",
"tool_call_supported": true "tool_call_supported": true,
"vision_supported": false
}, },
{ {
"name": "gpt-4", "name": "gpt-3.5-turbo-1106",
"tool_call_supported": true "tool_call_supported": true,
"vision_supported": false
},
{
"name": "gpt-4-turbo",
"tool_call_supported": true,
"vision_supported": true
},
{
"name": "gpt-4-turbo-2024-04-09",
"tool_call_supported": true,
"vision_supported": true
}, },
{ {
"name": "gpt-4-turbo-preview", "name": "gpt-4-turbo-preview",
"tool_call_supported": true "tool_call_supported": true,
"vision_supported": true
},
{
"name": "gpt-4-0125-preview",
"tool_call_supported": true,
"vision_supported": true
},
{
"name": "gpt-4-1106-preview",
"tool_call_supported": true,
"vision_supported": true
},
{
"name": "gpt-4",
"tool_call_supported": true,
"vision_supported": true
},
{
"name": "gpt-4o",
"tool_call_supported": true,
"vision_supported": true
},
{
"name": "gpt-4-0613",
"tool_call_supported": true,
"vision_supported": true
}, },
{ {
"name": "gpt-4-32k", "name": "gpt-4-32k",
"tool_call_supported": true "tool_call_supported": true,
"vision_supported": true
},
{
"name": "gpt-4-32k-0613",
"tool_call_supported": true,
"vision_supported": true
}, },
{ {
"model_name": "SparkDesk", "model_name": "SparkDesk",

View File

@ -1,5 +1,6 @@
{ {
"enable-chat": true, "enable-chat": true,
"enable-vision": true,
"keys": { "keys": {
"openai": [ "openai": [
"sk-1234567890" "sk-1234567890"