mirror of
https://github.com/RockChinQ/QChatGPT.git
synced 2024-11-15 19:22:24 +08:00
commit
78d98c40b1
|
@ -8,7 +8,7 @@ from . import entities, operator, errors
|
|||
from ..config import manager as cfg_mgr
|
||||
|
||||
# 引入所有算子以便注册
|
||||
from .operators import func, plugin, default, reset, list as list_cmd, last, next, delc, resend, prompt, cmd, help, version, update
|
||||
from .operators import func, plugin, default, reset, list as list_cmd, last, next, delc, resend, prompt, cmd, help, version, update, ollama, model
|
||||
|
||||
|
||||
class CommandManager:
|
||||
|
|
86
pkg/command/operators/model.py
Normal file
86
pkg/command/operators/model.py
Normal file
|
@ -0,0 +1,86 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from .. import operator, entities, cmdmgr, errors
|
||||
|
||||
@operator.operator_class(
|
||||
name="model",
|
||||
help='显示和切换模型列表',
|
||||
usage='!model\n!model show <模型名>\n!model set <模型名>',
|
||||
privilege=2
|
||||
)
|
||||
class ModelOperator(operator.CommandOperator):
|
||||
"""Model命令"""
|
||||
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
content = '模型列表:\n'
|
||||
|
||||
model_list = self.ap.model_mgr.model_list
|
||||
|
||||
for model in model_list:
|
||||
content += f"\n名称: {model.name}\n"
|
||||
content += f"请求器: {model.requester.name}\n"
|
||||
|
||||
content += f"\n当前对话使用模型: {context.query.use_model.name}\n"
|
||||
content += f"新对话默认使用模型: {self.ap.provider_cfg.data.get('model')}\n"
|
||||
|
||||
yield entities.CommandReturn(text=content.strip())
|
||||
|
||||
|
||||
@operator.operator_class(
|
||||
name="show",
|
||||
help='显示模型详情',
|
||||
privilege=2,
|
||||
parent_class=ModelOperator
|
||||
)
|
||||
class ModelShowOperator(operator.CommandOperator):
|
||||
"""Model Show命令"""
|
||||
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
model_name = context.crt_params[0]
|
||||
|
||||
model = None
|
||||
for _model in self.ap.model_mgr.model_list:
|
||||
if model_name == _model.name:
|
||||
model = _model
|
||||
break
|
||||
|
||||
if model is None:
|
||||
yield entities.CommandReturn(error=errors.CommandError(f"未找到模型 {model_name}"))
|
||||
else:
|
||||
content = f"模型详情\n"
|
||||
content += f"名称: {model.name}\n"
|
||||
if model.model_name is not None:
|
||||
content += f"请求模型名称: {model.model_name}\n"
|
||||
content += f"请求器: {model.requester.name}\n"
|
||||
content += f"密钥组: {model.token_mgr.provider}\n"
|
||||
content += f"支持视觉: {model.vision_supported}\n"
|
||||
content += f"支持工具: {model.tool_call_supported}\n"
|
||||
|
||||
yield entities.CommandReturn(text=content.strip())
|
||||
|
||||
@operator.operator_class(
|
||||
name="set",
|
||||
help='设置默认使用模型',
|
||||
privilege=2,
|
||||
parent_class=ModelOperator
|
||||
)
|
||||
class ModelSetOperator(operator.CommandOperator):
|
||||
"""Model Set命令"""
|
||||
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
model_name = context.crt_params[0]
|
||||
|
||||
model = None
|
||||
for _model in self.ap.model_mgr.model_list:
|
||||
if model_name == _model.name:
|
||||
model = _model
|
||||
break
|
||||
|
||||
if model is None:
|
||||
yield entities.CommandReturn(error=errors.CommandError(f"未找到模型 {model_name}"))
|
||||
else:
|
||||
self.ap.provider_cfg.data['model'] = model_name
|
||||
await self.ap.provider_cfg.dump_config()
|
||||
yield entities.CommandReturn(text=f"已设置当前使用模型为 {model_name},重置会话以生效")
|
121
pkg/command/operators/ollama.py
Normal file
121
pkg/command/operators/ollama.py
Normal file
|
@ -0,0 +1,121 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import typing
|
||||
import traceback
|
||||
|
||||
import ollama
|
||||
from .. import operator, entities, errors
|
||||
|
||||
|
||||
@operator.operator_class(
|
||||
name="ollama",
|
||||
help="ollama平台操作",
|
||||
usage="!ollama\n!ollama show <模型名>\n!ollama pull <模型名>\n!ollama del <模型名>"
|
||||
)
|
||||
class OllamaOperator(operator.CommandOperator):
|
||||
async def execute(
|
||||
self, context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
try:
|
||||
content: str = '模型列表:\n'
|
||||
model_list: list = ollama.list().get('models', [])
|
||||
for model in model_list:
|
||||
content += f"名称: {model['name']}\n"
|
||||
content += f"修改时间: {model['modified_at']}\n"
|
||||
content += f"大小: {bytes_to_mb(model['size'])}MB\n\n"
|
||||
yield entities.CommandReturn(text=f"{content.strip()}")
|
||||
except ollama.ResponseError as e:
|
||||
yield entities.CommandReturn(error=errors.CommandError(f"无法获取模型列表,请确认 Ollama 服务正常"))
|
||||
|
||||
|
||||
def bytes_to_mb(num_bytes):
|
||||
mb: float = num_bytes / 1024 / 1024
|
||||
return format(mb, '.2f')
|
||||
|
||||
|
||||
@operator.operator_class(
|
||||
name="show",
|
||||
help="ollama模型详情",
|
||||
privilege=2,
|
||||
parent_class=OllamaOperator
|
||||
)
|
||||
class OllamaShowOperator(operator.CommandOperator):
|
||||
async def execute(
|
||||
self, context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
content: str = '模型详情:\n'
|
||||
try:
|
||||
show: dict = ollama.show(model=context.crt_params[0])
|
||||
model_info: dict = show.get('model_info', {})
|
||||
ignore_show: str = 'too long to show...'
|
||||
|
||||
for key in ['license', 'modelfile']:
|
||||
show[key] = ignore_show
|
||||
|
||||
for key in ['tokenizer.chat_template.rag', 'tokenizer.chat_template.tool_use']:
|
||||
model_info[key] = ignore_show
|
||||
|
||||
content += json.dumps(show, indent=4)
|
||||
yield entities.CommandReturn(text=content.strip())
|
||||
except ollama.ResponseError as e:
|
||||
yield entities.CommandReturn(error=errors.CommandError(f"无法获取模型详情,请确认 Ollama 服务正常"))
|
||||
|
||||
@operator.operator_class(
|
||||
name="pull",
|
||||
help="ollama模型拉取",
|
||||
privilege=2,
|
||||
parent_class=OllamaOperator
|
||||
)
|
||||
class OllamaPullOperator(operator.CommandOperator):
|
||||
async def execute(
|
||||
self, context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
try:
|
||||
model_list: list = ollama.list().get('models', [])
|
||||
if context.crt_params[0] in [model['name'] for model in model_list]:
|
||||
yield entities.CommandReturn(text="模型已存在")
|
||||
return
|
||||
except ollama.ResponseError as e:
|
||||
yield entities.CommandReturn(error=errors.CommandError(f"无法获取模型列表,请确认 Ollama 服务正常"))
|
||||
return
|
||||
|
||||
on_progress: bool = False
|
||||
progress_count: int = 0
|
||||
try:
|
||||
for resp in ollama.pull(model=context.crt_params[0], stream=True):
|
||||
total: typing.Any = resp.get('total')
|
||||
if not on_progress:
|
||||
if total is not None:
|
||||
on_progress = True
|
||||
yield entities.CommandReturn(text=resp.get('status'))
|
||||
else:
|
||||
if total is None:
|
||||
on_progress = False
|
||||
|
||||
completed: typing.Any = resp.get('completed')
|
||||
if isinstance(completed, int) and isinstance(total, int):
|
||||
percentage_completed = (completed / total) * 100
|
||||
if percentage_completed > progress_count:
|
||||
progress_count += 10
|
||||
yield entities.CommandReturn(
|
||||
text=f"下载进度: {completed}/{total} ({percentage_completed:.2f}%)")
|
||||
except ollama.ResponseError as e:
|
||||
yield entities.CommandReturn(text=f"拉取失败: {e.error}")
|
||||
|
||||
|
||||
@operator.operator_class(
|
||||
name="del",
|
||||
help="ollama模型删除",
|
||||
privilege=2,
|
||||
parent_class=OllamaOperator
|
||||
)
|
||||
class OllamaDelOperator(operator.CommandOperator):
|
||||
async def execute(
|
||||
self, context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
try:
|
||||
ret: str = ollama.delete(model=context.crt_params[0])['status']
|
||||
except ollama.ResponseError as e:
|
||||
ret = f"{e.error}"
|
||||
yield entities.CommandReturn(text=ret)
|
|
@ -20,7 +20,7 @@ class VersionCommand(operator.CommandOperator):
|
|||
|
||||
try:
|
||||
if await self.ap.ver_mgr.is_new_version_available():
|
||||
reply_str += "\n\n有新版本可用, 使用 !update 更新"
|
||||
reply_str += "\n\n有新版本可用。"
|
||||
except:
|
||||
pass
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ from ..provider.session import sessionmgr as llm_session_mgr
|
|||
from ..provider.modelmgr import modelmgr as llm_model_mgr
|
||||
from ..provider.sysprompt import sysprompt as llm_prompt_mgr
|
||||
from ..provider.tools import toolmgr as llm_tool_mgr
|
||||
from ..provider import runnermgr
|
||||
from ..config import manager as config_mgr
|
||||
from ..audit.center import v2 as center_mgr
|
||||
from ..command import cmdmgr
|
||||
|
@ -33,6 +34,8 @@ class Application:
|
|||
|
||||
tool_mgr: llm_tool_mgr.ToolManager = None
|
||||
|
||||
runner_mgr: runnermgr.RunnerManager = None
|
||||
|
||||
# ======= 配置管理器 =======
|
||||
|
||||
command_cfg: config_mgr.ConfigManager = None
|
||||
|
|
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||
import abc
|
||||
import typing
|
||||
|
||||
from ..core import app
|
||||
from . import app
|
||||
|
||||
|
||||
preregistered_migrations: list[typing.Type[Migration]] = []
|
23
pkg/core/migrations/m010_ollama_requester_config.py
Normal file
23
pkg/core/migrations/m010_ollama_requester_config.py
Normal file
|
@ -0,0 +1,23 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from .. import migration
|
||||
|
||||
|
||||
@migration.migration_class("ollama-requester-config", 10)
|
||||
class MsgTruncatorConfigMigration(migration.Migration):
|
||||
"""迁移"""
|
||||
|
||||
async def need_migrate(self) -> bool:
|
||||
"""判断当前环境是否需要运行此迁移"""
|
||||
return 'ollama-chat' not in self.ap.provider_cfg.data['requester']
|
||||
|
||||
async def run(self):
|
||||
"""执行迁移"""
|
||||
|
||||
self.ap.provider_cfg.data['requester']['ollama-chat'] = {
|
||||
"base-url": "http://127.0.0.1:11434",
|
||||
"args": {},
|
||||
"timeout": 600
|
||||
}
|
||||
|
||||
await self.ap.provider_cfg.dump_config()
|
21
pkg/core/migrations/m011_command_prefix_config.py
Normal file
21
pkg/core/migrations/m011_command_prefix_config.py
Normal file
|
@ -0,0 +1,21 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from .. import migration
|
||||
|
||||
|
||||
@migration.migration_class("command-prefix-config", 11)
|
||||
class CommandPrefixConfigMigration(migration.Migration):
|
||||
"""迁移"""
|
||||
|
||||
async def need_migrate(self) -> bool:
|
||||
"""判断当前环境是否需要运行此迁移"""
|
||||
return 'command-prefix' not in self.ap.command_cfg.data
|
||||
|
||||
async def run(self):
|
||||
"""执行迁移"""
|
||||
|
||||
self.ap.command_cfg.data['command-prefix'] = [
|
||||
"!", "!"
|
||||
]
|
||||
|
||||
await self.ap.command_cfg.dump_config()
|
19
pkg/core/migrations/m012_runner_config.py
Normal file
19
pkg/core/migrations/m012_runner_config.py
Normal file
|
@ -0,0 +1,19 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from .. import migration
|
||||
|
||||
|
||||
@migration.migration_class("runner-config", 12)
|
||||
class RunnerConfigMigration(migration.Migration):
|
||||
"""迁移"""
|
||||
|
||||
async def need_migrate(self) -> bool:
|
||||
"""判断当前环境是否需要运行此迁移"""
|
||||
return 'runner' not in self.ap.provider_cfg.data
|
||||
|
||||
async def run(self):
|
||||
"""执行迁移"""
|
||||
|
||||
self.ap.provider_cfg.data['runner'] = 'local-agent'
|
||||
|
||||
await self.ap.provider_cfg.dump_config()
|
|
@ -13,6 +13,7 @@ from ...provider.session import sessionmgr as llm_session_mgr
|
|||
from ...provider.modelmgr import modelmgr as llm_model_mgr
|
||||
from ...provider.sysprompt import sysprompt as llm_prompt_mgr
|
||||
from ...provider.tools import toolmgr as llm_tool_mgr
|
||||
from ...provider import runnermgr
|
||||
from ...platform import manager as im_mgr
|
||||
|
||||
@stage.stage_class("BuildAppStage")
|
||||
|
@ -81,6 +82,11 @@ class BuildAppStage(stage.BootingStage):
|
|||
llm_tool_mgr_inst = llm_tool_mgr.ToolManager(ap)
|
||||
await llm_tool_mgr_inst.initialize()
|
||||
ap.tool_mgr = llm_tool_mgr_inst
|
||||
|
||||
runner_mgr_inst = runnermgr.RunnerManager(ap)
|
||||
await runner_mgr_inst.initialize()
|
||||
ap.runner_mgr = runner_mgr_inst
|
||||
|
||||
im_mgr_inst = im_mgr.PlatformManager(ap=ap)
|
||||
await im_mgr_inst.initialize()
|
||||
ap.platform_mgr = im_mgr_inst
|
||||
|
|
|
@ -3,9 +3,10 @@ from __future__ import annotations
|
|||
import importlib
|
||||
|
||||
from .. import stage, app
|
||||
from ...config import migration
|
||||
from ...config.migrations import m001_sensitive_word_migration, m002_openai_config_migration, m003_anthropic_requester_cfg_completion, m004_moonshot_cfg_completion
|
||||
from ...config.migrations import m005_deepseek_cfg_completion, m006_vision_config, m007_qcg_center_url, m008_ad_fixwin_config_migrate, m009_msg_truncator_cfg
|
||||
from .. import migration
|
||||
from ..migrations import m001_sensitive_word_migration, m002_openai_config_migration, m003_anthropic_requester_cfg_completion, m004_moonshot_cfg_completion
|
||||
from ..migrations import m005_deepseek_cfg_completion, m006_vision_config, m007_qcg_center_url, m008_ad_fixwin_config_migrate, m009_msg_truncator_cfg
|
||||
from ..migrations import m010_ollama_requester_config, m011_command_prefix_config, m012_runner_config
|
||||
|
||||
|
||||
@stage.stage_class("MigrationStage")
|
||||
|
|
|
@ -10,7 +10,7 @@ import mirai
|
|||
from .. import handler
|
||||
from ... import entities
|
||||
from ....core import entities as core_entities
|
||||
from ....provider import entities as llm_entities
|
||||
from ....provider import entities as llm_entities, runnermgr
|
||||
from ....plugin import events
|
||||
|
||||
|
||||
|
@ -71,7 +71,9 @@ class ChatMessageHandler(handler.MessageHandler):
|
|||
|
||||
try:
|
||||
|
||||
async for result in self.runner(query):
|
||||
runner = self.ap.runner_mgr.get_runner()
|
||||
|
||||
async for result in runner.run(query):
|
||||
query.resp_messages.append(result)
|
||||
|
||||
self.ap.logger.info(f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}')
|
||||
|
@ -108,64 +110,3 @@ class ChatMessageHandler(handler.MessageHandler):
|
|||
response_seconds=int(time.time() - start_time),
|
||||
retry_times=-1,
|
||||
)
|
||||
|
||||
async def runner(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
"""执行一个请求处理过程中的LLM接口请求、函数调用的循环
|
||||
|
||||
这是临时处理方案,后续可能改为使用LangChain或者自研的工作流处理器
|
||||
"""
|
||||
await query.use_model.requester.preprocess(query)
|
||||
|
||||
pending_tool_calls = []
|
||||
|
||||
req_messages = query.prompt.messages.copy() + query.messages.copy() + [query.user_message]
|
||||
|
||||
# 首次请求
|
||||
msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs)
|
||||
|
||||
yield msg
|
||||
|
||||
pending_tool_calls = msg.tool_calls
|
||||
|
||||
req_messages.append(msg)
|
||||
|
||||
# 持续请求,只要还有待处理的工具调用就继续处理调用
|
||||
while pending_tool_calls:
|
||||
for tool_call in pending_tool_calls:
|
||||
try:
|
||||
func = tool_call.function
|
||||
|
||||
parameters = json.loads(func.arguments)
|
||||
|
||||
func_ret = await self.ap.tool_mgr.execute_func_call(
|
||||
query, func.name, parameters
|
||||
)
|
||||
|
||||
msg = llm_entities.Message(
|
||||
role="tool", content=json.dumps(func_ret, ensure_ascii=False), tool_call_id=tool_call.id
|
||||
)
|
||||
|
||||
yield msg
|
||||
|
||||
req_messages.append(msg)
|
||||
except Exception as e:
|
||||
# 工具调用出错,添加一个报错信息到 req_messages
|
||||
err_msg = llm_entities.Message(
|
||||
role="tool", content=f"err: {e}", tool_call_id=tool_call.id
|
||||
)
|
||||
|
||||
yield err_msg
|
||||
|
||||
req_messages.append(err_msg)
|
||||
|
||||
# 处理完所有调用,再次请求
|
||||
msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs)
|
||||
|
||||
yield msg
|
||||
|
||||
pending_tool_calls = msg.tool_calls
|
||||
|
||||
req_messages.append(msg)
|
||||
|
|
|
@ -42,7 +42,9 @@ class Processor(stage.PipelineStage):
|
|||
self.ap.logger.info(f"处理 {query.launcher_type.value}_{query.launcher_id} 的请求({query.query_id}): {message_text}")
|
||||
|
||||
async def generator():
|
||||
if message_text.startswith('!') or message_text.startswith('!'):
|
||||
cmd_prefix = self.ap.command_cfg.data['command-prefix']
|
||||
|
||||
if any(message_text.startswith(prefix) for prefix in cmd_prefix):
|
||||
async for result in self.cmd_handler.handle(query):
|
||||
yield result
|
||||
else:
|
||||
|
|
|
@ -146,9 +146,9 @@ class PlatformManager:
|
|||
if len(self.adapters) == 0:
|
||||
self.ap.logger.warning('未运行平台适配器,请根据文档配置并启用平台适配器。')
|
||||
|
||||
async def send(self, event: mirai.MessageEvent, msg: mirai.MessageChain, adapter: msadapter.MessageSourceAdapter, check_quote=True, check_at_sender=True):
|
||||
async def send(self, event: mirai.MessageEvent, msg: mirai.MessageChain, adapter: msadapter.MessageSourceAdapter):
|
||||
|
||||
if check_at_sender and self.ap.platform_cfg.data['at-sender'] and isinstance(event, GroupMessage):
|
||||
if self.ap.platform_cfg.data['at-sender'] and isinstance(event, GroupMessage):
|
||||
|
||||
msg.insert(
|
||||
0,
|
||||
|
@ -160,7 +160,7 @@ class PlatformManager:
|
|||
await adapter.reply_message(
|
||||
event,
|
||||
msg,
|
||||
quote_origin=True if self.ap.platform_cfg.data['quote-origin'] and check_quote else False
|
||||
quote_origin=True if self.ap.platform_cfg.data['quote-origin'] else False
|
||||
)
|
||||
|
||||
async def run(self):
|
||||
|
|
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||
import typing
|
||||
import abc
|
||||
import pydantic
|
||||
import mirai
|
||||
|
||||
from . import events
|
||||
from ..provider.tools import entities as tools_entities
|
||||
|
@ -165,11 +166,54 @@ class EventContext:
|
|||
}
|
||||
"""
|
||||
|
||||
# ========== 插件可调用的 API ==========
|
||||
|
||||
def add_return(self, key: str, ret):
|
||||
"""添加返回值"""
|
||||
if key not in self.__return_value__:
|
||||
self.__return_value__[key] = []
|
||||
self.__return_value__[key].append(ret)
|
||||
|
||||
async def reply(self, message_chain: mirai.MessageChain):
|
||||
"""回复此次消息请求
|
||||
|
||||
Args:
|
||||
message_chain (mirai.MessageChain): YiriMirai库的消息链,若用户使用的不是 YiriMirai 适配器,程序也能自动转换为目标消息链
|
||||
"""
|
||||
await self.host.ap.platform_mgr.send(
|
||||
event=self.event.query.message_event,
|
||||
msg=message_chain,
|
||||
adapter=self.event.query.adapter,
|
||||
)
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
target_type: str,
|
||||
target_id: str,
|
||||
message: mirai.MessageChain
|
||||
):
|
||||
"""主动发送消息
|
||||
|
||||
Args:
|
||||
target_type (str): 目标类型,`person`或`group`
|
||||
target_id (str): 目标ID
|
||||
message (mirai.MessageChain): YiriMirai库的消息链,若用户使用的不是 YiriMirai 适配器,程序也能自动转换为目标消息链
|
||||
"""
|
||||
await self.event.query.adapter.send_message(
|
||||
target_type=target_type,
|
||||
target_id=target_id,
|
||||
message=message
|
||||
)
|
||||
|
||||
def prevent_postorder(self):
|
||||
"""阻止后续插件执行"""
|
||||
self.__prevent_postorder__ = True
|
||||
|
||||
def prevent_default(self):
|
||||
"""阻止默认行为"""
|
||||
self.__prevent_default__ = True
|
||||
|
||||
# ========== 以下是内部保留方法,插件不应调用 ==========
|
||||
|
||||
def get_return(self, key: str) -> list:
|
||||
"""获取key的所有返回值"""
|
||||
|
@ -183,14 +227,6 @@ class EventContext:
|
|||
return self.__return_value__[key][0]
|
||||
return None
|
||||
|
||||
def prevent_default(self):
|
||||
"""阻止默认行为"""
|
||||
self.__prevent_default__ = True
|
||||
|
||||
def prevent_postorder(self):
|
||||
"""阻止后续插件执行"""
|
||||
self.__prevent_postorder__ = True
|
||||
|
||||
def is_prevented_default(self):
|
||||
"""是否阻止默认行为"""
|
||||
return self.__prevent_default__
|
||||
|
@ -198,6 +234,7 @@ class EventContext:
|
|||
def is_prevented_postorder(self):
|
||||
"""是否阻止后序插件执行"""
|
||||
return self.__prevent_postorder__
|
||||
|
||||
|
||||
def __init__(self, host: APIHost, event: events.BaseEventModel):
|
||||
|
||||
|
|
|
@ -95,7 +95,7 @@ class Message(pydantic.BaseModel):
|
|||
for ce in self.content:
|
||||
if ce.type == 'text':
|
||||
mc.append(mirai.Plain(ce.text))
|
||||
elif ce.type == 'image':
|
||||
elif ce.type == 'image_url':
|
||||
if ce.image_url.url.startswith("http"):
|
||||
mc.append(mirai.Image(url=ce.image_url.url))
|
||||
else: # base64
|
||||
|
|
105
pkg/provider/modelmgr/apis/ollamachat.py
Normal file
105
pkg/provider/modelmgr/apis/ollamachat.py
Normal file
|
@ -0,0 +1,105 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import typing
|
||||
from typing import Union, Mapping, Any, AsyncIterator
|
||||
|
||||
import async_lru
|
||||
import ollama
|
||||
|
||||
from .. import api, entities, errors
|
||||
from ... import entities as llm_entities
|
||||
from ...tools import entities as tools_entities
|
||||
from ....core import app
|
||||
from ....utils import image
|
||||
|
||||
REQUESTER_NAME: str = "ollama-chat"
|
||||
|
||||
|
||||
@api.requester_class(REQUESTER_NAME)
|
||||
class OllamaChatCompletions(api.LLMAPIRequester):
|
||||
"""Ollama平台 ChatCompletion API请求器"""
|
||||
client: ollama.AsyncClient
|
||||
request_cfg: dict
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
super().__init__(ap)
|
||||
self.ap = ap
|
||||
self.request_cfg = self.ap.provider_cfg.data['requester'][REQUESTER_NAME]
|
||||
|
||||
async def initialize(self):
|
||||
os.environ['OLLAMA_HOST'] = self.request_cfg['base-url']
|
||||
self.client = ollama.AsyncClient(
|
||||
timeout=self.request_cfg['timeout']
|
||||
)
|
||||
|
||||
async def _req(self,
|
||||
args: dict,
|
||||
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
|
||||
return await self.client.chat(
|
||||
**args
|
||||
)
|
||||
|
||||
async def _closure(self, req_messages: list[dict], use_model: entities.LLMModelInfo,
|
||||
user_funcs: list[tools_entities.LLMFunction] = None) -> (
|
||||
llm_entities.Message):
|
||||
args: Any = self.request_cfg['args'].copy()
|
||||
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
|
||||
|
||||
messages: list[dict] = req_messages.copy()
|
||||
for msg in messages:
|
||||
if 'content' in msg and isinstance(msg["content"], list):
|
||||
text_content: list = []
|
||||
image_urls: list = []
|
||||
for me in msg["content"]:
|
||||
if me["type"] == "text":
|
||||
text_content.append(me["text"])
|
||||
elif me["type"] == "image_url":
|
||||
image_url = await self.get_base64_str(me["image_url"]['url'])
|
||||
image_urls.append(image_url)
|
||||
msg["content"] = "\n".join(text_content)
|
||||
msg["images"] = [url.split(',')[1] for url in image_urls]
|
||||
args["messages"] = messages
|
||||
|
||||
resp: Mapping[str, Any] | AsyncIterator[Mapping[str, Any]] = await self._req(args)
|
||||
message: llm_entities.Message = await self._make_msg(resp)
|
||||
return message
|
||||
|
||||
async def _make_msg(
|
||||
self,
|
||||
chat_completions: Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]) -> llm_entities.Message:
|
||||
message: Any = chat_completions.pop('message', None)
|
||||
if message is None:
|
||||
raise ValueError("chat_completions must contain a 'message' field")
|
||||
|
||||
message.update(chat_completions)
|
||||
ret_msg: llm_entities.Message = llm_entities.Message(**message)
|
||||
return ret_msg
|
||||
|
||||
async def call(
|
||||
self,
|
||||
model: entities.LLMModelInfo,
|
||||
messages: typing.List[llm_entities.Message],
|
||||
funcs: typing.List[tools_entities.LLMFunction] = None,
|
||||
) -> llm_entities.Message:
|
||||
req_messages: list = []
|
||||
for m in messages:
|
||||
msg_dict: dict = m.dict(exclude_none=True)
|
||||
content: Any = msg_dict.get("content")
|
||||
if isinstance(content, list):
|
||||
if all(isinstance(part, dict) and part.get('type') == 'text' for part in content):
|
||||
msg_dict["content"] = "\n".join(part["text"] for part in content)
|
||||
req_messages.append(msg_dict)
|
||||
try:
|
||||
return await self._closure(req_messages, model)
|
||||
except asyncio.TimeoutError:
|
||||
raise errors.RequesterError('请求超时')
|
||||
|
||||
@async_lru.alru_cache(maxsize=128)
|
||||
async def get_base64_str(
|
||||
self,
|
||||
original_url: str,
|
||||
) -> str:
|
||||
base64_image: str = await image.qq_image_url_to_base64(original_url)
|
||||
return f"data:image/jpeg;base64,{base64_image}"
|
|
@ -6,7 +6,7 @@ from . import entities
|
|||
from ...core import app
|
||||
|
||||
from . import token, api
|
||||
from .apis import chatcmpl, anthropicmsgs, moonshotchatcmpl, deepseekchatcmpl
|
||||
from .apis import chatcmpl, anthropicmsgs, moonshotchatcmpl, deepseekchatcmpl, ollamachat
|
||||
|
||||
FETCH_MODEL_LIST_URL = "https://api.qchatgpt.rockchin.top/api/v2/fetch/model_list"
|
||||
|
||||
|
|
40
pkg/provider/runner.py
Normal file
40
pkg/provider/runner.py
Normal file
|
@ -0,0 +1,40 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import typing
|
||||
|
||||
from ..core import app, entities as core_entities
|
||||
from . import entities as llm_entities
|
||||
|
||||
|
||||
preregistered_runners: list[typing.Type[RequestRunner]] = []
|
||||
|
||||
def runner_class(name: str):
|
||||
"""注册一个请求运行器
|
||||
"""
|
||||
def decorator(cls: typing.Type[RequestRunner]) -> typing.Type[RequestRunner]:
|
||||
cls.name = name
|
||||
preregistered_runners.append(cls)
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class RequestRunner(abc.ABC):
|
||||
"""请求运行器
|
||||
"""
|
||||
name: str = None
|
||||
|
||||
ap: app.Application
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
"""运行请求
|
||||
"""
|
||||
pass
|
27
pkg/provider/runnermgr.py
Normal file
27
pkg/provider/runnermgr.py
Normal file
|
@ -0,0 +1,27 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from . import runner
|
||||
from ..core import app
|
||||
|
||||
from .runners import localagent
|
||||
|
||||
|
||||
class RunnerManager:
|
||||
|
||||
ap: app.Application
|
||||
|
||||
using_runner: runner.RequestRunner
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
async def initialize(self):
|
||||
|
||||
for r in runner.preregistered_runners:
|
||||
if r.name == self.ap.provider_cfg.data['runner']:
|
||||
self.using_runner = r(self.ap)
|
||||
await self.using_runner.initialize()
|
||||
break
|
||||
|
||||
def get_runner(self) -> runner.RequestRunner:
|
||||
return self.using_runner
|
0
pkg/provider/runners/__init__.py
Normal file
0
pkg/provider/runners/__init__.py
Normal file
70
pkg/provider/runners/localagent.py
Normal file
70
pkg/provider/runners/localagent.py
Normal file
|
@ -0,0 +1,70 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import typing
|
||||
|
||||
from .. import runner
|
||||
from ...core import app, entities as core_entities
|
||||
from .. import entities as llm_entities
|
||||
|
||||
|
||||
@runner.runner_class("local-agent")
|
||||
class LocalAgentRunner(runner.RequestRunner):
|
||||
"""本地Agent请求运行器
|
||||
"""
|
||||
|
||||
async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
"""运行请求
|
||||
"""
|
||||
await query.use_model.requester.preprocess(query)
|
||||
|
||||
pending_tool_calls = []
|
||||
|
||||
req_messages = query.prompt.messages.copy() + query.messages.copy() + [query.user_message]
|
||||
|
||||
# 首次请求
|
||||
msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs)
|
||||
|
||||
yield msg
|
||||
|
||||
pending_tool_calls = msg.tool_calls
|
||||
|
||||
req_messages.append(msg)
|
||||
|
||||
# 持续请求,只要还有待处理的工具调用就继续处理调用
|
||||
while pending_tool_calls:
|
||||
for tool_call in pending_tool_calls:
|
||||
try:
|
||||
func = tool_call.function
|
||||
|
||||
parameters = json.loads(func.arguments)
|
||||
|
||||
func_ret = await self.ap.tool_mgr.execute_func_call(
|
||||
query, func.name, parameters
|
||||
)
|
||||
|
||||
msg = llm_entities.Message(
|
||||
role="tool", content=json.dumps(func_ret, ensure_ascii=False), tool_call_id=tool_call.id
|
||||
)
|
||||
|
||||
yield msg
|
||||
|
||||
req_messages.append(msg)
|
||||
except Exception as e:
|
||||
# 工具调用出错,添加一个报错信息到 req_messages
|
||||
err_msg = llm_entities.Message(
|
||||
role="tool", content=f"err: {e}", tool_call_id=tool_call.id
|
||||
)
|
||||
|
||||
yield err_msg
|
||||
|
||||
req_messages.append(err_msg)
|
||||
|
||||
# 处理完所有调用,再次请求
|
||||
msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs)
|
||||
|
||||
yield msg
|
||||
|
||||
pending_tool_calls = msg.tool_calls
|
||||
|
||||
req_messages.append(msg)
|
|
@ -14,4 +14,5 @@ pydantic
|
|||
websockets
|
||||
urllib3
|
||||
psutil
|
||||
async-lru
|
||||
async-lru
|
||||
ollama
|
|
@ -1,3 +1,7 @@
|
|||
{
|
||||
"privilege": {}
|
||||
"privilege": {},
|
||||
"command-prefix": [
|
||||
"!",
|
||||
"!"
|
||||
]
|
||||
}
|
|
@ -37,11 +37,17 @@
|
|||
"base-url": "https://api.deepseek.com",
|
||||
"args": {},
|
||||
"timeout": 120
|
||||
},
|
||||
"ollama-chat": {
|
||||
"base-url": "http://127.0.0.1:11434",
|
||||
"args": {},
|
||||
"timeout": 600
|
||||
}
|
||||
},
|
||||
"model": "gpt-3.5-turbo",
|
||||
"prompt-mode": "normal",
|
||||
"prompt": {
|
||||
"default": ""
|
||||
}
|
||||
},
|
||||
"runner": "local-agent"
|
||||
}
|
Loading…
Reference in New Issue
Block a user