From 527ad81d38aa11d869cb37687ee832c865fd3bb3 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Tue, 14 May 2024 22:20:31 +0800 Subject: [PATCH 01/11] =?UTF-8?q?feat:=20=E8=A7=A3=E8=97=95chat=E7=9A=84?= =?UTF-8?q?=E5=A4=84=E7=90=86=E5=99=A8=E5=92=8C=E8=AF=B7=E6=B1=82=E5=99=A8?= =?UTF-8?q?=20(#772)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/command/operators/default.py | 32 ++++----- pkg/command/operators/last.py | 2 +- pkg/pipeline/bansess/bansess.py | 5 +- pkg/pipeline/cntfilter/cntfilter.py | 13 +++- pkg/pipeline/cntfilter/filters/banwords.py | 2 +- pkg/pipeline/longtext/longtext.py | 3 + pkg/pipeline/preproc/preproc.py | 15 +++- pkg/pipeline/process/handlers/chat.py | 70 +++++++++++++++++-- pkg/pipeline/process/process.py | 8 ++- pkg/pipeline/ratelimit/ratelimit.py | 5 +- pkg/pipeline/resprule/resprule.py | 5 +- pkg/pipeline/stagemgr.py | 22 +++--- pkg/pipeline/wrapper/wrapper.py | 9 ++- pkg/provider/entities.py | 17 +++-- pkg/provider/modelmgr/api.py | 33 ++++++--- pkg/provider/modelmgr/apis/anthropicmsgs.py | 25 +++---- pkg/provider/modelmgr/apis/chatcmpl.py | 76 +++------------------ 17 files changed, 205 insertions(+), 137 deletions(-) diff --git a/pkg/command/operators/default.py b/pkg/command/operators/default.py index ca7e404..ee46c7d 100644 --- a/pkg/command/operators/default.py +++ b/pkg/command/operators/default.py @@ -24,7 +24,7 @@ class DefaultOperator(operator.CommandOperator): content = "" for msg in prompt.messages: - content += f" {msg.role}: {msg.content}" + content += f" {msg.readable_str()}\n" reply_str += f"名称: {prompt.name}\n内容: \n{content}\n\n" @@ -45,18 +45,18 @@ class DefaultSetOperator(operator.CommandOperator): context: entities.ExecuteContext ) -> typing.AsyncGenerator[entities.CommandReturn, None]: - if len(context.crt_params) == 0: - yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供情景预设名称')) - else: - prompt_name = context.crt_params[0] - - try: - prompt = await self.ap.prompt_mgr.get_prompt_by_prefix(prompt_name) - if prompt is None: - yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: 未找到情景预设 {}".format(prompt_name))) - else: - context.session.use_prompt_name = prompt.name - yield entities.CommandReturn(text=f"已设置当前会话默认情景预设为 {prompt_name}, !reset 后生效") - except Exception as e: - traceback.print_exc() - yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: "+str(e))) + if len(context.crt_params) == 0: + yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供情景预设名称')) + else: + prompt_name = context.crt_params[0] + + try: + prompt = await self.ap.prompt_mgr.get_prompt_by_prefix(prompt_name) + if prompt is None: + yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: 未找到情景预设 {}".format(prompt_name))) + else: + context.session.use_prompt_name = prompt.name + yield entities.CommandReturn(text=f"已设置当前会话默认情景预设为 {prompt_name}, !reset 后生效") + except Exception as e: + traceback.print_exc() + yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: "+str(e))) diff --git a/pkg/command/operators/last.py b/pkg/command/operators/last.py index 8e3a523..e7a14c8 100644 --- a/pkg/command/operators/last.py +++ b/pkg/command/operators/last.py @@ -30,7 +30,7 @@ class LastOperator(operator.CommandOperator): context.session.using_conversation = context.session.conversations[index-1] time_str = context.session.using_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S") - yield entities.CommandReturn(text=f"已切换到上一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].content}") + yield entities.CommandReturn(text=f"已切换到上一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].readable_str()}") return else: yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话')) \ No newline at end of file diff --git a/pkg/pipeline/bansess/bansess.py b/pkg/pipeline/bansess/bansess.py index 95a7cff..9c04138 100644 --- a/pkg/pipeline/bansess/bansess.py +++ b/pkg/pipeline/bansess/bansess.py @@ -8,7 +8,10 @@ from ...config import manager as cfg_mgr @stage.stage_class('BanSessionCheckStage') class BanSessionCheckStage(stage.PipelineStage): - """访问控制处理阶段""" + """访问控制处理阶段 + + 仅检查query中群号或个人号是否在访问控制列表中。 + """ async def initialize(self): pass diff --git a/pkg/pipeline/cntfilter/cntfilter.py b/pkg/pipeline/cntfilter/cntfilter.py index 21b6c25..2c6a5ab 100644 --- a/pkg/pipeline/cntfilter/cntfilter.py +++ b/pkg/pipeline/cntfilter/cntfilter.py @@ -14,7 +14,18 @@ from .filters import cntignore, banwords, baiduexamine @stage.stage_class('PostContentFilterStage') @stage.stage_class('PreContentFilterStage') class ContentFilterStage(stage.PipelineStage): - """内容过滤阶段""" + """内容过滤阶段 + + 前置: + 检查消息是否符合规则,不符合则拦截。 + 改写: + message_chain + + 后置: + 检查AI回复消息是否符合规则,可能进行改写,不符合则拦截。 + 改写: + query.resp_messages + """ filter_chain: list[filter_model.ContentFilter] diff --git a/pkg/pipeline/cntfilter/filters/banwords.py b/pkg/pipeline/cntfilter/filters/banwords.py index 5cd7dcf..1430c2e 100644 --- a/pkg/pipeline/cntfilter/filters/banwords.py +++ b/pkg/pipeline/cntfilter/filters/banwords.py @@ -8,7 +8,7 @@ from ....config import manager as cfg_mgr @filter_model.filter_class("ban-word-filter") class BanWordFilter(filter_model.ContentFilter): - """根据内容禁言""" + """根据内容过滤""" async def initialize(self): pass diff --git a/pkg/pipeline/longtext/longtext.py b/pkg/pipeline/longtext/longtext.py index 28c2814..ec0e66e 100644 --- a/pkg/pipeline/longtext/longtext.py +++ b/pkg/pipeline/longtext/longtext.py @@ -16,6 +16,9 @@ from ...config import manager as cfg_mgr @stage.stage_class("LongTextProcessStage") class LongTextProcessStage(stage.PipelineStage): """长消息处理阶段 + + 改写: + - resp_message_chain """ strategy_impl: strategy.LongTextStrategy diff --git a/pkg/pipeline/preproc/preproc.py b/pkg/pipeline/preproc/preproc.py index cedc030..164f78c 100644 --- a/pkg/pipeline/preproc/preproc.py +++ b/pkg/pipeline/preproc/preproc.py @@ -9,6 +9,16 @@ from ...plugin import events @stage.stage_class("PreProcessor") class PreProcessor(stage.PipelineStage): """请求预处理阶段 + + 签出会话、prompt、上文、模型、内容函数。 + + 改写: + - session + - prompt + - messages + - user_message + - use_model + - use_funcs """ async def process( @@ -27,7 +37,7 @@ class PreProcessor(stage.PipelineStage): query.prompt = conversation.prompt.copy() query.messages = conversation.messages.copy() - query.user_message = llm_entities.Message( + query.user_message = llm_entities.Message( # TODO 适配多模态输入 role='user', content=str(query.message_chain).strip() ) @@ -37,11 +47,10 @@ class PreProcessor(stage.PipelineStage): query.use_funcs = conversation.use_funcs # =========== 触发事件 PromptPreProcessing - session = query.session event_ctx = await self.ap.plugin_mgr.emit_event( event=events.PromptPreProcessing( - session_name=f'{session.launcher_type.value}_{session.launcher_id}', + session_name=f'{query.session.launcher_type.value}_{query.session.launcher_id}', default_prompt=query.prompt.messages, prompt=query.messages, query=query diff --git a/pkg/pipeline/process/handlers/chat.py b/pkg/pipeline/process/handlers/chat.py index f38ee34..2f0616f 100644 --- a/pkg/pipeline/process/handlers/chat.py +++ b/pkg/pipeline/process/handlers/chat.py @@ -3,6 +3,7 @@ from __future__ import annotations import typing import time import traceback +import json import mirai @@ -70,17 +71,13 @@ class ChatMessageHandler(handler.MessageHandler): mirai.Plain(event_ctx.event.alter) ]) - query.messages.append( - query.user_message - ) - text_length = 0 start_time = time.time() try: - async for result in query.use_model.requester.request(query): + async for result in self.runner(query): query.resp_messages.append(result) self.ap.logger.info(f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}') @@ -115,4 +112,65 @@ class ChatMessageHandler(handler.MessageHandler): model_name=query.use_model.name, response_seconds=int(time.time() - start_time), retry_times=-1, - ) \ No newline at end of file + ) + + async def runner( + self, + query: core_entities.Query, + ) -> typing.AsyncGenerator[llm_entities.Message, None]: + """执行一个请求处理过程中的LLM接口请求、函数调用的循环 + + 这是临时处理方案,后续可能改为使用LangChain或者自研的工作流处理器 + """ + await query.use_model.requester.preprocess(query) + + pending_tool_calls = [] + + req_messages = query.prompt.messages.copy() + query.messages.copy() + [query.user_message] + + # 首次请求 + msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs) + + yield msg + + pending_tool_calls = msg.tool_calls + + req_messages.append(msg) + + # 持续请求,只要还有待处理的工具调用就继续处理调用 + while pending_tool_calls: + for tool_call in pending_tool_calls: + try: + func = tool_call.function + + parameters = json.loads(func.arguments) + + func_ret = await self.ap.tool_mgr.execute_func_call( + query, func.name, parameters + ) + + msg = llm_entities.Message( + role="tool", content=json.dumps(func_ret, ensure_ascii=False), tool_call_id=tool_call.id + ) + + yield msg + + req_messages.append(msg) + except Exception as e: + # 工具调用出错,添加一个报错信息到 req_messages + err_msg = llm_entities.Message( + role="tool", content=f"err: {e}", tool_call_id=tool_call.id + ) + + yield err_msg + + req_messages.append(err_msg) + + # 处理完所有调用,再次请求 + msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs) + + yield msg + + pending_tool_calls = msg.tool_calls + + req_messages.append(msg) diff --git a/pkg/pipeline/process/process.py b/pkg/pipeline/process/process.py index ddf8809..e58d15e 100644 --- a/pkg/pipeline/process/process.py +++ b/pkg/pipeline/process/process.py @@ -11,7 +11,13 @@ from ...config import manager as cfg_mgr @stage.stage_class("MessageProcessor") class Processor(stage.PipelineStage): - """请求实际处理阶段""" + """请求实际处理阶段 + + 通过命令处理器和聊天处理器处理消息。 + + 改写: + - resp_messages + """ cmd_handler: handler.MessageHandler diff --git a/pkg/pipeline/ratelimit/ratelimit.py b/pkg/pipeline/ratelimit/ratelimit.py index 2622247..cd39b85 100644 --- a/pkg/pipeline/ratelimit/ratelimit.py +++ b/pkg/pipeline/ratelimit/ratelimit.py @@ -11,7 +11,10 @@ from ...core import entities as core_entities @stage.stage_class("RequireRateLimitOccupancy") @stage.stage_class("ReleaseRateLimitOccupancy") class RateLimit(stage.PipelineStage): - """限速器控制阶段""" + """限速器控制阶段 + + 不改写query,只检查是否需要限速。 + """ algo: algo.ReteLimitAlgo diff --git a/pkg/pipeline/resprule/resprule.py b/pkg/pipeline/resprule/resprule.py index d795d05..fce0c4e 100644 --- a/pkg/pipeline/resprule/resprule.py +++ b/pkg/pipeline/resprule/resprule.py @@ -14,9 +14,12 @@ from ...config import manager as cfg_mgr @stage.stage_class("GroupRespondRuleCheckStage") class GroupRespondRuleCheckStage(stage.PipelineStage): """群组响应规则检查器 + + 仅检查群消息是否符合规则。 """ rule_matchers: list[rule.GroupRespondRule] + """检查器实例""" async def initialize(self): """初始化检查器 @@ -31,7 +34,7 @@ class GroupRespondRuleCheckStage(stage.PipelineStage): async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: - if query.launcher_type.value != 'group': + if query.launcher_type.value != 'group': # 只处理群消息 return entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, new_query=query diff --git a/pkg/pipeline/stagemgr.py b/pkg/pipeline/stagemgr.py index 23c7897..46957aa 100644 --- a/pkg/pipeline/stagemgr.py +++ b/pkg/pipeline/stagemgr.py @@ -17,17 +17,17 @@ from .ratelimit import ratelimit # 请求处理阶段顺序 stage_order = [ - "GroupRespondRuleCheckStage", - "BanSessionCheckStage", - "PreContentFilterStage", - "PreProcessor", - "RequireRateLimitOccupancy", - "MessageProcessor", - "ReleaseRateLimitOccupancy", - "PostContentFilterStage", - "ResponseWrapper", - "LongTextProcessStage", - "SendResponseBackStage", + "GroupRespondRuleCheckStage", # 群响应规则检查 + "BanSessionCheckStage", # 封禁会话检查 + "PreContentFilterStage", # 内容过滤前置阶段 + "PreProcessor", # 预处理器 + "RequireRateLimitOccupancy", # 请求速率限制占用 + "MessageProcessor", # 处理器 + "ReleaseRateLimitOccupancy", # 释放速率限制占用 + "PostContentFilterStage", # 内容过滤后置阶段 + "ResponseWrapper", # 响应包装器 + "LongTextProcessStage", # 长文本处理 + "SendResponseBackStage", # 发送响应 ] diff --git a/pkg/pipeline/wrapper/wrapper.py b/pkg/pipeline/wrapper/wrapper.py index a500d7c..78705e6 100644 --- a/pkg/pipeline/wrapper/wrapper.py +++ b/pkg/pipeline/wrapper/wrapper.py @@ -14,6 +14,13 @@ from ...plugin import events @stage.stage_class("ResponseWrapper") class ResponseWrapper(stage.PipelineStage): + """回复包装阶段 + + 把回复的 message 包装成人类识读的形式。 + + 改写: + - resp_message_chain + """ async def initialize(self): pass @@ -128,4 +135,4 @@ class ResponseWrapper(stage.PipelineStage): yield entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, new_query=query - ) \ No newline at end of file + ) diff --git a/pkg/provider/entities.py b/pkg/provider/entities.py index a30d4e3..3281a93 100644 --- a/pkg/provider/entities.py +++ b/pkg/provider/entities.py @@ -21,6 +21,16 @@ class ToolCall(pydantic.BaseModel): function: FunctionCall +class Content(pydantic.BaseModel): + + type: str + """内容类型""" + + text: typing.Optional[str] = None + + image_url: typing.Optional[str] = None + + class Message(pydantic.BaseModel): """消息""" @@ -33,9 +43,6 @@ class Message(pydantic.BaseModel): content: typing.Optional[str] | typing.Optional[mirai.MessageChain] = None """内容""" - function_call: typing.Optional[FunctionCall] = None - """函数调用,不再受支持,请使用tool_calls""" - tool_calls: typing.Optional[list[ToolCall]] = None """工具调用""" @@ -43,9 +50,7 @@ class Message(pydantic.BaseModel): def readable_str(self) -> str: if self.content is not None: - return str(self.content) - elif self.function_call is not None: - return f'{self.function_call.name}({self.function_call.arguments})' + return str(self.role) + ": " + str(self.content) elif self.tool_calls is not None: return f'调用工具: {self.tool_calls[0].id}' else: diff --git a/pkg/provider/modelmgr/api.py b/pkg/provider/modelmgr/api.py index 63021be..930cf9e 100644 --- a/pkg/provider/modelmgr/api.py +++ b/pkg/provider/modelmgr/api.py @@ -6,6 +6,8 @@ import typing from ...core import app from ...core import entities as core_entities from .. import entities as llm_entities +from . import entities as modelmgr_entities +from ..tools import entities as tools_entities preregistered_requesters: list[typing.Type[LLMAPIRequester]] = [] @@ -33,20 +35,31 @@ class LLMAPIRequester(metaclass=abc.ABCMeta): async def initialize(self): pass - @abc.abstractmethod - async def request( + async def preprocess( self, query: core_entities.Query, - ) -> typing.AsyncGenerator[llm_entities.Message, None]: - """请求API + ): + """预处理 + + 在这里处理特定API对Query对象的兼容性问题。 + """ + pass - 对话前文可以从 query 对象中获取。 - 可以多次yield消息对象。 + @abc.abstractmethod + async def call( + self, + model: modelmgr_entities.LLMModelInfo, + messages: typing.List[llm_entities.Message], + funcs: typing.List[tools_entities.LLMFunction] = None, + ) -> llm_entities.Message: + """调用API Args: - query (core_entities.Query): 本次请求的上下文对象 + model (modelmgr_entities.LLMModelInfo): 使用的模型信息 + messages (typing.List[llm_entities.Message]): 消息对象列表 + funcs (typing.List[tools_entities.LLMFunction], optional): 使用的工具函数列表. Defaults to None. - Yields: - pkg.provider.entities.Message: 返回消息对象 + Returns: + llm_entities.Message: 返回消息对象 """ - raise NotImplementedError + pass diff --git a/pkg/provider/modelmgr/apis/anthropicmsgs.py b/pkg/provider/modelmgr/apis/anthropicmsgs.py index 42bd385..923e1ce 100644 --- a/pkg/provider/modelmgr/apis/anthropicmsgs.py +++ b/pkg/provider/modelmgr/apis/anthropicmsgs.py @@ -27,20 +27,22 @@ class AnthropicMessages(api.LLMAPIRequester): proxies=self.ap.proxy_mgr.get_forward_proxies() ) - async def request( + async def call( self, - query: core_entities.Query, - ) -> typing.AsyncGenerator[llm_entities.Message, None]: - self.client.api_key = query.use_model.token_mgr.get_token() + model: entities.LLMModelInfo, + messages: typing.List[llm_entities.Message], + funcs: typing.List[tools_entities.LLMFunction] = None, + ) -> llm_entities.Message: + self.client.api_key = model.token_mgr.get_token() args = self.ap.provider_cfg.data['requester']['anthropic-messages']['args'].copy() - args["model"] = query.use_model.name if query.use_model.model_name is None else query.use_model.model_name + args["model"] = model.name if model.model_name is None else model.model_name - req_messages = [ # req_messages 仅用于类内,外部同步由 query.messages 进行 - m.dict(exclude_none=True) for m in query.prompt.messages if m.content.strip() != "" - ] + [m.dict(exclude_none=True) for m in query.messages] + req_messages = [ + m.dict(exclude_none=True) for m in messages if m.content.strip() != "" + ] - # 删除所有 role=system & content='' 的消息 + # 删除所有 role=system & content='' 的消息 req_messages = [ m for m in req_messages if not (m["role"] == "system" and m["content"].strip() == "") ] @@ -64,10 +66,9 @@ class AnthropicMessages(api.LLMAPIRequester): args["messages"] = req_messages try: - resp = await self.client.messages.create(**args) - yield llm_entities.Message( + return llm_entities.Message( content=resp.content[0].text, role=resp.role ) @@ -79,4 +80,4 @@ class AnthropicMessages(api.LLMAPIRequester): if 'model: ' in str(e): raise errors.RequesterError(f'模型无效: {e.message}') else: - raise errors.RequesterError(f'请求地址无效: {e.message}') \ No newline at end of file + raise errors.RequesterError(f'请求地址无效: {e.message}') diff --git a/pkg/provider/modelmgr/apis/chatcmpl.py b/pkg/provider/modelmgr/apis/chatcmpl.py index e3901de..7984dd8 100644 --- a/pkg/provider/modelmgr/apis/chatcmpl.py +++ b/pkg/provider/modelmgr/apis/chatcmpl.py @@ -84,73 +84,19 @@ class OpenAIChatCompletions(api.LLMAPIRequester): message = await self._make_msg(resp) return message - - async def _request( - self, query: core_entities.Query - ) -> typing.AsyncGenerator[llm_entities.Message, None]: - """请求""" - - pending_tool_calls = [] - + + async def call( + self, + model: entities.LLMModelInfo, + messages: typing.List[llm_entities.Message], + funcs: typing.List[tools_entities.LLMFunction] = None, + ) -> llm_entities.Message: req_messages = [ # req_messages 仅用于类内,外部同步由 query.messages 进行 - m.dict(exclude_none=True) for m in query.prompt.messages if m.content.strip() != "" - ] + [m.dict(exclude_none=True) for m in query.messages] + m.dict(exclude_none=True) for m in messages + ] - # req_messages.append({"role": "user", "content": str(query.message_chain)}) - - # 首次请求 - msg = await self._closure(req_messages, query.use_model, query.use_funcs) - - yield msg - - pending_tool_calls = msg.tool_calls - - req_messages.append(msg.dict(exclude_none=True)) - - # 持续请求,只要还有待处理的工具调用就继续处理调用 - while pending_tool_calls: - for tool_call in pending_tool_calls: - try: - func = tool_call.function - - parameters = json.loads(func.arguments) - - func_ret = await self.ap.tool_mgr.execute_func_call( - query, func.name, parameters - ) - - msg = llm_entities.Message( - role="tool", content=json.dumps(func_ret, ensure_ascii=False), tool_call_id=tool_call.id - ) - - yield msg - - req_messages.append(msg.dict(exclude_none=True)) - except Exception as e: - # 出错,添加一个报错信息到 req_messages - err_msg = llm_entities.Message( - role="tool", content=f"err: {e}", tool_call_id=tool_call.id - ) - - yield err_msg - - req_messages.append( - err_msg.dict(exclude_none=True) - ) - - # 处理完所有调用,继续请求 - msg = await self._closure(req_messages, query.use_model, query.use_funcs) - - yield msg - - pending_tool_calls = msg.tool_calls - - req_messages.append(msg.dict(exclude_none=True)) - - async def request(self, query: core_entities.Query) -> AsyncGenerator[llm_entities.Message, None]: try: - async for msg in self._request(query): - yield msg + return await self._closure(req_messages, model, funcs) except asyncio.TimeoutError: raise errors.RequesterError('请求超时') except openai.BadRequestError as e: @@ -163,6 +109,6 @@ class OpenAIChatCompletions(api.LLMAPIRequester): except openai.NotFoundError as e: raise errors.RequesterError(f'请求路径错误: {e.message}') except openai.RateLimitError as e: - raise errors.RequesterError(f'请求过于频繁: {e.message}') + raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}') except openai.APIError as e: raise errors.RequesterError(f'请求错误: {e.message}') From 269e561497dce1cb730de128b89beb65fb531bc4 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Tue, 14 May 2024 22:41:39 +0800 Subject: [PATCH 02/11] =?UTF-8?q?perf:=20messages=20=E5=AD=98=E5=9B=9E=20c?= =?UTF-8?q?onversation=20=E5=BA=94=E8=AF=A5=E4=BB=85=E5=9C=A8=E6=88=90?= =?UTF-8?q?=E5=8A=9F=E6=89=A7=E8=A1=8C=E6=9C=AC=E6=AC=A1=E8=AF=B7=E6=B1=82?= =?UTF-8?q?=E6=97=B6=E6=89=A7=E8=A1=8C=20(#769)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/pipeline/process/handlers/chat.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pkg/pipeline/process/handlers/chat.py b/pkg/pipeline/process/handlers/chat.py index 2f0616f..26f73b6 100644 --- a/pkg/pipeline/process/handlers/chat.py +++ b/pkg/pipeline/process/handlers/chat.py @@ -89,6 +89,9 @@ class ChatMessageHandler(handler.MessageHandler): result_type=entities.ResultType.CONTINUE, new_query=query ) + + query.session.using_conversation.messages.append(query.user_message) + query.session.using_conversation.messages.extend(query.resp_messages) except Exception as e: self.ap.logger.error(f'对话({query.query_id})请求失败: {str(e)}') @@ -101,8 +104,6 @@ class ChatMessageHandler(handler.MessageHandler): debug_notice=traceback.format_exc() ) finally: - query.session.using_conversation.messages.append(query.user_message) - query.session.using_conversation.messages.extend(query.resp_messages) await self.ap.ctr_mgr.usage.post_query_record( session_type=query.session.launcher_type.value, From 8807f02f36561e55212f6fe4140737ff6fe4a8cd Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Tue, 14 May 2024 23:08:49 +0800 Subject: [PATCH 03/11] =?UTF-8?q?perf:=20resp=5Fmessage=5Fchain=20?= =?UTF-8?q?=E6=94=B9=E4=B8=BA=20list=20=E7=B1=BB=E5=9E=8B=20(#770)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/core/entities.py | 2 +- pkg/pipeline/longtext/longtext.py | 6 +++--- pkg/pipeline/pool.py | 2 +- pkg/pipeline/process/handlers/command.py | 6 ------ pkg/pipeline/respback/respback.py | 2 +- pkg/pipeline/wrapper/wrapper.py | 16 ++++++++-------- 6 files changed, 14 insertions(+), 20 deletions(-) diff --git a/pkg/core/entities.py b/pkg/core/entities.py index 2e7d0b1..30b983a 100644 --- a/pkg/core/entities.py +++ b/pkg/core/entities.py @@ -70,7 +70,7 @@ class Query(pydantic.BaseModel): resp_messages: typing.Optional[list[llm_entities.Message]] = [] """由Process阶段生成的回复消息对象列表""" - resp_message_chain: typing.Optional[mirai.MessageChain] = None + resp_message_chain: typing.Optional[list[mirai.MessageChain]] = None """回复消息链,从resp_messages包装而得""" class Config: diff --git a/pkg/pipeline/longtext/longtext.py b/pkg/pipeline/longtext/longtext.py index ec0e66e..756df44 100644 --- a/pkg/pipeline/longtext/longtext.py +++ b/pkg/pipeline/longtext/longtext.py @@ -62,15 +62,15 @@ class LongTextProcessStage(stage.PipelineStage): # 检查是否包含非 Plain 组件 contains_non_plain = False - for msg in query.resp_message_chain: + for msg in query.resp_message_chain[-1]: if not isinstance(msg, Plain): contains_non_plain = True break if contains_non_plain: self.ap.logger.debug("消息中包含非 Plain 组件,跳过长消息处理。") - elif len(str(query.resp_message_chain)) > self.ap.platform_cfg.data['long-text-process']['threshold']: - query.resp_message_chain = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain), query)) + elif len(str(query.resp_message_chain[-1])) > self.ap.platform_cfg.data['long-text-process']['threshold']: + query.resp_message_chain[-1] = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain[-1]), query)) return entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, diff --git a/pkg/pipeline/pool.py b/pkg/pipeline/pool.py index bd48c48..ba7f999 100644 --- a/pkg/pipeline/pool.py +++ b/pkg/pipeline/pool.py @@ -43,7 +43,7 @@ class QueryPool: message_event=message_event, message_chain=message_chain, resp_messages=[], - resp_message_chain=None, + resp_message_chain=[], adapter=adapter ) self.queries.append(query) diff --git a/pkg/pipeline/process/handlers/command.py b/pkg/pipeline/process/handlers/command.py index 7179fd3..75d9222 100644 --- a/pkg/pipeline/process/handlers/command.py +++ b/pkg/pipeline/process/handlers/command.py @@ -80,9 +80,6 @@ class CommandHandler(handler.MessageHandler): session=session ): if ret.error is not None: - # query.resp_message_chain = mirai.MessageChain([ - # mirai.Plain(str(ret.error)) - # ]) query.resp_messages.append( llm_entities.Message( role='command', @@ -97,9 +94,6 @@ class CommandHandler(handler.MessageHandler): new_query=query ) elif ret.text is not None: - # query.resp_message_chain = mirai.MessageChain([ - # mirai.Plain(ret.text) - # ]) query.resp_messages.append( llm_entities.Message( role='command', diff --git a/pkg/pipeline/respback/respback.py b/pkg/pipeline/respback/respback.py index 36a7329..d3af14e 100644 --- a/pkg/pipeline/respback/respback.py +++ b/pkg/pipeline/respback/respback.py @@ -31,7 +31,7 @@ class SendResponseBackStage(stage.PipelineStage): await self.ap.platform_mgr.send( query.message_event, - query.resp_message_chain, + query.resp_message_chain[-1], adapter=query.adapter ) diff --git a/pkg/pipeline/wrapper/wrapper.py b/pkg/pipeline/wrapper/wrapper.py index 78705e6..345addb 100644 --- a/pkg/pipeline/wrapper/wrapper.py +++ b/pkg/pipeline/wrapper/wrapper.py @@ -34,7 +34,7 @@ class ResponseWrapper(stage.PipelineStage): """ if query.resp_messages[-1].role == 'command': - query.resp_message_chain = mirai.MessageChain("[bot] "+query.resp_messages[-1].content) + query.resp_message_chain.append(mirai.MessageChain("[bot] "+query.resp_messages[-1].content)) yield entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, @@ -42,9 +42,9 @@ class ResponseWrapper(stage.PipelineStage): ) elif query.resp_messages[-1].role == 'plugin': if not isinstance(query.resp_messages[-1].content, mirai.MessageChain): - query.resp_message_chain = mirai.MessageChain(query.resp_messages[-1].content) + query.resp_message_chain.append(mirai.MessageChain(query.resp_messages[-1].content)) else: - query.resp_message_chain = query.resp_messages[-1].content + query.resp_message_chain.append(query.resp_messages[-1].content) yield entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, @@ -83,11 +83,11 @@ class ResponseWrapper(stage.PipelineStage): else: if event_ctx.event.reply is not None: - query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply) + query.resp_message_chain.append(mirai.MessageChain(event_ctx.event.reply)) else: - query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)]) + query.resp_message_chain.append(mirai.MessageChain([mirai.Plain(reply_text)])) yield entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, @@ -100,7 +100,7 @@ class ResponseWrapper(stage.PipelineStage): reply_text = f'调用函数 {".".join(function_names)}...' - query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)]) + query.resp_message_chain.append(mirai.MessageChain([mirai.Plain(reply_text)])) if self.ap.platform_cfg.data['track-function-calls']: @@ -126,11 +126,11 @@ class ResponseWrapper(stage.PipelineStage): else: if event_ctx.event.reply is not None: - query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply) + query.resp_message_chain.append(mirai.MessageChain(event_ctx.event.reply)) else: - query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)]) + query.resp_message_chain.append(mirai.MessageChain([mirai.Plain(reply_text)])) yield entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, From d5b5d667a578c23ac744a153ac6493cbb74026b4 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Wed, 15 May 2024 21:40:18 +0800 Subject: [PATCH 04/11] =?UTF-8?q?feat:=20=E6=A8=A1=E5=9E=8B=E8=A7=86?= =?UTF-8?q?=E8=A7=89=E5=A4=9A=E6=A8=A1=E6=80=81=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/command/entities.py | 9 +- pkg/config/impls/json.py | 10 ++- pkg/config/impls/pymodule.py | 21 ++--- pkg/config/manager.py | 12 +-- .../migrations/m006_vision_and_oss_config.py | 35 ++++++++ pkg/config/model.py | 2 +- pkg/core/app.py | 3 + pkg/core/bootutils/deps.py | 1 + pkg/core/stages/build_app.py | 7 +- pkg/core/stages/load_config.py | 10 +-- pkg/core/stages/migrate.py | 2 +- pkg/oss/__init__.py | 0 pkg/oss/oss.py | 85 +++++++++++++++++++ pkg/oss/service.py | 67 +++++++++++++++ pkg/oss/services/__init__.py | 0 pkg/oss/services/aliyun.py | 48 +++++++++++ pkg/pipeline/cntfilter/cntfilter.py | 16 ++++ pkg/pipeline/cntfilter/entities.py | 11 ++- pkg/pipeline/cntfilter/filter.py | 14 ++- pkg/pipeline/preproc/preproc.py | 26 +++++- pkg/pipeline/process/handlers/command.py | 19 ++++- pkg/pipeline/wrapper/wrapper.py | 16 ++-- pkg/provider/entities.py | 58 ++++++++++++- pkg/provider/modelmgr/apis/anthropicmsgs.py | 52 +++++++----- pkg/provider/modelmgr/apis/chatcmpl.py | 34 +++++++- .../modelmgr/apis/deepseekchatcmpl.py | 42 ++++++++- .../modelmgr/apis/moonshotchatcmpl.py | 43 +++++++++- pkg/provider/modelmgr/modelmgr.py | 4 + requirements.txt | 4 +- templates/metadata/llm-models.json | 4 + templates/provider.json | 1 + templates/system.json | 12 +++ 32 files changed, 596 insertions(+), 72 deletions(-) create mode 100644 pkg/config/migrations/m006_vision_and_oss_config.py create mode 100644 pkg/oss/__init__.py create mode 100644 pkg/oss/oss.py create mode 100644 pkg/oss/service.py create mode 100644 pkg/oss/services/__init__.py create mode 100644 pkg/oss/services/aliyun.py diff --git a/pkg/command/entities.py b/pkg/command/entities.py index 27cb596..8697551 100644 --- a/pkg/command/entities.py +++ b/pkg/command/entities.py @@ -13,11 +13,16 @@ class CommandReturn(pydantic.BaseModel): """命令返回值 """ - text: typing.Optional[str] + text: typing.Optional[str] = None """文本 """ - image: typing.Optional[mirai.Image] + image: typing.Optional[mirai.Image] = None + """弃用""" + + image_url: typing.Optional[str] = None + """图片链接 + """ error: typing.Optional[errors.CommandError]= None """错误 diff --git a/pkg/config/impls/json.py b/pkg/config/impls/json.py index 754bfa5..362bc78 100644 --- a/pkg/config/impls/json.py +++ b/pkg/config/impls/json.py @@ -27,7 +27,7 @@ class JSONConfigFile(file_model.ConfigFile): else: raise ValueError("template_file_name or template_data must be provided") - async def load(self) -> dict: + async def load(self, completion: bool=True) -> dict: if not self.exists(): await self.create() @@ -39,9 +39,11 @@ class JSONConfigFile(file_model.ConfigFile): with open(self.config_file_name, "r", encoding="utf-8") as f: cfg = json.load(f) - for key in self.template_data: - if key not in cfg: - cfg[key] = self.template_data[key] + if completion: + + for key in self.template_data: + if key not in cfg: + cfg[key] = self.template_data[key] return cfg diff --git a/pkg/config/impls/pymodule.py b/pkg/config/impls/pymodule.py index ceeebad..67e5867 100644 --- a/pkg/config/impls/pymodule.py +++ b/pkg/config/impls/pymodule.py @@ -25,7 +25,7 @@ class PythonModuleConfigFile(file_model.ConfigFile): async def create(self): shutil.copyfile(self.template_file_name, self.config_file_name) - async def load(self) -> dict: + async def load(self, completion: bool=True) -> dict: module_name = os.path.splitext(os.path.basename(self.config_file_name))[0] module = importlib.import_module(module_name) @@ -43,18 +43,19 @@ class PythonModuleConfigFile(file_model.ConfigFile): cfg[key] = getattr(module, key) # 从模板模块文件中进行补全 - module_name = os.path.splitext(os.path.basename(self.template_file_name))[0] - module = importlib.import_module(module_name) + if completion: + module_name = os.path.splitext(os.path.basename(self.template_file_name))[0] + module = importlib.import_module(module_name) - for key in dir(module): - if key.startswith('__'): - continue + for key in dir(module): + if key.startswith('__'): + continue - if not isinstance(getattr(module, key), allowed_types): - continue + if not isinstance(getattr(module, key), allowed_types): + continue - if key not in cfg: - cfg[key] = getattr(module, key) + if key not in cfg: + cfg[key] = getattr(module, key) return cfg diff --git a/pkg/config/manager.py b/pkg/config/manager.py index f9e93c8..7983407 100644 --- a/pkg/config/manager.py +++ b/pkg/config/manager.py @@ -20,8 +20,8 @@ class ConfigManager: self.file = cfg_file self.data = {} - async def load_config(self): - self.data = await self.file.load() + async def load_config(self, completion: bool=True): + self.data = await self.file.load(completion=completion) async def dump_config(self): await self.file.save(self.data) @@ -30,7 +30,7 @@ class ConfigManager: self.file.save_sync(self.data) -async def load_python_module_config(config_name: str, template_name: str) -> ConfigManager: +async def load_python_module_config(config_name: str, template_name: str, completion: bool=True) -> ConfigManager: """加载Python模块配置文件""" cfg_inst = pymodule.PythonModuleConfigFile( config_name, @@ -38,12 +38,12 @@ async def load_python_module_config(config_name: str, template_name: str) -> Con ) cfg_mgr = ConfigManager(cfg_inst) - await cfg_mgr.load_config() + await cfg_mgr.load_config(completion=completion) return cfg_mgr -async def load_json_config(config_name: str, template_name: str=None, template_data: dict=None) -> ConfigManager: +async def load_json_config(config_name: str, template_name: str=None, template_data: dict=None, completion: bool=True) -> ConfigManager: """加载JSON配置文件""" cfg_inst = json_file.JSONConfigFile( config_name, @@ -52,6 +52,6 @@ async def load_json_config(config_name: str, template_name: str=None, template_d ) cfg_mgr = ConfigManager(cfg_inst) - await cfg_mgr.load_config() + await cfg_mgr.load_config(completion=completion) return cfg_mgr \ No newline at end of file diff --git a/pkg/config/migrations/m006_vision_and_oss_config.py b/pkg/config/migrations/m006_vision_and_oss_config.py new file mode 100644 index 0000000..2d05b97 --- /dev/null +++ b/pkg/config/migrations/m006_vision_and_oss_config.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +from .. import migration + + +@migration.migration_class("vision-and-oss-config", 6) +class VisionAndOSSConfigMigration(migration.Migration): + """迁移""" + + async def need_migrate(self) -> bool: + """判断当前环境是否需要运行此迁移""" + return "enable-vision" not in self.ap.provider_cfg.data \ + or "oss" not in self.ap.system_cfg.data + + async def run(self): + """执行迁移""" + if "enable-vision" not in self.ap.provider_cfg.data: + self.ap.provider_cfg.data["enable-vision"] = False + + if "oss" not in self.ap.system_cfg.data: + self.ap.system_cfg.data["oss"] = [ + { + "type": "aliyun", + "endpoint": "https://oss-cn-hangzhou.aliyuncs.com", + "public-read-base-url": "https://qchatgpt.oss-cn-hangzhou.aliyuncs.com", + "access-key-id": "LTAI5tJ5Q5J8J6J5J5J5J5J5", + "access-key-secret": "xxxxxx", + "bucket": "qchatgpt", + "prefix": "qchatgpt", + "enable": False, + } + ] + + await self.ap.provider_cfg.dump_config() + await self.ap.system_cfg.dump_config() diff --git a/pkg/config/model.py b/pkg/config/model.py index d209093..153123e 100644 --- a/pkg/config/model.py +++ b/pkg/config/model.py @@ -22,7 +22,7 @@ class ConfigFile(metaclass=abc.ABCMeta): pass @abc.abstractmethod - async def load(self) -> dict: + async def load(self, completion: bool=True) -> dict: pass @abc.abstractmethod diff --git a/pkg/core/app.py b/pkg/core/app.py index 1ed5304..1705e29 100644 --- a/pkg/core/app.py +++ b/pkg/core/app.py @@ -15,6 +15,7 @@ from ..command import cmdmgr from ..plugin import manager as plugin_mgr from ..pipeline import pool from ..pipeline import controller, stagemgr +from ..oss import oss from ..utils import version as version_mgr, proxy as proxy_mgr @@ -71,6 +72,8 @@ class Application: proxy_mgr: proxy_mgr.ProxyManager = None + oss_mgr: oss.OSSServiceManager = None + logger: logging.Logger = None def __init__(self): diff --git a/pkg/core/bootutils/deps.py b/pkg/core/bootutils/deps.py index 4adf132..ab40704 100644 --- a/pkg/core/bootutils/deps.py +++ b/pkg/core/bootutils/deps.py @@ -14,6 +14,7 @@ required_deps = { "yaml": "pyyaml", "aiohttp": "aiohttp", "psutil": "psutil", + "oss2": "oss2", } diff --git a/pkg/core/stages/build_app.py b/pkg/core/stages/build_app.py index d536582..ff8c1e7 100644 --- a/pkg/core/stages/build_app.py +++ b/pkg/core/stages/build_app.py @@ -14,6 +14,7 @@ from ...provider.modelmgr import modelmgr as llm_model_mgr from ...provider.sysprompt import sysprompt as llm_prompt_mgr from ...provider.tools import toolmgr as llm_tool_mgr from ...platform import manager as im_mgr +from ...oss import oss as oss_mgr @stage.stage_class("BuildAppStage") @@ -68,6 +69,10 @@ class BuildAppStage(stage.BootingStage): await cmd_mgr_inst.initialize() ap.cmd_mgr = cmd_mgr_inst + oss_mgr_inst = oss_mgr.OSSServiceManager(ap) + await oss_mgr_inst.initialize() + ap.oss_mgr = oss_mgr_inst + llm_model_mgr_inst = llm_model_mgr.ModelManager(ap) await llm_model_mgr_inst.initialize() ap.model_mgr = llm_model_mgr_inst @@ -83,7 +88,6 @@ class BuildAppStage(stage.BootingStage): llm_tool_mgr_inst = llm_tool_mgr.ToolManager(ap) await llm_tool_mgr_inst.initialize() ap.tool_mgr = llm_tool_mgr_inst - im_mgr_inst = im_mgr.PlatformManager(ap=ap) await im_mgr_inst.initialize() ap.platform_mgr = im_mgr_inst @@ -92,5 +96,6 @@ class BuildAppStage(stage.BootingStage): await stage_mgr.initialize() ap.stage_mgr = stage_mgr + ctrl = controller.Controller(ap) ap.ctrl = ctrl diff --git a/pkg/core/stages/load_config.py b/pkg/core/stages/load_config.py index 9e61c1c..cb6e1ed 100644 --- a/pkg/core/stages/load_config.py +++ b/pkg/core/stages/load_config.py @@ -12,11 +12,11 @@ class LoadConfigStage(stage.BootingStage): async def run(self, ap: app.Application): """启动 """ - ap.command_cfg = await config.load_json_config("data/config/command.json", "templates/command.json") - ap.pipeline_cfg = await config.load_json_config("data/config/pipeline.json", "templates/pipeline.json") - ap.platform_cfg = await config.load_json_config("data/config/platform.json", "templates/platform.json") - ap.provider_cfg = await config.load_json_config("data/config/provider.json", "templates/provider.json") - ap.system_cfg = await config.load_json_config("data/config/system.json", "templates/system.json") + ap.command_cfg = await config.load_json_config("data/config/command.json", "templates/command.json", completion=False) + ap.pipeline_cfg = await config.load_json_config("data/config/pipeline.json", "templates/pipeline.json", completion=False) + ap.platform_cfg = await config.load_json_config("data/config/platform.json", "templates/platform.json", completion=False) + ap.provider_cfg = await config.load_json_config("data/config/provider.json", "templates/provider.json", completion=False) + ap.system_cfg = await config.load_json_config("data/config/system.json", "templates/system.json", completion=False) ap.plugin_setting_meta = await config.load_json_config("plugins/plugins.json", "templates/plugin-settings.json") await ap.plugin_setting_meta.dump_config() diff --git a/pkg/core/stages/migrate.py b/pkg/core/stages/migrate.py index cef3b42..44102fe 100644 --- a/pkg/core/stages/migrate.py +++ b/pkg/core/stages/migrate.py @@ -5,7 +5,7 @@ import importlib from .. import stage, app from ...config import migration from ...config.migrations import m001_sensitive_word_migration, m002_openai_config_migration, m003_anthropic_requester_cfg_completion, m004_moonshot_cfg_completion -from ...config.migrations import m005_deepseek_cfg_completion +from ...config.migrations import m005_deepseek_cfg_completion, m006_vision_and_oss_config @stage.stage_class("MigrationStage") diff --git a/pkg/oss/__init__.py b/pkg/oss/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pkg/oss/oss.py b/pkg/oss/oss.py new file mode 100644 index 0000000..5474ed3 --- /dev/null +++ b/pkg/oss/oss.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +import aiohttp +import typing +from urllib.parse import urlparse, parse_qs +import ssl + +from . import service as osssv +from ..core import app +from .services import aliyun + + +class OSSServiceManager: + + ap: app.Application + + service: osssv.OSSService = None + + def __init__(self, ap: app.Application): + self.ap = ap + + async def initialize(self): + """初始化 + """ + + mapping = {} + + for svcls in osssv.preregistered_services: + mapping[svcls.name] = svcls + + for sv in self.ap.system_cfg.data['oss']: + if sv['enable']: + + if sv['type'] not in mapping: + raise Exception(f"未知的OSS服务类型: {sv['type']}") + + self.service = mapping[sv['type']](self.ap, sv) + await self.service.initialize() + break + + def available(self) -> bool: + """是否可用 + + Returns: + bool: 是否可用 + """ + return self.service is not None + + async def fetch_image(self, image_url: str) -> bytes: + parsed = urlparse(image_url) + query = parse_qs(parsed.query) + + # Flatten the query dictionary + query = {k: v[0] for k, v in query.items()} + + ssl_context = ssl.create_default_context() + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + + async with aiohttp.ClientSession(trust_env=False) as session: + async with session.get( + f"http://{parsed.netloc}{parsed.path}", + params=query, + ssl=ssl_context + ) as resp: + resp.raise_for_status() # 检查HTTP错误 + file_bytes = await resp.read() + return file_bytes + + async def upload_url_image( + self, + image_url: str, + ) -> str: + """上传URL图片 + + Args: + image_url (str): 图片URL + + Returns: + str: 文件URL + """ + + file_bytes = await self.fetch_image(image_url) + + return await self.service.upload(file_bytes=file_bytes, ext=".jpg") \ No newline at end of file diff --git a/pkg/oss/service.py b/pkg/oss/service.py new file mode 100644 index 0000000..a822844 --- /dev/null +++ b/pkg/oss/service.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +import typing +import abc + +from ..core import app + + +preregistered_services: list[typing.Type[OSSService]] = [] + +def service_class( + name: str +) -> typing.Callable[[typing.Type[OSSService]], typing.Type[OSSService]]: + """OSS服务类装饰器 + + Args: + name (str): 服务名称 + + Returns: + typing.Callable[[typing.Type[OSSService]], typing.Type[OSSService]]: 装饰器 + """ + def decorator(cls: typing.Type[OSSService]) -> typing.Type[OSSService]: + assert issubclass(cls, OSSService) + + cls.name = name + + preregistered_services.append(cls) + + return cls + + return decorator + + +class OSSService(metaclass=abc.ABCMeta): + """OSS抽象类""" + + name: str + + ap: app.Application + + cfg: dict + + def __init__(self, ap: app.Application, cfg: dict) -> None: + self.ap = ap + self.cfg = cfg + + async def initialize(self): + pass + + @abc.abstractmethod + async def upload( + self, + local_file: str=None, + file_bytes: bytes=None, + ext: str=None, + ) -> str: + """上传文件 + + Args: + local_file (str, optional): 本地文件路径. Defaults to None. + file_bytes (bytes, optional): 文件字节. Defaults to None. + ext (str, optional): 文件扩展名. Defaults to None. + + Returns: + str: 文件URL + """ + pass diff --git a/pkg/oss/services/__init__.py b/pkg/oss/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pkg/oss/services/aliyun.py b/pkg/oss/services/aliyun.py new file mode 100644 index 0000000..d30ac89 --- /dev/null +++ b/pkg/oss/services/aliyun.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +import uuid + +import oss2 + +from .. import service as osssv + + +@osssv.service_class('aliyun') +class AliyunOSSService(osssv.OSSService): + """阿里云OSS服务""" + + auth: oss2.Auth + + bucket: oss2.Bucket + + async def initialize(self): + self.auth = oss2.Auth( + self.cfg['access-key-id'], + self.cfg['access-key-secret'] + ) + + self.bucket = oss2.Bucket( + self.auth, + self.cfg['endpoint'], + self.cfg['bucket'] + ) + + async def upload( + self, + local_file: str=None, + file_bytes: bytes=None, + ext: str=None, + ) -> str: + if local_file is not None: + with open(local_file, 'rb') as f: + file_bytes = f.read() + + if file_bytes is None: + raise Exception("缺少文件内容") + + name = str(uuid.uuid1()) + + key = f"{self.cfg['prefix']}/{name}{ext}" + self.bucket.put_object(key, file_bytes) + + return f"{self.cfg['public-read-base-url']}/{key}" diff --git a/pkg/pipeline/cntfilter/cntfilter.py b/pkg/pipeline/cntfilter/cntfilter.py index 2c6a5ab..a669e31 100644 --- a/pkg/pipeline/cntfilter/cntfilter.py +++ b/pkg/pipeline/cntfilter/cntfilter.py @@ -9,6 +9,7 @@ from ...core import entities as core_entities from ...config import manager as cfg_mgr from . import filter as filter_model, entities as filter_entities from .filters import cntignore, banwords, baiduexamine +from ...provider import entities as llm_entities @stage.stage_class('PostContentFilterStage') @@ -141,6 +142,21 @@ class ContentFilterStage(stage.PipelineStage): """处理 """ if stage_inst_name == 'PreContentFilterStage': + + contain_non_text = False + + for me in query.message_chain: + if not isinstance(me, mirai.Plain): + contain_non_text = True + break + + if contain_non_text: + self.ap.logger.debug(f"消息中包含非文本消息,跳过内容过滤器检查。") + return entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) + return await self._pre_process( str(query.message_chain).strip(), query diff --git a/pkg/pipeline/cntfilter/entities.py b/pkg/pipeline/cntfilter/entities.py index 8ff581f..83744c1 100644 --- a/pkg/pipeline/cntfilter/entities.py +++ b/pkg/pipeline/cntfilter/entities.py @@ -4,6 +4,8 @@ import enum import pydantic +from ...provider import entities as llm_entities + class ResultLevel(enum.Enum): """结果等级""" @@ -29,6 +31,13 @@ class EnableStage(enum.Enum): """后处理""" +class AcceptContent(enum.Enum): + """过滤器接受的内容模态""" + + TEXT = enum.auto() + + IMAGE_URL = enum.auto() + class FilterResult(pydantic.BaseModel): level: ResultLevel """结果等级 @@ -38,7 +47,7 @@ class FilterResult(pydantic.BaseModel): """ replacement: str - """替换后的消息 + """替换后的文本消息 内容过滤器可以进行一些遮掩处理,然后把遮掩后的消息返回。 若没有修改内容,也需要返回原消息。 diff --git a/pkg/pipeline/cntfilter/filter.py b/pkg/pipeline/cntfilter/filter.py index 8b34e0c..5fd55e4 100644 --- a/pkg/pipeline/cntfilter/filter.py +++ b/pkg/pipeline/cntfilter/filter.py @@ -5,6 +5,7 @@ import typing from ...core import app from . import entities +from ...provider import entities as llm_entities preregistered_filters: list[typing.Type[ContentFilter]] = [] @@ -56,6 +57,16 @@ class ContentFilter(metaclass=abc.ABCMeta): entities.EnableStage.PRE, entities.EnableStage.POST ] + + @property + def accept_content(self): + """本过滤器接受的模态 + + 默认仅接受纯文本 + """ + return [ + entities.AcceptContent.TEXT + ] async def initialize(self): """初始化过滤器 @@ -63,7 +74,7 @@ class ContentFilter(metaclass=abc.ABCMeta): pass @abc.abstractmethod - async def process(self, message: str) -> entities.FilterResult: + async def process(self, message: str=None, image_url=None) -> entities.FilterResult: """处理消息 分为前后阶段,具体取决于 enable_stages 的值。 @@ -71,6 +82,7 @@ class ContentFilter(metaclass=abc.ABCMeta): Args: message (str): 需要检查的内容 + image_url (str): 要检查的图片的 URL Returns: entities.FilterResult: 过滤结果,具体内容请查看 entities.FilterResult 类的文档 diff --git a/pkg/pipeline/preproc/preproc.py b/pkg/pipeline/preproc/preproc.py index 164f78c..0470a60 100644 --- a/pkg/pipeline/preproc/preproc.py +++ b/pkg/pipeline/preproc/preproc.py @@ -1,5 +1,7 @@ from __future__ import annotations +import mirai + from .. import stage, entities, stagemgr from ...core import entities as core_entities from ...provider import entities as llm_entities @@ -37,9 +39,31 @@ class PreProcessor(stage.PipelineStage): query.prompt = conversation.prompt.copy() query.messages = conversation.messages.copy() + # 检查vision是否启用,没启用就删除所有图片 + if not self.ap.provider_cfg.data['enable-vision']: + for msg in query.messages: + if isinstance(msg.content, list): + for me in msg.content: + if me.type == 'image_url': + msg.content.remove(me) + + content_list = [] + + for me in query.message_chain: + if isinstance(me, mirai.Plain): + content_list.append( + llm_entities.ContentElement.from_text(me.text) + ) + elif isinstance(me, mirai.Image): + if self.ap.provider_cfg.data['enable-vision']: + if me.url is not None: + content_list.append( + llm_entities.ContentElement.from_image_url(str(me.url)) + ) + query.user_message = llm_entities.Message( # TODO 适配多模态输入 role='user', - content=str(query.message_chain).strip() + content=content_list ) query.use_model = conversation.use_model diff --git a/pkg/pipeline/process/handlers/command.py b/pkg/pipeline/process/handlers/command.py index 75d9222..02ff269 100644 --- a/pkg/pipeline/process/handlers/command.py +++ b/pkg/pipeline/process/handlers/command.py @@ -93,15 +93,28 @@ class CommandHandler(handler.MessageHandler): result_type=entities.ResultType.CONTINUE, new_query=query ) - elif ret.text is not None: + elif ret.text is not None or ret.image_url is not None: + + content: list[llm_entities.ContentElement]= [] + + if ret.text is not None: + content.append( + llm_entities.ContentElement.from_text(ret.text) + ) + + if ret.image_url is not None: + content.append( + llm_entities.ContentElement.from_image_url(ret.image_url) + ) + query.resp_messages.append( llm_entities.Message( role='command', - content=ret.text, + content=content, ) ) - self.ap.logger.info(f'命令返回: {self.cut_str(ret.text)}') + self.ap.logger.info(f'命令返回: {self.cut_str(str(content[0]))}') yield entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, diff --git a/pkg/pipeline/wrapper/wrapper.py b/pkg/pipeline/wrapper/wrapper.py index 345addb..acf0549 100644 --- a/pkg/pipeline/wrapper/wrapper.py +++ b/pkg/pipeline/wrapper/wrapper.py @@ -34,17 +34,19 @@ class ResponseWrapper(stage.PipelineStage): """ if query.resp_messages[-1].role == 'command': - query.resp_message_chain.append(mirai.MessageChain("[bot] "+query.resp_messages[-1].content)) + # query.resp_message_chain.append(mirai.MessageChain("[bot] "+query.resp_messages[-1].content)) + query.resp_message_chain.append(query.resp_messages[-1].get_content_mirai_message_chain(prefix_text='[bot] ')) yield entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, new_query=query ) elif query.resp_messages[-1].role == 'plugin': - if not isinstance(query.resp_messages[-1].content, mirai.MessageChain): - query.resp_message_chain.append(mirai.MessageChain(query.resp_messages[-1].content)) - else: - query.resp_message_chain.append(query.resp_messages[-1].content) + # if not isinstance(query.resp_messages[-1].content, mirai.MessageChain): + # query.resp_message_chain.append(mirai.MessageChain(query.resp_messages[-1].content)) + # else: + # query.resp_message_chain.append(query.resp_messages[-1].content) + query.resp_message_chain.append(query.resp_messages[-1].get_content_mirai_message_chain()) yield entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, @@ -59,7 +61,7 @@ class ResponseWrapper(stage.PipelineStage): reply_text = '' if result.content is not None: # 有内容 - reply_text = result.content + reply_text = str(result.get_content_mirai_message_chain()) # ============= 触发插件事件 =============== event_ctx = await self.ap.plugin_mgr.emit_event( @@ -87,7 +89,7 @@ class ResponseWrapper(stage.PipelineStage): else: - query.resp_message_chain.append(mirai.MessageChain([mirai.Plain(reply_text)])) + query.resp_message_chain.append(result.get_content_mirai_message_chain()) yield entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, diff --git a/pkg/provider/entities.py b/pkg/provider/entities.py index 3281a93..b892b89 100644 --- a/pkg/provider/entities.py +++ b/pkg/provider/entities.py @@ -21,14 +21,34 @@ class ToolCall(pydantic.BaseModel): function: FunctionCall -class Content(pydantic.BaseModel): +class ImageURLContentObject(pydantic.BaseModel): + url: str + + +class ContentElement(pydantic.BaseModel): type: str """内容类型""" text: typing.Optional[str] = None - image_url: typing.Optional[str] = None + image_url: typing.Optional[ImageURLContentObject] = None + + def __str__(self): + if self.type == 'text': + return self.text + elif self.type == 'image_url': + return f'[图片]({self.image_url})' + else: + return '未知内容' + + @classmethod + def from_text(cls, text: str): + return cls(type='text', text=text) + + @classmethod + def from_image_url(cls, image_url: str): + return cls(type='image_url', image_url=ImageURLContentObject(url=image_url)) class Message(pydantic.BaseModel): @@ -40,7 +60,7 @@ class Message(pydantic.BaseModel): name: typing.Optional[str] = None """名称,仅函数调用返回时设置""" - content: typing.Optional[str] | typing.Optional[mirai.MessageChain] = None + content: typing.Optional[list[ContentElement]] | typing.Optional[str] = None """内容""" tool_calls: typing.Optional[list[ToolCall]] = None @@ -50,8 +70,38 @@ class Message(pydantic.BaseModel): def readable_str(self) -> str: if self.content is not None: - return str(self.role) + ": " + str(self.content) + return str(self.role) + ": " + str(self.get_content_mirai_message_chain()) elif self.tool_calls is not None: return f'调用工具: {self.tool_calls[0].id}' else: return '未知消息' + + def get_content_mirai_message_chain(self, prefix_text: str="") -> mirai.MessageChain | None: + """将内容转换为 Mirai MessageChain 对象 + + Args: + prefix_text (str): 首个文字组件的前缀文本 + """ + + if self.content is None: + return None + elif isinstance(self.content, str): + return mirai.MessageChain([mirai.Plain(prefix_text+self.content)]) + elif isinstance(self.content, list): + mc = [] + for ce in self.content: + if ce.type == 'text': + mc.append(mirai.Plain(ce.text)) + elif ce.type == 'image': + mc.append(mirai.Image(url=ce.image_url)) + + # 找第一个文字组件 + if prefix_text: + for i, c in enumerate(mc): + if isinstance(c, mirai.Plain): + mc[i] = mirai.Plain(prefix_text+c.text) + break + else: + mc.insert(0, mirai.Plain(prefix_text)) + + return mirai.MessageChain(mc) diff --git a/pkg/provider/modelmgr/apis/anthropicmsgs.py b/pkg/provider/modelmgr/apis/anthropicmsgs.py index 923e1ce..ee2c51a 100644 --- a/pkg/provider/modelmgr/apis/anthropicmsgs.py +++ b/pkg/provider/modelmgr/apis/anthropicmsgs.py @@ -38,30 +38,42 @@ class AnthropicMessages(api.LLMAPIRequester): args = self.ap.provider_cfg.data['requester']['anthropic-messages']['args'].copy() args["model"] = model.name if model.model_name is None else model.model_name - req_messages = [ - m.dict(exclude_none=True) for m in messages if m.content.strip() != "" - ] + # 处理消息 - # 删除所有 role=system & content='' 的消息 - req_messages = [ - m for m in req_messages if not (m["role"] == "system" and m["content"].strip() == "") - ] + # system + system_role_message = None - # 检查是否有 role=system 的消息,若有,改为 role=user,并在后面加一个 role=assistant 的消息 - system_role_index = [] - for i, m in enumerate(req_messages): - if m["role"] == "system": - system_role_index.append(i) - m["role"] = "user" + for i, m in enumerate(messages): + if m.role == "system": + system_role_message = m - if system_role_index: - for i in system_role_index[::-1]: - req_messages.insert(i + 1, {"role": "assistant", "content": "Okay, I'll follow."}) + messages.pop(i) + break - # 忽略掉空消息,用户可能发送空消息,而上层未过滤 - req_messages = [ - m for m in req_messages if m["content"].strip() != "" - ] + if isinstance(system_role_message, llm_entities.Message) \ + and isinstance(system_role_message.content, str): + args['system'] = system_role_message.content + + # 其他消息 + # req_messages = [ + # m.dict(exclude_none=True) for m in messages \ + # if (isinstance(m.content, str) and m.content.strip() != "") \ + # or (isinstance(m.content, list) and ) + # ] + # 暂时不支持vision,仅保留纯文字的content + req_messages = [] + + for m in messages: + if isinstance(m.content, str) and m.content.strip() != "": + req_messages.append(m.dict(exclude_none=True)) + elif isinstance(m.content, list): + # 删除m.content中的type!=text的元素 + m.content = [ + c for c in m.content if c.get("type") == "text" + ] + + if len(m.content) > 0: + req_messages.append(m.dict(exclude_none=True)) args["messages"] = req_messages diff --git a/pkg/provider/modelmgr/apis/chatcmpl.py b/pkg/provider/modelmgr/apis/chatcmpl.py index 7984dd8..fb6e057 100644 --- a/pkg/provider/modelmgr/apis/chatcmpl.py +++ b/pkg/provider/modelmgr/apis/chatcmpl.py @@ -23,9 +23,18 @@ class OpenAIChatCompletions(api.LLMAPIRequester): requester_cfg: dict + cached_image_oss_url: dict[str, str] = {} + """缓存的OSS服务的图片URL + + key: 前文message中的原图片URL(QQ图片) + value: OSS服务的图片URL + """ + def __init__(self, ap: app.Application): self.ap = ap + self.cached_image_oss_url = {} + self.requester_cfg = self.ap.provider_cfg.data['requester']['openai-chat-completions'] async def initialize(self): @@ -74,7 +83,16 @@ class OpenAIChatCompletions(api.LLMAPIRequester): args["tools"] = tools # 设置此次请求中的messages - messages = req_messages + messages = req_messages.copy() + + # 检查vision + if self.ap.oss_mgr.available(): + for msg in messages: + if isinstance(msg["content"], list): + for me in msg["content"]: + if me["type"] == "image_url": + me["image_url"]['url'] = await self.get_oss_url(me["image_url"]['url']) + args["messages"] = messages # 发送请求 @@ -112,3 +130,17 @@ class OpenAIChatCompletions(api.LLMAPIRequester): raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}') except openai.APIError as e: raise errors.RequesterError(f'请求错误: {e.message}') + + async def get_oss_url( + self, + original_url: str, + ) -> str: + + if original_url in self.cached_image_oss_url: + return self.cached_image_oss_url[original_url] + + oss_url = await self.ap.oss_mgr.upload_url_image(original_url) + + self.cached_image_oss_url[original_url] = oss_url + + return oss_url diff --git a/pkg/provider/modelmgr/apis/deepseekchatcmpl.py b/pkg/provider/modelmgr/apis/deepseekchatcmpl.py index dd8ddc6..9cb667b 100644 --- a/pkg/provider/modelmgr/apis/deepseekchatcmpl.py +++ b/pkg/provider/modelmgr/apis/deepseekchatcmpl.py @@ -3,7 +3,10 @@ from __future__ import annotations from ....core import app from . import chatcmpl -from .. import api +from .. import api, entities, errors +from ....core import entities as core_entities, app +from ... import entities as llm_entities +from ...tools import entities as tools_entities @api.requester_class("deepseek-chat-completions") @@ -12,4 +15,39 @@ class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions): def __init__(self, ap: app.Application): self.requester_cfg = ap.provider_cfg.data['requester']['deepseek-chat-completions'] - self.ap = ap \ No newline at end of file + self.ap = ap + + async def _closure( + self, + req_messages: list[dict], + use_model: entities.LLMModelInfo, + use_funcs: list[tools_entities.LLMFunction] = None, + ) -> llm_entities.Message: + self.client.api_key = use_model.token_mgr.get_token() + + args = self.requester_cfg['args'].copy() + args["model"] = use_model.name if use_model.model_name is None else use_model.model_name + + if use_model.tool_call_supported: + tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs) + + if tools: + args["tools"] = tools + + # 设置此次请求中的messages + messages = req_messages + + # deepseek 不支持多模态,把content都转换成纯文字 + for m in messages: + if isinstance(m["content"], list): + m["content"] = " ".join([c["text"] for c in m["content"]]) + + args["messages"] = messages + + # 发送请求 + resp = await self._req(args) + + # 处理请求结果 + message = await self._make_msg(resp) + + return message \ No newline at end of file diff --git a/pkg/provider/modelmgr/apis/moonshotchatcmpl.py b/pkg/provider/modelmgr/apis/moonshotchatcmpl.py index cb9fd93..f50ca62 100644 --- a/pkg/provider/modelmgr/apis/moonshotchatcmpl.py +++ b/pkg/provider/modelmgr/apis/moonshotchatcmpl.py @@ -3,7 +3,10 @@ from __future__ import annotations from ....core import app from . import chatcmpl -from .. import api +from .. import api, entities, errors +from ....core import entities as core_entities, app +from ... import entities as llm_entities +from ...tools import entities as tools_entities @api.requester_class("moonshot-chat-completions") @@ -13,3 +16,41 @@ class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions): def __init__(self, ap: app.Application): self.requester_cfg = ap.provider_cfg.data['requester']['moonshot-chat-completions'] self.ap = ap + + async def _closure( + self, + req_messages: list[dict], + use_model: entities.LLMModelInfo, + use_funcs: list[tools_entities.LLMFunction] = None, + ) -> llm_entities.Message: + self.client.api_key = use_model.token_mgr.get_token() + + args = self.requester_cfg['args'].copy() + args["model"] = use_model.name if use_model.model_name is None else use_model.model_name + + if use_model.tool_call_supported: + tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs) + + if tools: + args["tools"] = tools + + # 设置此次请求中的messages + messages = req_messages + + # deepseek 不支持多模态,把content都转换成纯文字 + for m in messages: + if isinstance(m["content"], list): + m["content"] = " ".join([c["text"] for c in m["content"]]) + + # 删除空的 + messages = [m for m in messages if m["content"].strip() != ""] + + args["messages"] = messages + + # 发送请求 + resp = await self._req(args) + + # 处理请求结果 + message = await self._make_msg(resp) + + return message \ No newline at end of file diff --git a/pkg/provider/modelmgr/modelmgr.py b/pkg/provider/modelmgr/modelmgr.py index 3fffd78..93cf54d 100644 --- a/pkg/provider/modelmgr/modelmgr.py +++ b/pkg/provider/modelmgr/modelmgr.py @@ -37,6 +37,10 @@ class ModelManager: raise ValueError(f"无法确定模型 {name} 的信息,请在元数据中配置") async def initialize(self): + + # 检查是否启用了vision但是没有配置oss + if self.ap.provider_cfg.data['enable-vision'] and not self.ap.oss_mgr.available(): + self.ap.logger.warn("启用了视觉但是没有配置可用的oss服务,基于 URL 传递图片的视觉 API 将无法正常使用") # 初始化token_mgr, requester for k, v in self.ap.provider_cfg.data['keys'].items(): diff --git a/requirements.txt b/requirements.txt index f04bdc9..63996e0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,4 +13,6 @@ aiohttp pydantic websockets urllib3 -psutil \ No newline at end of file +psutil + +oss2 \ No newline at end of file diff --git a/templates/metadata/llm-models.json b/templates/metadata/llm-models.json index 9787223..13cf93c 100644 --- a/templates/metadata/llm-models.json +++ b/templates/metadata/llm-models.json @@ -22,6 +22,10 @@ "name": "gpt-4-32k", "tool_call_supported": true }, + { + "name": "gpt-4o", + "tool_call_supported": true + }, { "model_name": "SparkDesk", "name": "OneAPI/SparkDesk" diff --git a/templates/provider.json b/templates/provider.json index e537156..dadec8e 100644 --- a/templates/provider.json +++ b/templates/provider.json @@ -1,5 +1,6 @@ { "enable-chat": true, + "enable-vision": false, "keys": { "openai": [ "sk-1234567890" diff --git a/templates/system.json b/templates/system.json index 72d29b9..906640c 100644 --- a/templates/system.json +++ b/templates/system.json @@ -1,5 +1,17 @@ { "admin-sessions": [], + "oss": [ + { + "type": "aliyun", + "endpoint": "https://oss-cn-hangzhou.aliyuncs.com", + "public-read-base-url": "https://qchatgpt.oss-cn-hangzhou.aliyuncs.com", + "access-key-id": "LTAI5tJ5Q5J8J6J5J5J5J5J5", + "access-key-secret": "xxxxxx", + "bucket": "qchatgpt", + "prefix": "qchatgpt", + "enable": false + } + ], "network-proxies": { "http": null, "https": null From 404e5492a3956a21a5fe8110e6b1a592e3916054 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Thu, 16 May 2024 18:29:23 +0800 Subject: [PATCH 05/11] =?UTF-8?q?chore:=20=E5=90=8C=E6=AD=A5=E7=8E=B0?= =?UTF-8?q?=E6=9C=89=E6=A8=A1=E5=9E=8B=E4=BF=A1=E6=81=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- templates/metadata/llm-models.json | 36 ++++++++++++++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/templates/metadata/llm-models.json b/templates/metadata/llm-models.json index 13cf93c..63ae031 100644 --- a/templates/metadata/llm-models.json +++ b/templates/metadata/llm-models.json @@ -6,12 +6,24 @@ "token_mgr": "openai", "tool_call_supported": false }, + { + "name": "gpt-3.5-turbo-0125", + "tool_call_supported": true + }, { "name": "gpt-3.5-turbo", "tool_call_supported": true }, { - "name": "gpt-4", + "name": "gpt-3.5-turbo-1106", + "tool_call_supported": true + }, + { + "name": "gpt-4-turbo", + "tool_call_supported": true + }, + { + "name": "gpt-4-turbo-2024-04-09", "tool_call_supported": true }, { @@ -19,13 +31,33 @@ "tool_call_supported": true }, { - "name": "gpt-4-32k", + "name": "gpt-4-0125-preview", + "tool_call_supported": true + }, + { + "name": "gpt-4-1106-preview", + "tool_call_supported": true + }, + { + "name": "gpt-4", "tool_call_supported": true }, { "name": "gpt-4o", "tool_call_supported": true }, + { + "name": "gpt-4-0613", + "tool_call_supported": true + }, + { + "name": "gpt-4-32k", + "tool_call_supported": true + }, + { + "name": "gpt-4-32k-0613", + "tool_call_supported": true + }, { "model_name": "SparkDesk", "name": "OneAPI/SparkDesk" From 2c478ccc2519d84c18e46b23b80b98ec604a0fac Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Thu, 16 May 2024 20:11:54 +0800 Subject: [PATCH 06/11] =?UTF-8?q?feat:=20=E6=A8=A1=E5=9E=8Bvision=E6=94=AF?= =?UTF-8?q?=E6=8C=81=E6=80=A7=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/pipeline/preproc/preproc.py | 14 +++---- pkg/provider/modelmgr/apis/chatcmpl.py | 4 +- .../modelmgr/apis/deepseekchatcmpl.py | 4 +- .../modelmgr/apis/moonshotchatcmpl.py | 4 +- pkg/provider/modelmgr/entities.py | 2 + pkg/provider/modelmgr/modelmgr.py | 7 +++- templates/metadata/llm-models.json | 42 ++++++++++++------- 7 files changed, 48 insertions(+), 29 deletions(-) diff --git a/pkg/pipeline/preproc/preproc.py b/pkg/pipeline/preproc/preproc.py index 0470a60..ebe4d31 100644 --- a/pkg/pipeline/preproc/preproc.py +++ b/pkg/pipeline/preproc/preproc.py @@ -39,8 +39,13 @@ class PreProcessor(stage.PipelineStage): query.prompt = conversation.prompt.copy() query.messages = conversation.messages.copy() + query.use_model = conversation.use_model + + query.use_funcs = conversation.use_funcs if query.use_model.tool_call_supported else None + + # 检查vision是否启用,没启用就删除所有图片 - if not self.ap.provider_cfg.data['enable-vision']: + if not self.ap.provider_cfg.data['enable-vision'] or not query.use_model.vision_supported: for msg in query.messages: if isinstance(msg.content, list): for me in msg.content: @@ -55,7 +60,7 @@ class PreProcessor(stage.PipelineStage): llm_entities.ContentElement.from_text(me.text) ) elif isinstance(me, mirai.Image): - if self.ap.provider_cfg.data['enable-vision']: + if self.ap.provider_cfg.data['enable-vision'] and query.use_model.vision_supported: if me.url is not None: content_list.append( llm_entities.ContentElement.from_image_url(str(me.url)) @@ -65,11 +70,6 @@ class PreProcessor(stage.PipelineStage): role='user', content=content_list ) - - query.use_model = conversation.use_model - - query.use_funcs = conversation.use_funcs - # =========== 触发事件 PromptPreProcessing event_ctx = await self.ap.plugin_mgr.emit_event( diff --git a/pkg/provider/modelmgr/apis/chatcmpl.py b/pkg/provider/modelmgr/apis/chatcmpl.py index fb6e057..c2bdf8c 100644 --- a/pkg/provider/modelmgr/apis/chatcmpl.py +++ b/pkg/provider/modelmgr/apis/chatcmpl.py @@ -76,7 +76,7 @@ class OpenAIChatCompletions(api.LLMAPIRequester): args = self.requester_cfg['args'].copy() args["model"] = use_model.name if use_model.model_name is None else use_model.model_name - if use_model.tool_call_supported: + if use_funcs: tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs) if tools: @@ -88,7 +88,7 @@ class OpenAIChatCompletions(api.LLMAPIRequester): # 检查vision if self.ap.oss_mgr.available(): for msg in messages: - if isinstance(msg["content"], list): + if 'content' in msg and isinstance(msg["content"], list): for me in msg["content"]: if me["type"] == "image_url": me["image_url"]['url'] = await self.get_oss_url(me["image_url"]['url']) diff --git a/pkg/provider/modelmgr/apis/deepseekchatcmpl.py b/pkg/provider/modelmgr/apis/deepseekchatcmpl.py index 9cb667b..4edc2cb 100644 --- a/pkg/provider/modelmgr/apis/deepseekchatcmpl.py +++ b/pkg/provider/modelmgr/apis/deepseekchatcmpl.py @@ -28,7 +28,7 @@ class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions): args = self.requester_cfg['args'].copy() args["model"] = use_model.name if use_model.model_name is None else use_model.model_name - if use_model.tool_call_supported: + if use_funcs: tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs) if tools: @@ -39,7 +39,7 @@ class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions): # deepseek 不支持多模态,把content都转换成纯文字 for m in messages: - if isinstance(m["content"], list): + if 'content' in m and isinstance(m["content"], list): m["content"] = " ".join([c["text"] for c in m["content"]]) args["messages"] = messages diff --git a/pkg/provider/modelmgr/apis/moonshotchatcmpl.py b/pkg/provider/modelmgr/apis/moonshotchatcmpl.py index f50ca62..2f299b8 100644 --- a/pkg/provider/modelmgr/apis/moonshotchatcmpl.py +++ b/pkg/provider/modelmgr/apis/moonshotchatcmpl.py @@ -28,7 +28,7 @@ class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions): args = self.requester_cfg['args'].copy() args["model"] = use_model.name if use_model.model_name is None else use_model.model_name - if use_model.tool_call_supported: + if use_funcs: tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs) if tools: @@ -39,7 +39,7 @@ class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions): # deepseek 不支持多模态,把content都转换成纯文字 for m in messages: - if isinstance(m["content"], list): + if 'content' in m and isinstance(m["content"], list): m["content"] = " ".join([c["text"] for c in m["content"]]) # 删除空的 diff --git a/pkg/provider/modelmgr/entities.py b/pkg/provider/modelmgr/entities.py index 277f125..79cb544 100644 --- a/pkg/provider/modelmgr/entities.py +++ b/pkg/provider/modelmgr/entities.py @@ -21,5 +21,7 @@ class LLMModelInfo(pydantic.BaseModel): tool_call_supported: typing.Optional[bool] = False + vision_supported: typing.Optional[bool] = False + class Config: arbitrary_types_allowed = True diff --git a/pkg/provider/modelmgr/modelmgr.py b/pkg/provider/modelmgr/modelmgr.py index 93cf54d..85bf8e9 100644 --- a/pkg/provider/modelmgr/modelmgr.py +++ b/pkg/provider/modelmgr/modelmgr.py @@ -87,7 +87,8 @@ class ModelManager: model_name=None, token_mgr=self.token_mgrs[model['token_mgr']], requester=self.requesters[model['requester']], - tool_call_supported=model['tool_call_supported'] + tool_call_supported=model['tool_call_supported'], + vision_supported=model['vision_supported'] ) break @@ -99,13 +100,15 @@ class ModelManager: token_mgr = self.token_mgrs[model['token_mgr']] if 'token_mgr' in model else default_model_info.token_mgr requester = self.requesters[model['requester']] if 'requester' in model else default_model_info.requester tool_call_supported = model.get('tool_call_supported', default_model_info.tool_call_supported) + vision_supported = model.get('vision_supported', default_model_info.vision_supported) model_info = entities.LLMModelInfo( name=model['name'], model_name=model_name, token_mgr=token_mgr, requester=requester, - tool_call_supported=tool_call_supported + tool_call_supported=tool_call_supported, + vision_supported=vision_supported ) self.model_list.append(model_info) diff --git a/templates/metadata/llm-models.json b/templates/metadata/llm-models.json index 63ae031..235eea7 100644 --- a/templates/metadata/llm-models.json +++ b/templates/metadata/llm-models.json @@ -4,59 +4,73 @@ "name": "default", "requester": "openai-chat-completions", "token_mgr": "openai", - "tool_call_supported": false + "tool_call_supported": false, + "vision_supported": false }, { "name": "gpt-3.5-turbo-0125", - "tool_call_supported": true + "tool_call_supported": true, + "vision_supported": false }, { "name": "gpt-3.5-turbo", - "tool_call_supported": true + "tool_call_supported": true, + "vision_supported": false }, { "name": "gpt-3.5-turbo-1106", - "tool_call_supported": true + "tool_call_supported": true, + "vision_supported": false }, { "name": "gpt-4-turbo", - "tool_call_supported": true + "tool_call_supported": true, + "vision_supported": true }, { "name": "gpt-4-turbo-2024-04-09", - "tool_call_supported": true + "tool_call_supported": true, + "vision_supported": true }, { "name": "gpt-4-turbo-preview", - "tool_call_supported": true + "tool_call_supported": true, + "vision_supported": true }, { "name": "gpt-4-0125-preview", - "tool_call_supported": true + "tool_call_supported": true, + "vision_supported": true }, { "name": "gpt-4-1106-preview", - "tool_call_supported": true + "tool_call_supported": true, + "vision_supported": true }, { "name": "gpt-4", - "tool_call_supported": true + "tool_call_supported": true, + "vision_supported": true }, { "name": "gpt-4o", - "tool_call_supported": true + "tool_call_supported": true, + "vision_supported": true }, { "name": "gpt-4-0613", - "tool_call_supported": true + "tool_call_supported": true, + "vision_supported": true }, { "name": "gpt-4-32k", - "tool_call_supported": true + "tool_call_supported": true, + "vision_supported": true }, { "name": "gpt-4-32k-0613", - "tool_call_supported": true + "tool_call_supported": true, + "vision_supported": true }, { "model_name": "SparkDesk", From 6bc6f77af1977ca35819b9a6ffea636203df2fbc Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Thu, 16 May 2024 20:25:51 +0800 Subject: [PATCH 07/11] =?UTF-8?q?feat:=20=E9=80=9A=E8=BF=87=20base64=20?= =?UTF-8?q?=E4=BC=A0=E8=BE=93=E5=9B=BE=E7=89=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/provider/modelmgr/apis/chatcmpl.py | 15 +++++++++- pkg/utils/constants.py | 2 +- pkg/utils/image.py | 41 ++++++++++++++++++++++++++ 3 files changed, 56 insertions(+), 2 deletions(-) create mode 100644 pkg/utils/image.py diff --git a/pkg/provider/modelmgr/apis/chatcmpl.py b/pkg/provider/modelmgr/apis/chatcmpl.py index c2bdf8c..c2242a9 100644 --- a/pkg/provider/modelmgr/apis/chatcmpl.py +++ b/pkg/provider/modelmgr/apis/chatcmpl.py @@ -3,16 +3,19 @@ from __future__ import annotations import asyncio import typing import json +import base64 from typing import AsyncGenerator import openai import openai.types.chat.chat_completion as chat_completion import httpx +import aiohttp from .. import api, entities, errors from ....core import entities as core_entities, app from ... import entities as llm_entities from ...tools import entities as tools_entities +from ....utils import image @api.requester_class("openai-chat-completions") @@ -91,7 +94,8 @@ class OpenAIChatCompletions(api.LLMAPIRequester): if 'content' in msg and isinstance(msg["content"], list): for me in msg["content"]: if me["type"] == "image_url": - me["image_url"]['url'] = await self.get_oss_url(me["image_url"]['url']) + # me["image_url"]['url'] = await self.get_oss_url(me["image_url"]['url']) + me["image_url"]['url'] = await self.get_base64_str(me["image_url"]['url']) args["messages"] = messages @@ -144,3 +148,12 @@ class OpenAIChatCompletions(api.LLMAPIRequester): self.cached_image_oss_url[original_url] = oss_url return oss_url + + async def get_base64_str( + self, + original_url: str, + ) -> str: + + base64_image = await image.qq_image_url_to_base64(original_url) + + return f"data:image/jpeg;base64,{base64_image}" diff --git a/pkg/utils/constants.py b/pkg/utils/constants.py index 0addd9f..81ca04a 100644 --- a/pkg/utils/constants.py +++ b/pkg/utils/constants.py @@ -1 +1 @@ -semantic_version = "v3.1.1" +semantic_version = "v3.2.0" diff --git a/pkg/utils/image.py b/pkg/utils/image.py new file mode 100644 index 0000000..34acc2f --- /dev/null +++ b/pkg/utils/image.py @@ -0,0 +1,41 @@ +import base64 +import typing +from urllib.parse import urlparse, parse_qs +import ssl + +import aiohttp + + +async def qq_image_url_to_base64( + image_url: str +) -> str: + """将QQ图片URL转为base64 + + Args: + image_url (str): QQ图片URL + + Returns: + str: base64编码 + """ + parsed = urlparse(image_url) + query = parse_qs(parsed.query) + + # Flatten the query dictionary + query = {k: v[0] for k, v in query.items()} + + ssl_context = ssl.create_default_context() + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + + async with aiohttp.ClientSession(trust_env=False) as session: + async with session.get( + f"http://{parsed.netloc}{parsed.path}", + params=query, + ssl=ssl_context + ) as resp: + resp.raise_for_status() # 检查HTTP错误 + file_bytes = await resp.read() + + base64_str = base64.b64encode(file_bytes).decode() + + return base64_str From 37ef1c9fab1a390e2f89f2a2b340cff789367dfc Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Thu, 16 May 2024 20:32:30 +0800 Subject: [PATCH 08/11] =?UTF-8?q?feat:=20=E5=88=A0=E9=99=A4oss=E7=9B=B8?= =?UTF-8?q?=E5=85=B3=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../migrations/m006_vision_and_oss_config.py | 35 -------- pkg/config/migrations/m006_vision_config.py | 19 +++++ pkg/core/app.py | 3 - pkg/core/bootutils/deps.py | 1 - pkg/core/stages/build_app.py | 6 -- pkg/core/stages/migrate.py | 4 +- pkg/oss/__init__.py | 0 pkg/oss/oss.py | 85 ------------------- pkg/oss/service.py | 67 --------------- pkg/oss/services/__init__.py | 0 pkg/oss/services/aliyun.py | 48 ----------- pkg/provider/modelmgr/apis/chatcmpl.py | 35 ++------ pkg/provider/modelmgr/modelmgr.py | 4 - requirements.txt | 2 - templates/system.json | 12 --- 15 files changed, 26 insertions(+), 295 deletions(-) delete mode 100644 pkg/config/migrations/m006_vision_and_oss_config.py create mode 100644 pkg/config/migrations/m006_vision_config.py delete mode 100644 pkg/oss/__init__.py delete mode 100644 pkg/oss/oss.py delete mode 100644 pkg/oss/service.py delete mode 100644 pkg/oss/services/__init__.py delete mode 100644 pkg/oss/services/aliyun.py diff --git a/pkg/config/migrations/m006_vision_and_oss_config.py b/pkg/config/migrations/m006_vision_and_oss_config.py deleted file mode 100644 index 2d05b97..0000000 --- a/pkg/config/migrations/m006_vision_and_oss_config.py +++ /dev/null @@ -1,35 +0,0 @@ -from __future__ import annotations - -from .. import migration - - -@migration.migration_class("vision-and-oss-config", 6) -class VisionAndOSSConfigMigration(migration.Migration): - """迁移""" - - async def need_migrate(self) -> bool: - """判断当前环境是否需要运行此迁移""" - return "enable-vision" not in self.ap.provider_cfg.data \ - or "oss" not in self.ap.system_cfg.data - - async def run(self): - """执行迁移""" - if "enable-vision" not in self.ap.provider_cfg.data: - self.ap.provider_cfg.data["enable-vision"] = False - - if "oss" not in self.ap.system_cfg.data: - self.ap.system_cfg.data["oss"] = [ - { - "type": "aliyun", - "endpoint": "https://oss-cn-hangzhou.aliyuncs.com", - "public-read-base-url": "https://qchatgpt.oss-cn-hangzhou.aliyuncs.com", - "access-key-id": "LTAI5tJ5Q5J8J6J5J5J5J5J5", - "access-key-secret": "xxxxxx", - "bucket": "qchatgpt", - "prefix": "qchatgpt", - "enable": False, - } - ] - - await self.ap.provider_cfg.dump_config() - await self.ap.system_cfg.dump_config() diff --git a/pkg/config/migrations/m006_vision_config.py b/pkg/config/migrations/m006_vision_config.py new file mode 100644 index 0000000..8084611 --- /dev/null +++ b/pkg/config/migrations/m006_vision_config.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +from .. import migration + + +@migration.migration_class("vision-config", 6) +class VisionConfigMigration(migration.Migration): + """迁移""" + + async def need_migrate(self) -> bool: + """判断当前环境是否需要运行此迁移""" + return "enable-vision" not in self.ap.provider_cfg.data + + async def run(self): + """执行迁移""" + if "enable-vision" not in self.ap.provider_cfg.data: + self.ap.provider_cfg.data["enable-vision"] = False + + await self.ap.provider_cfg.dump_config() diff --git a/pkg/core/app.py b/pkg/core/app.py index 1705e29..1ed5304 100644 --- a/pkg/core/app.py +++ b/pkg/core/app.py @@ -15,7 +15,6 @@ from ..command import cmdmgr from ..plugin import manager as plugin_mgr from ..pipeline import pool from ..pipeline import controller, stagemgr -from ..oss import oss from ..utils import version as version_mgr, proxy as proxy_mgr @@ -72,8 +71,6 @@ class Application: proxy_mgr: proxy_mgr.ProxyManager = None - oss_mgr: oss.OSSServiceManager = None - logger: logging.Logger = None def __init__(self): diff --git a/pkg/core/bootutils/deps.py b/pkg/core/bootutils/deps.py index ab40704..4adf132 100644 --- a/pkg/core/bootutils/deps.py +++ b/pkg/core/bootutils/deps.py @@ -14,7 +14,6 @@ required_deps = { "yaml": "pyyaml", "aiohttp": "aiohttp", "psutil": "psutil", - "oss2": "oss2", } diff --git a/pkg/core/stages/build_app.py b/pkg/core/stages/build_app.py index ff8c1e7..39ecc02 100644 --- a/pkg/core/stages/build_app.py +++ b/pkg/core/stages/build_app.py @@ -14,8 +14,6 @@ from ...provider.modelmgr import modelmgr as llm_model_mgr from ...provider.sysprompt import sysprompt as llm_prompt_mgr from ...provider.tools import toolmgr as llm_tool_mgr from ...platform import manager as im_mgr -from ...oss import oss as oss_mgr - @stage.stage_class("BuildAppStage") class BuildAppStage(stage.BootingStage): @@ -69,10 +67,6 @@ class BuildAppStage(stage.BootingStage): await cmd_mgr_inst.initialize() ap.cmd_mgr = cmd_mgr_inst - oss_mgr_inst = oss_mgr.OSSServiceManager(ap) - await oss_mgr_inst.initialize() - ap.oss_mgr = oss_mgr_inst - llm_model_mgr_inst = llm_model_mgr.ModelManager(ap) await llm_model_mgr_inst.initialize() ap.model_mgr = llm_model_mgr_inst diff --git a/pkg/core/stages/migrate.py b/pkg/core/stages/migrate.py index 44102fe..4d5b8d8 100644 --- a/pkg/core/stages/migrate.py +++ b/pkg/core/stages/migrate.py @@ -4,8 +4,8 @@ import importlib from .. import stage, app from ...config import migration -from ...config.migrations import m001_sensitive_word_migration, m002_openai_config_migration, m003_anthropic_requester_cfg_completion, m004_moonshot_cfg_completion -from ...config.migrations import m005_deepseek_cfg_completion, m006_vision_and_oss_config +from ...config.migrations import m001_sensitive_word_migration, m002_openai_config_migration, m003_anthropic_requester_cfg_completion, m004_moonshot_cfg_completion, m006_vision_config +from ...config.migrations import m005_deepseek_cfg_completion @stage.stage_class("MigrationStage") diff --git a/pkg/oss/__init__.py b/pkg/oss/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/pkg/oss/oss.py b/pkg/oss/oss.py deleted file mode 100644 index 5474ed3..0000000 --- a/pkg/oss/oss.py +++ /dev/null @@ -1,85 +0,0 @@ -from __future__ import annotations - -import aiohttp -import typing -from urllib.parse import urlparse, parse_qs -import ssl - -from . import service as osssv -from ..core import app -from .services import aliyun - - -class OSSServiceManager: - - ap: app.Application - - service: osssv.OSSService = None - - def __init__(self, ap: app.Application): - self.ap = ap - - async def initialize(self): - """初始化 - """ - - mapping = {} - - for svcls in osssv.preregistered_services: - mapping[svcls.name] = svcls - - for sv in self.ap.system_cfg.data['oss']: - if sv['enable']: - - if sv['type'] not in mapping: - raise Exception(f"未知的OSS服务类型: {sv['type']}") - - self.service = mapping[sv['type']](self.ap, sv) - await self.service.initialize() - break - - def available(self) -> bool: - """是否可用 - - Returns: - bool: 是否可用 - """ - return self.service is not None - - async def fetch_image(self, image_url: str) -> bytes: - parsed = urlparse(image_url) - query = parse_qs(parsed.query) - - # Flatten the query dictionary - query = {k: v[0] for k, v in query.items()} - - ssl_context = ssl.create_default_context() - ssl_context.check_hostname = False - ssl_context.verify_mode = ssl.CERT_NONE - - async with aiohttp.ClientSession(trust_env=False) as session: - async with session.get( - f"http://{parsed.netloc}{parsed.path}", - params=query, - ssl=ssl_context - ) as resp: - resp.raise_for_status() # 检查HTTP错误 - file_bytes = await resp.read() - return file_bytes - - async def upload_url_image( - self, - image_url: str, - ) -> str: - """上传URL图片 - - Args: - image_url (str): 图片URL - - Returns: - str: 文件URL - """ - - file_bytes = await self.fetch_image(image_url) - - return await self.service.upload(file_bytes=file_bytes, ext=".jpg") \ No newline at end of file diff --git a/pkg/oss/service.py b/pkg/oss/service.py deleted file mode 100644 index a822844..0000000 --- a/pkg/oss/service.py +++ /dev/null @@ -1,67 +0,0 @@ -from __future__ import annotations - -import typing -import abc - -from ..core import app - - -preregistered_services: list[typing.Type[OSSService]] = [] - -def service_class( - name: str -) -> typing.Callable[[typing.Type[OSSService]], typing.Type[OSSService]]: - """OSS服务类装饰器 - - Args: - name (str): 服务名称 - - Returns: - typing.Callable[[typing.Type[OSSService]], typing.Type[OSSService]]: 装饰器 - """ - def decorator(cls: typing.Type[OSSService]) -> typing.Type[OSSService]: - assert issubclass(cls, OSSService) - - cls.name = name - - preregistered_services.append(cls) - - return cls - - return decorator - - -class OSSService(metaclass=abc.ABCMeta): - """OSS抽象类""" - - name: str - - ap: app.Application - - cfg: dict - - def __init__(self, ap: app.Application, cfg: dict) -> None: - self.ap = ap - self.cfg = cfg - - async def initialize(self): - pass - - @abc.abstractmethod - async def upload( - self, - local_file: str=None, - file_bytes: bytes=None, - ext: str=None, - ) -> str: - """上传文件 - - Args: - local_file (str, optional): 本地文件路径. Defaults to None. - file_bytes (bytes, optional): 文件字节. Defaults to None. - ext (str, optional): 文件扩展名. Defaults to None. - - Returns: - str: 文件URL - """ - pass diff --git a/pkg/oss/services/__init__.py b/pkg/oss/services/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/pkg/oss/services/aliyun.py b/pkg/oss/services/aliyun.py deleted file mode 100644 index d30ac89..0000000 --- a/pkg/oss/services/aliyun.py +++ /dev/null @@ -1,48 +0,0 @@ -from __future__ import annotations - -import uuid - -import oss2 - -from .. import service as osssv - - -@osssv.service_class('aliyun') -class AliyunOSSService(osssv.OSSService): - """阿里云OSS服务""" - - auth: oss2.Auth - - bucket: oss2.Bucket - - async def initialize(self): - self.auth = oss2.Auth( - self.cfg['access-key-id'], - self.cfg['access-key-secret'] - ) - - self.bucket = oss2.Bucket( - self.auth, - self.cfg['endpoint'], - self.cfg['bucket'] - ) - - async def upload( - self, - local_file: str=None, - file_bytes: bytes=None, - ext: str=None, - ) -> str: - if local_file is not None: - with open(local_file, 'rb') as f: - file_bytes = f.read() - - if file_bytes is None: - raise Exception("缺少文件内容") - - name = str(uuid.uuid1()) - - key = f"{self.cfg['prefix']}/{name}{ext}" - self.bucket.put_object(key, file_bytes) - - return f"{self.cfg['public-read-base-url']}/{key}" diff --git a/pkg/provider/modelmgr/apis/chatcmpl.py b/pkg/provider/modelmgr/apis/chatcmpl.py index c2242a9..709a65a 100644 --- a/pkg/provider/modelmgr/apis/chatcmpl.py +++ b/pkg/provider/modelmgr/apis/chatcmpl.py @@ -26,18 +26,9 @@ class OpenAIChatCompletions(api.LLMAPIRequester): requester_cfg: dict - cached_image_oss_url: dict[str, str] = {} - """缓存的OSS服务的图片URL - - key: 前文message中的原图片URL(QQ图片) - value: OSS服务的图片URL - """ - def __init__(self, ap: app.Application): self.ap = ap - self.cached_image_oss_url = {} - self.requester_cfg = self.ap.provider_cfg.data['requester']['openai-chat-completions'] async def initialize(self): @@ -89,13 +80,11 @@ class OpenAIChatCompletions(api.LLMAPIRequester): messages = req_messages.copy() # 检查vision - if self.ap.oss_mgr.available(): - for msg in messages: - if 'content' in msg and isinstance(msg["content"], list): - for me in msg["content"]: - if me["type"] == "image_url": - # me["image_url"]['url'] = await self.get_oss_url(me["image_url"]['url']) - me["image_url"]['url'] = await self.get_base64_str(me["image_url"]['url']) + for msg in messages: + if 'content' in msg and isinstance(msg["content"], list): + for me in msg["content"]: + if me["type"] == "image_url": + me["image_url"]['url'] = await self.get_base64_str(me["image_url"]['url']) args["messages"] = messages @@ -135,20 +124,6 @@ class OpenAIChatCompletions(api.LLMAPIRequester): except openai.APIError as e: raise errors.RequesterError(f'请求错误: {e.message}') - async def get_oss_url( - self, - original_url: str, - ) -> str: - - if original_url in self.cached_image_oss_url: - return self.cached_image_oss_url[original_url] - - oss_url = await self.ap.oss_mgr.upload_url_image(original_url) - - self.cached_image_oss_url[original_url] = oss_url - - return oss_url - async def get_base64_str( self, original_url: str, diff --git a/pkg/provider/modelmgr/modelmgr.py b/pkg/provider/modelmgr/modelmgr.py index 85bf8e9..79e467a 100644 --- a/pkg/provider/modelmgr/modelmgr.py +++ b/pkg/provider/modelmgr/modelmgr.py @@ -38,10 +38,6 @@ class ModelManager: async def initialize(self): - # 检查是否启用了vision但是没有配置oss - if self.ap.provider_cfg.data['enable-vision'] and not self.ap.oss_mgr.available(): - self.ap.logger.warn("启用了视觉但是没有配置可用的oss服务,基于 URL 传递图片的视觉 API 将无法正常使用") - # 初始化token_mgr, requester for k, v in self.ap.provider_cfg.data['keys'].items(): self.token_mgrs[k] = token.TokenManager(k, v) diff --git a/requirements.txt b/requirements.txt index 63996e0..e43f9f0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,5 +14,3 @@ pydantic websockets urllib3 psutil - -oss2 \ No newline at end of file diff --git a/templates/system.json b/templates/system.json index 906640c..72d29b9 100644 --- a/templates/system.json +++ b/templates/system.json @@ -1,17 +1,5 @@ { "admin-sessions": [], - "oss": [ - { - "type": "aliyun", - "endpoint": "https://oss-cn-hangzhou.aliyuncs.com", - "public-read-base-url": "https://qchatgpt.oss-cn-hangzhou.aliyuncs.com", - "access-key-id": "LTAI5tJ5Q5J8J6J5J5J5J5J5", - "access-key-secret": "xxxxxx", - "bucket": "qchatgpt", - "prefix": "qchatgpt", - "enable": false - } - ], "network-proxies": { "http": null, "https": null From 91e23b8c11a13ca89d7cb2ae01cbd4e0ca0919f2 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Thu, 16 May 2024 20:52:17 +0800 Subject: [PATCH 09/11] =?UTF-8?q?perf:=20=E4=B8=BA=E5=9B=BE=E7=89=87base64?= =?UTF-8?q?=E5=87=BD=E6=95=B0=E6=B7=BB=E5=8A=A0lru?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/core/bootutils/deps.py | 1 + pkg/provider/entities.py | 3 +++ pkg/provider/modelmgr/apis/chatcmpl.py | 3 ++- requirements.txt | 1 + 4 files changed, 7 insertions(+), 1 deletion(-) diff --git a/pkg/core/bootutils/deps.py b/pkg/core/bootutils/deps.py index 4adf132..da0ae07 100644 --- a/pkg/core/bootutils/deps.py +++ b/pkg/core/bootutils/deps.py @@ -14,6 +14,7 @@ required_deps = { "yaml": "pyyaml", "aiohttp": "aiohttp", "psutil": "psutil", + "async_lru": "async-lru", } diff --git a/pkg/provider/entities.py b/pkg/provider/entities.py index b892b89..3b87f5c 100644 --- a/pkg/provider/entities.py +++ b/pkg/provider/entities.py @@ -24,6 +24,9 @@ class ToolCall(pydantic.BaseModel): class ImageURLContentObject(pydantic.BaseModel): url: str + def __str__(self): + return self.url[:128] + ('...' if len(self.url) > 128 else '') + class ContentElement(pydantic.BaseModel): diff --git a/pkg/provider/modelmgr/apis/chatcmpl.py b/pkg/provider/modelmgr/apis/chatcmpl.py index 709a65a..028b208 100644 --- a/pkg/provider/modelmgr/apis/chatcmpl.py +++ b/pkg/provider/modelmgr/apis/chatcmpl.py @@ -10,6 +10,7 @@ import openai import openai.types.chat.chat_completion as chat_completion import httpx import aiohttp +import async_lru from .. import api, entities, errors from ....core import entities as core_entities, app @@ -46,7 +47,6 @@ class OpenAIChatCompletions(api.LLMAPIRequester): self, args: dict, ) -> chat_completion.ChatCompletion: - self.ap.logger.debug(f"req chat_completion with args {args}") return await self.client.chat.completions.create(**args) async def _make_msg( @@ -124,6 +124,7 @@ class OpenAIChatCompletions(api.LLMAPIRequester): except openai.APIError as e: raise errors.RequesterError(f'请求错误: {e.message}') + @async_lru.alru_cache(maxsize=128) async def get_base64_str( self, original_url: str, diff --git a/requirements.txt b/requirements.txt index e43f9f0..44bc285 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,3 +14,4 @@ pydantic websockets urllib3 psutil +async-lru \ No newline at end of file From a3706bfe210f03048d1ecdbcd6b7a898f48e87a8 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Thu, 16 May 2024 21:02:59 +0800 Subject: [PATCH 10/11] =?UTF-8?q?perf:=20=E7=BB=86=E8=8A=82=E4=BC=98?= =?UTF-8?q?=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/command/operators/list.py | 4 ++-- pkg/pipeline/cntfilter/entities.py | 7 ------- pkg/pipeline/cntfilter/filter.py | 10 ---------- templates/provider.json | 2 +- 4 files changed, 3 insertions(+), 20 deletions(-) diff --git a/pkg/command/operators/list.py b/pkg/command/operators/list.py index 258e0ee..ff90d4d 100644 --- a/pkg/command/operators/list.py +++ b/pkg/command/operators/list.py @@ -42,7 +42,7 @@ class ListOperator(operator.CommandOperator): using_conv_index = index if index >= page * record_per_page and index < (page + 1) * record_per_page: - content += f"{index} {time_str}: {conv.messages[0].content if len(conv.messages) > 0 else '无内容'}\n" + content += f"{index} {time_str}: {conv.messages[0].readable_str() if len(conv.messages) > 0 else '无内容'}\n" index += 1 if content == '': @@ -51,6 +51,6 @@ class ListOperator(operator.CommandOperator): if context.session.using_conversation is None: content += "\n当前处于新会话" else: - content += f"\n当前会话: {using_conv_index} {context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S')}: {context.session.using_conversation.messages[0].content if len(context.session.using_conversation.messages) > 0 else '无内容'}" + content += f"\n当前会话: {using_conv_index} {context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S')}: {context.session.using_conversation.messages[0].readable_str() if len(context.session.using_conversation.messages) > 0 else '无内容'}" yield entities.CommandReturn(text=f"第 {page + 1} 页 (时间倒序):\n{content}") diff --git a/pkg/pipeline/cntfilter/entities.py b/pkg/pipeline/cntfilter/entities.py index 83744c1..af60a59 100644 --- a/pkg/pipeline/cntfilter/entities.py +++ b/pkg/pipeline/cntfilter/entities.py @@ -31,13 +31,6 @@ class EnableStage(enum.Enum): """后处理""" -class AcceptContent(enum.Enum): - """过滤器接受的内容模态""" - - TEXT = enum.auto() - - IMAGE_URL = enum.auto() - class FilterResult(pydantic.BaseModel): level: ResultLevel """结果等级 diff --git a/pkg/pipeline/cntfilter/filter.py b/pkg/pipeline/cntfilter/filter.py index 5fd55e4..8eceb87 100644 --- a/pkg/pipeline/cntfilter/filter.py +++ b/pkg/pipeline/cntfilter/filter.py @@ -57,16 +57,6 @@ class ContentFilter(metaclass=abc.ABCMeta): entities.EnableStage.PRE, entities.EnableStage.POST ] - - @property - def accept_content(self): - """本过滤器接受的模态 - - 默认仅接受纯文本 - """ - return [ - entities.AcceptContent.TEXT - ] async def initialize(self): """初始化过滤器 diff --git a/templates/provider.json b/templates/provider.json index dadec8e..309fb82 100644 --- a/templates/provider.json +++ b/templates/provider.json @@ -1,6 +1,6 @@ { "enable-chat": true, - "enable-vision": false, + "enable-vision": true, "keys": { "openai": [ "sk-1234567890" From bae86ac05cfcdb0e57d2149135eec8f900559125 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Thu, 16 May 2024 21:03:56 +0800 Subject: [PATCH 11/11] =?UTF-8?q?chore:=20=E6=81=A2=E5=A4=8D=E7=89=88?= =?UTF-8?q?=E6=9C=AC=E5=8F=B7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/utils/constants.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/utils/constants.py b/pkg/utils/constants.py index 81ca04a..0addd9f 100644 --- a/pkg/utils/constants.py +++ b/pkg/utils/constants.py @@ -1 +1 @@ -semantic_version = "v3.2.0" +semantic_version = "v3.1.1"