feat: runner 层抽象 (#839)

This commit is contained in:
RockChinQ 2024-07-28 18:45:27 +08:00
parent 48cc3656bd
commit 8cad4089a7
10 changed files with 172 additions and 65 deletions

View File

@ -9,6 +9,7 @@ from ..provider.session import sessionmgr as llm_session_mgr
from ..provider.modelmgr import modelmgr as llm_model_mgr from ..provider.modelmgr import modelmgr as llm_model_mgr
from ..provider.sysprompt import sysprompt as llm_prompt_mgr from ..provider.sysprompt import sysprompt as llm_prompt_mgr
from ..provider.tools import toolmgr as llm_tool_mgr from ..provider.tools import toolmgr as llm_tool_mgr
from ..provider import runnermgr
from ..config import manager as config_mgr from ..config import manager as config_mgr
from ..audit.center import v2 as center_mgr from ..audit.center import v2 as center_mgr
from ..command import cmdmgr from ..command import cmdmgr
@ -33,6 +34,8 @@ class Application:
tool_mgr: llm_tool_mgr.ToolManager = None tool_mgr: llm_tool_mgr.ToolManager = None
runner_mgr: runnermgr.RunnerManager = None
# ======= 配置管理器 ======= # ======= 配置管理器 =======
command_cfg: config_mgr.ConfigManager = None command_cfg: config_mgr.ConfigManager = None

View File

@ -0,0 +1,19 @@
from __future__ import annotations
from .. import migration
@migration.migration_class("runner-config", 12)
class RunnerConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return 'runner' not in self.ap.provider_cfg.data
async def run(self):
"""执行迁移"""
self.ap.provider_cfg.data['runner'] = 'local-agent'
await self.ap.provider_cfg.dump_config()

View File

@ -13,6 +13,7 @@ from ...provider.session import sessionmgr as llm_session_mgr
from ...provider.modelmgr import modelmgr as llm_model_mgr from ...provider.modelmgr import modelmgr as llm_model_mgr
from ...provider.sysprompt import sysprompt as llm_prompt_mgr from ...provider.sysprompt import sysprompt as llm_prompt_mgr
from ...provider.tools import toolmgr as llm_tool_mgr from ...provider.tools import toolmgr as llm_tool_mgr
from ...provider import runnermgr
from ...platform import manager as im_mgr from ...platform import manager as im_mgr
@stage.stage_class("BuildAppStage") @stage.stage_class("BuildAppStage")
@ -81,6 +82,11 @@ class BuildAppStage(stage.BootingStage):
llm_tool_mgr_inst = llm_tool_mgr.ToolManager(ap) llm_tool_mgr_inst = llm_tool_mgr.ToolManager(ap)
await llm_tool_mgr_inst.initialize() await llm_tool_mgr_inst.initialize()
ap.tool_mgr = llm_tool_mgr_inst ap.tool_mgr = llm_tool_mgr_inst
runner_mgr_inst = runnermgr.RunnerManager(ap)
await runner_mgr_inst.initialize()
ap.runner_mgr = runner_mgr_inst
im_mgr_inst = im_mgr.PlatformManager(ap=ap) im_mgr_inst = im_mgr.PlatformManager(ap=ap)
await im_mgr_inst.initialize() await im_mgr_inst.initialize()
ap.platform_mgr = im_mgr_inst ap.platform_mgr = im_mgr_inst

View File

@ -6,7 +6,7 @@ from .. import stage, app
from .. import migration from .. import migration
from ..migrations import m001_sensitive_word_migration, m002_openai_config_migration, m003_anthropic_requester_cfg_completion, m004_moonshot_cfg_completion from ..migrations import m001_sensitive_word_migration, m002_openai_config_migration, m003_anthropic_requester_cfg_completion, m004_moonshot_cfg_completion
from ..migrations import m005_deepseek_cfg_completion, m006_vision_config, m007_qcg_center_url, m008_ad_fixwin_config_migrate, m009_msg_truncator_cfg from ..migrations import m005_deepseek_cfg_completion, m006_vision_config, m007_qcg_center_url, m008_ad_fixwin_config_migrate, m009_msg_truncator_cfg
from ..migrations import m010_ollama_requester_config, m011_command_prefix_config from ..migrations import m010_ollama_requester_config, m011_command_prefix_config, m012_runner_config
@stage.stage_class("MigrationStage") @stage.stage_class("MigrationStage")

View File

@ -10,7 +10,7 @@ import mirai
from .. import handler from .. import handler
from ... import entities from ... import entities
from ....core import entities as core_entities from ....core import entities as core_entities
from ....provider import entities as llm_entities from ....provider import entities as llm_entities, runnermgr
from ....plugin import events from ....plugin import events
@ -71,7 +71,9 @@ class ChatMessageHandler(handler.MessageHandler):
try: try:
async for result in self.runner(query): runner = self.ap.runner_mgr.get_runner()
async for result in runner.run(query):
query.resp_messages.append(result) query.resp_messages.append(result)
self.ap.logger.info(f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}') self.ap.logger.info(f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}')
@ -108,64 +110,3 @@ class ChatMessageHandler(handler.MessageHandler):
response_seconds=int(time.time() - start_time), response_seconds=int(time.time() - start_time),
retry_times=-1, retry_times=-1,
) )
async def runner(
self,
query: core_entities.Query,
) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""执行一个请求处理过程中的LLM接口请求、函数调用的循环
这是临时处理方案后续可能改为使用LangChain或者自研的工作流处理器
"""
await query.use_model.requester.preprocess(query)
pending_tool_calls = []
req_messages = query.prompt.messages.copy() + query.messages.copy() + [query.user_message]
# 首次请求
msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs)
yield msg
pending_tool_calls = msg.tool_calls
req_messages.append(msg)
# 持续请求,只要还有待处理的工具调用就继续处理调用
while pending_tool_calls:
for tool_call in pending_tool_calls:
try:
func = tool_call.function
parameters = json.loads(func.arguments)
func_ret = await self.ap.tool_mgr.execute_func_call(
query, func.name, parameters
)
msg = llm_entities.Message(
role="tool", content=json.dumps(func_ret, ensure_ascii=False), tool_call_id=tool_call.id
)
yield msg
req_messages.append(msg)
except Exception as e:
# 工具调用出错,添加一个报错信息到 req_messages
err_msg = llm_entities.Message(
role="tool", content=f"err: {e}", tool_call_id=tool_call.id
)
yield err_msg
req_messages.append(err_msg)
# 处理完所有调用,再次请求
msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs)
yield msg
pending_tool_calls = msg.tool_calls
req_messages.append(msg)

40
pkg/provider/runner.py Normal file
View File

@ -0,0 +1,40 @@
from __future__ import annotations
import abc
import typing
from ..core import app, entities as core_entities
from . import entities as llm_entities
preregistered_runners: list[typing.Type[RequestRunner]] = []
def runner_class(name: str):
"""注册一个请求运行器
"""
def decorator(cls: typing.Type[RequestRunner]) -> typing.Type[RequestRunner]:
cls.name = name
preregistered_runners.append(cls)
return cls
return decorator
class RequestRunner(abc.ABC):
"""请求运行器
"""
name: str = None
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
pass
@abc.abstractmethod
async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""运行请求
"""
pass

27
pkg/provider/runnermgr.py Normal file
View File

@ -0,0 +1,27 @@
from __future__ import annotations
from . import runner
from ..core import app
from .runners import localagent
class RunnerManager:
ap: app.Application
using_runner: runner.RequestRunner
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
for r in runner.preregistered_runners:
if r.name == self.ap.provider_cfg.data['runner']:
self.using_runner = r(self.ap)
await self.using_runner.initialize()
break
def get_runner(self) -> runner.RequestRunner:
return self.using_runner

View File

View File

@ -0,0 +1,70 @@
from __future__ import annotations
import json
import typing
from .. import runner
from ...core import app, entities as core_entities
from .. import entities as llm_entities
@runner.runner_class("local-agent")
class LocalAgentRunner(runner.RequestRunner):
"""本地Agent请求运行器
"""
async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""运行请求
"""
await query.use_model.requester.preprocess(query)
pending_tool_calls = []
req_messages = query.prompt.messages.copy() + query.messages.copy() + [query.user_message]
# 首次请求
msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs)
yield msg
pending_tool_calls = msg.tool_calls
req_messages.append(msg)
# 持续请求,只要还有待处理的工具调用就继续处理调用
while pending_tool_calls:
for tool_call in pending_tool_calls:
try:
func = tool_call.function
parameters = json.loads(func.arguments)
func_ret = await self.ap.tool_mgr.execute_func_call(
query, func.name, parameters
)
msg = llm_entities.Message(
role="tool", content=json.dumps(func_ret, ensure_ascii=False), tool_call_id=tool_call.id
)
yield msg
req_messages.append(msg)
except Exception as e:
# 工具调用出错,添加一个报错信息到 req_messages
err_msg = llm_entities.Message(
role="tool", content=f"err: {e}", tool_call_id=tool_call.id
)
yield err_msg
req_messages.append(err_msg)
# 处理完所有调用,再次请求
msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs)
yield msg
pending_tool_calls = msg.tool_calls
req_messages.append(msg)

View File

@ -48,5 +48,6 @@
"prompt-mode": "normal", "prompt-mode": "normal",
"prompt": { "prompt": {
"default": "" "default": ""
} },
"runner": "local-agent"
} }