mirror of
https://github.com/RockChinQ/QChatGPT.git
synced 2024-11-15 19:22:24 +08:00
feat: runner 层抽象 (#839)
This commit is contained in:
parent
48cc3656bd
commit
8cad4089a7
|
@ -9,6 +9,7 @@ from ..provider.session import sessionmgr as llm_session_mgr
|
|||
from ..provider.modelmgr import modelmgr as llm_model_mgr
|
||||
from ..provider.sysprompt import sysprompt as llm_prompt_mgr
|
||||
from ..provider.tools import toolmgr as llm_tool_mgr
|
||||
from ..provider import runnermgr
|
||||
from ..config import manager as config_mgr
|
||||
from ..audit.center import v2 as center_mgr
|
||||
from ..command import cmdmgr
|
||||
|
@ -33,6 +34,8 @@ class Application:
|
|||
|
||||
tool_mgr: llm_tool_mgr.ToolManager = None
|
||||
|
||||
runner_mgr: runnermgr.RunnerManager = None
|
||||
|
||||
# ======= 配置管理器 =======
|
||||
|
||||
command_cfg: config_mgr.ConfigManager = None
|
||||
|
|
19
pkg/core/migrations/m012_runner_config.py
Normal file
19
pkg/core/migrations/m012_runner_config.py
Normal file
|
@ -0,0 +1,19 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from .. import migration
|
||||
|
||||
|
||||
@migration.migration_class("runner-config", 12)
|
||||
class RunnerConfigMigration(migration.Migration):
|
||||
"""迁移"""
|
||||
|
||||
async def need_migrate(self) -> bool:
|
||||
"""判断当前环境是否需要运行此迁移"""
|
||||
return 'runner' not in self.ap.provider_cfg.data
|
||||
|
||||
async def run(self):
|
||||
"""执行迁移"""
|
||||
|
||||
self.ap.provider_cfg.data['runner'] = 'local-agent'
|
||||
|
||||
await self.ap.provider_cfg.dump_config()
|
|
@ -13,6 +13,7 @@ from ...provider.session import sessionmgr as llm_session_mgr
|
|||
from ...provider.modelmgr import modelmgr as llm_model_mgr
|
||||
from ...provider.sysprompt import sysprompt as llm_prompt_mgr
|
||||
from ...provider.tools import toolmgr as llm_tool_mgr
|
||||
from ...provider import runnermgr
|
||||
from ...platform import manager as im_mgr
|
||||
|
||||
@stage.stage_class("BuildAppStage")
|
||||
|
@ -81,6 +82,11 @@ class BuildAppStage(stage.BootingStage):
|
|||
llm_tool_mgr_inst = llm_tool_mgr.ToolManager(ap)
|
||||
await llm_tool_mgr_inst.initialize()
|
||||
ap.tool_mgr = llm_tool_mgr_inst
|
||||
|
||||
runner_mgr_inst = runnermgr.RunnerManager(ap)
|
||||
await runner_mgr_inst.initialize()
|
||||
ap.runner_mgr = runner_mgr_inst
|
||||
|
||||
im_mgr_inst = im_mgr.PlatformManager(ap=ap)
|
||||
await im_mgr_inst.initialize()
|
||||
ap.platform_mgr = im_mgr_inst
|
||||
|
|
|
@ -6,7 +6,7 @@ from .. import stage, app
|
|||
from .. import migration
|
||||
from ..migrations import m001_sensitive_word_migration, m002_openai_config_migration, m003_anthropic_requester_cfg_completion, m004_moonshot_cfg_completion
|
||||
from ..migrations import m005_deepseek_cfg_completion, m006_vision_config, m007_qcg_center_url, m008_ad_fixwin_config_migrate, m009_msg_truncator_cfg
|
||||
from ..migrations import m010_ollama_requester_config, m011_command_prefix_config
|
||||
from ..migrations import m010_ollama_requester_config, m011_command_prefix_config, m012_runner_config
|
||||
|
||||
|
||||
@stage.stage_class("MigrationStage")
|
||||
|
|
|
@ -10,7 +10,7 @@ import mirai
|
|||
from .. import handler
|
||||
from ... import entities
|
||||
from ....core import entities as core_entities
|
||||
from ....provider import entities as llm_entities
|
||||
from ....provider import entities as llm_entities, runnermgr
|
||||
from ....plugin import events
|
||||
|
||||
|
||||
|
@ -71,7 +71,9 @@ class ChatMessageHandler(handler.MessageHandler):
|
|||
|
||||
try:
|
||||
|
||||
async for result in self.runner(query):
|
||||
runner = self.ap.runner_mgr.get_runner()
|
||||
|
||||
async for result in runner.run(query):
|
||||
query.resp_messages.append(result)
|
||||
|
||||
self.ap.logger.info(f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}')
|
||||
|
@ -108,64 +110,3 @@ class ChatMessageHandler(handler.MessageHandler):
|
|||
response_seconds=int(time.time() - start_time),
|
||||
retry_times=-1,
|
||||
)
|
||||
|
||||
async def runner(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
"""执行一个请求处理过程中的LLM接口请求、函数调用的循环
|
||||
|
||||
这是临时处理方案,后续可能改为使用LangChain或者自研的工作流处理器
|
||||
"""
|
||||
await query.use_model.requester.preprocess(query)
|
||||
|
||||
pending_tool_calls = []
|
||||
|
||||
req_messages = query.prompt.messages.copy() + query.messages.copy() + [query.user_message]
|
||||
|
||||
# 首次请求
|
||||
msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs)
|
||||
|
||||
yield msg
|
||||
|
||||
pending_tool_calls = msg.tool_calls
|
||||
|
||||
req_messages.append(msg)
|
||||
|
||||
# 持续请求,只要还有待处理的工具调用就继续处理调用
|
||||
while pending_tool_calls:
|
||||
for tool_call in pending_tool_calls:
|
||||
try:
|
||||
func = tool_call.function
|
||||
|
||||
parameters = json.loads(func.arguments)
|
||||
|
||||
func_ret = await self.ap.tool_mgr.execute_func_call(
|
||||
query, func.name, parameters
|
||||
)
|
||||
|
||||
msg = llm_entities.Message(
|
||||
role="tool", content=json.dumps(func_ret, ensure_ascii=False), tool_call_id=tool_call.id
|
||||
)
|
||||
|
||||
yield msg
|
||||
|
||||
req_messages.append(msg)
|
||||
except Exception as e:
|
||||
# 工具调用出错,添加一个报错信息到 req_messages
|
||||
err_msg = llm_entities.Message(
|
||||
role="tool", content=f"err: {e}", tool_call_id=tool_call.id
|
||||
)
|
||||
|
||||
yield err_msg
|
||||
|
||||
req_messages.append(err_msg)
|
||||
|
||||
# 处理完所有调用,再次请求
|
||||
msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs)
|
||||
|
||||
yield msg
|
||||
|
||||
pending_tool_calls = msg.tool_calls
|
||||
|
||||
req_messages.append(msg)
|
||||
|
|
40
pkg/provider/runner.py
Normal file
40
pkg/provider/runner.py
Normal file
|
@ -0,0 +1,40 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import typing
|
||||
|
||||
from ..core import app, entities as core_entities
|
||||
from . import entities as llm_entities
|
||||
|
||||
|
||||
preregistered_runners: list[typing.Type[RequestRunner]] = []
|
||||
|
||||
def runner_class(name: str):
|
||||
"""注册一个请求运行器
|
||||
"""
|
||||
def decorator(cls: typing.Type[RequestRunner]) -> typing.Type[RequestRunner]:
|
||||
cls.name = name
|
||||
preregistered_runners.append(cls)
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class RequestRunner(abc.ABC):
|
||||
"""请求运行器
|
||||
"""
|
||||
name: str = None
|
||||
|
||||
ap: app.Application
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
"""运行请求
|
||||
"""
|
||||
pass
|
27
pkg/provider/runnermgr.py
Normal file
27
pkg/provider/runnermgr.py
Normal file
|
@ -0,0 +1,27 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from . import runner
|
||||
from ..core import app
|
||||
|
||||
from .runners import localagent
|
||||
|
||||
|
||||
class RunnerManager:
|
||||
|
||||
ap: app.Application
|
||||
|
||||
using_runner: runner.RequestRunner
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
async def initialize(self):
|
||||
|
||||
for r in runner.preregistered_runners:
|
||||
if r.name == self.ap.provider_cfg.data['runner']:
|
||||
self.using_runner = r(self.ap)
|
||||
await self.using_runner.initialize()
|
||||
break
|
||||
|
||||
def get_runner(self) -> runner.RequestRunner:
|
||||
return self.using_runner
|
0
pkg/provider/runners/__init__.py
Normal file
0
pkg/provider/runners/__init__.py
Normal file
70
pkg/provider/runners/localagent.py
Normal file
70
pkg/provider/runners/localagent.py
Normal file
|
@ -0,0 +1,70 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import typing
|
||||
|
||||
from .. import runner
|
||||
from ...core import app, entities as core_entities
|
||||
from .. import entities as llm_entities
|
||||
|
||||
|
||||
@runner.runner_class("local-agent")
|
||||
class LocalAgentRunner(runner.RequestRunner):
|
||||
"""本地Agent请求运行器
|
||||
"""
|
||||
|
||||
async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
"""运行请求
|
||||
"""
|
||||
await query.use_model.requester.preprocess(query)
|
||||
|
||||
pending_tool_calls = []
|
||||
|
||||
req_messages = query.prompt.messages.copy() + query.messages.copy() + [query.user_message]
|
||||
|
||||
# 首次请求
|
||||
msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs)
|
||||
|
||||
yield msg
|
||||
|
||||
pending_tool_calls = msg.tool_calls
|
||||
|
||||
req_messages.append(msg)
|
||||
|
||||
# 持续请求,只要还有待处理的工具调用就继续处理调用
|
||||
while pending_tool_calls:
|
||||
for tool_call in pending_tool_calls:
|
||||
try:
|
||||
func = tool_call.function
|
||||
|
||||
parameters = json.loads(func.arguments)
|
||||
|
||||
func_ret = await self.ap.tool_mgr.execute_func_call(
|
||||
query, func.name, parameters
|
||||
)
|
||||
|
||||
msg = llm_entities.Message(
|
||||
role="tool", content=json.dumps(func_ret, ensure_ascii=False), tool_call_id=tool_call.id
|
||||
)
|
||||
|
||||
yield msg
|
||||
|
||||
req_messages.append(msg)
|
||||
except Exception as e:
|
||||
# 工具调用出错,添加一个报错信息到 req_messages
|
||||
err_msg = llm_entities.Message(
|
||||
role="tool", content=f"err: {e}", tool_call_id=tool_call.id
|
||||
)
|
||||
|
||||
yield err_msg
|
||||
|
||||
req_messages.append(err_msg)
|
||||
|
||||
# 处理完所有调用,再次请求
|
||||
msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs)
|
||||
|
||||
yield msg
|
||||
|
||||
pending_tool_calls = msg.tool_calls
|
||||
|
||||
req_messages.append(msg)
|
|
@ -48,5 +48,6 @@
|
|||
"prompt-mode": "normal",
|
||||
"prompt": {
|
||||
"default": ""
|
||||
}
|
||||
},
|
||||
"runner": "local-agent"
|
||||
}
|
Loading…
Reference in New Issue
Block a user