refactor: 重构openai包基础组件架构

This commit is contained in:
RockChinQ 2024-01-27 00:06:38 +08:00
parent 411034902a
commit 850a4eeb7c
35 changed files with 779 additions and 59 deletions

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from . import model as file_model
from ..utils import context
from .impls import pymodule, json as json_file

View File

@ -5,6 +5,9 @@ import asyncio
from ..qqbot import manager as qqbot_mgr
from ..openai import manager as openai_mgr
from ..openai.session import sessionmgr as llm_session_mgr
from ..openai.requester import modelmgr as llm_model_mgr
from ..openai.sysprompt import sysprompt as llm_prompt_mgr
from ..config import manager as config_mgr
from ..database import manager as database_mgr
from ..utils.center import v2 as center_mgr
@ -18,6 +21,12 @@ class Application:
llm_mgr: openai_mgr.OpenAIInteract = None
sess_mgr: llm_session_mgr.SessionManager = None
model_mgr: llm_model_mgr.ModelManager = None
prompt_mgr: llm_prompt_mgr.PromptManager = None
cfg_mgr: config_mgr.ConfigManager = None
tips_mgr: config_mgr.ConfigManager = None

View File

@ -15,7 +15,9 @@ from ..pipeline import stagemgr
from ..audit import identifier
from ..database import manager as db_mgr
from ..openai import manager as llm_mgr
from ..openai import session as llm_session
from ..openai.session import sessionmgr as llm_session_mgr
from ..openai.requester import modelmgr as llm_model_mgr
from ..openai.sysprompt import sysprompt as llm_prompt_mgr
from ..openai import dprompt as llm_dprompt
from ..qqbot import manager as im_mgr
from ..qqbot.cmds import aamgr as im_cmd_aamgr
@ -112,8 +114,18 @@ async def make_app() -> app.Application:
llm_mgr_inst = llm_mgr.OpenAIInteract(ap)
ap.llm_mgr = llm_mgr_inst
# TODO make it async
llm_session.load_sessions()
llm_model_mgr_inst = llm_model_mgr.ModelManager(ap)
await llm_model_mgr_inst.initialize()
ap.model_mgr = llm_model_mgr_inst
llm_session_mgr_inst = llm_session_mgr.SessionManager(ap)
await llm_session_mgr_inst.initialize()
ap.sess_mgr = llm_session_mgr_inst
llm_prompt_mgr_inst = llm_prompt_mgr.PromptManager(ap)
await llm_prompt_mgr_inst.initialize()
ap.prompt_mgr = llm_prompt_mgr_inst
im_mgr_inst = im_mgr.QQBotManager(first_time_init=True, ap=ap)
await im_mgr_inst.initialize()

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import json
from ...config import manager as config_mgr

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import asyncio
import typing
import traceback
from . import app, entities
@ -24,25 +25,115 @@ class Controller:
async def consumer(self):
"""事件处理循环
"""
while True:
selected_query: entities.Query = None
try:
while True:
selected_query: entities.Query = None
# 取请求
async with self.ap.query_pool:
queries: list[entities.Query] = self.ap.query_pool.queries
# 取请求
async with self.ap.query_pool:
queries: list[entities.Query] = self.ap.query_pool.queries
if queries:
selected_query = queries.pop(0) # FCFS
else:
await self.ap.query_pool.condition.wait()
continue
for query in queries:
session = await self.ap.sess_mgr.get_session(query)
self.ap.logger.debug(f"Checking query {query} session {session}")
if selected_query:
async def _process_query(selected_query):
async with self.semaphore:
await self.process_query(selected_query)
asyncio.create_task(_process_query(selected_query))
if not session.semaphore.locked():
selected_query = query
await session.semaphore.acquire()
break
if selected_query: # 找到了
queries.remove(selected_query)
else: # 没找到 说明:没有请求 或者 所有query对应的session都已达到并发上限
await self.ap.query_pool.condition.wait()
continue
if selected_query:
async def _process_query(selected_query):
async with self.semaphore: # 总并发上限
await self.process_query(selected_query)
async with self.ap.query_pool:
(await self.ap.sess_mgr.get_session(selected_query)).semaphore.release()
# 通知其他协程,有新的请求可以处理了
self.ap.query_pool.condition.notify_all()
asyncio.create_task(_process_query(selected_query))
except Exception as e:
self.ap.logger.error(f"事件处理循环出错: {e}")
traceback.print_exc()
async def _check_output(self, result: pipeline_entities.StageProcessResult):
"""检查输出
"""
if result.user_notice:
await self.ap.im_mgr.send(
result.user_notice
)
if result.debug_notice:
self.ap.logger.debug(result.debug_notice)
if result.console_notice:
self.ap.logger.info(result.console_notice)
async def _execute_from_stage(
self,
stage_index: int,
query: entities.Query,
):
"""从指定阶段开始执行
如何看懂这里为什么这么写
去问 GPT-4:
Q1: 现在有一个责任链其中有多个stagequery对象在其中传递stage.process可能返回Result也有可能返回typing.AsyncGenerator[Result, None]
如果返回的是生成器需要挨个生成result检查是否result中是否要求继续如果要求继续就进行下一个stage如果此次生成器产生的result处理完了就继续生成下一个result
调用后续的stage直到该生成器全部生成完责任链中可能有多个stage会返回生成器
Q2: 不是这样的你可能理解有误如果我们责任链上有这些Stage
A B C D E F G
如果所有的stage都返回Result且所有Result都要求继续那么执行顺序是
A B C D E F G
现在假设C返回的是AsyncGenerator那么执行顺序是
A B C D E F G C D E F G C D E F G ...
Q3: 但是如果不止一个stage会返回生成器呢
"""
i = stage_index
while i < len(self.ap.stage_mgr.stage_containers):
stage_container = self.ap.stage_mgr.stage_containers[i]
result = await stage_container.inst.process(query, stage_container.inst_name)
if isinstance(result, pipeline_entities.StageProcessResult): # 直接返回结果
self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {result}")
await self._check_output(result)
if result.result_type == pipeline_entities.ResultType.INTERRUPT:
self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}")
break
elif result.result_type == pipeline_entities.ResultType.CONTINUE:
query = result.new_query
elif isinstance(result, typing.AsyncGenerator): # 生成器
self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} gen")
async for sub_result in result:
self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {sub_result}")
await self._check_output(sub_result)
if sub_result.result_type == pipeline_entities.ResultType.INTERRUPT:
self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}")
break
elif sub_result.result_type == pipeline_entities.ResultType.CONTINUE:
query = sub_result.new_query
await self._execute_from_stage(i + 1, query)
break
i += 1
async def process_query(self, query: entities.Query):
"""处理请求
@ -50,28 +141,7 @@ class Controller:
self.ap.logger.debug(f"Processing query {query}")
try:
for stage_container in self.ap.stage_mgr.stage_containers:
res = await stage_container.inst.process(query, stage_container.inst_name)
self.ap.logger.debug(f"Stage {stage_container.inst_name} res {res}")
if res.user_notice:
await self.ap.im_mgr.send(
query.message_event,
res.user_notice
)
if res.debug_notice:
self.ap.logger.debug(res.debug_notice)
if res.console_notice:
self.ap.logger.info(res.console_notice)
if res.result_type == pipeline_entities.ResultType.INTERRUPT:
self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}")
break
elif res.result_type == pipeline_entities.ResultType.CONTINUE:
query = res.new_query
continue
await self._execute_from_stage(0, query)
except Exception as e:
self.ap.logger.error(f"处理请求时出错 {query}: {e}")
traceback.print_exc()

31
pkg/openai/entities.py Normal file
View File

@ -0,0 +1,31 @@
from __future__ import annotations
import typing
import enum
import pydantic
class MessageRole(enum.Enum):
SYSTEM = 'system'
USER = 'user'
ASSISTANT = 'assistant'
FUNCTION = 'function'
class FunctionCall(pydantic.BaseModel):
name: str
args: dict[str, typing.Any]
class Message(pydantic.BaseModel):
role: MessageRole
content: typing.Optional[str] = None
function_call: typing.Optional[FunctionCall] = None

View File

View File

@ -0,0 +1,31 @@
from __future__ import annotations
import abc
import typing
from ...core import app
from ...core import entities as core_entities
from .. import entities as llm_entities
from ..session import entities as session_entities
class LLMAPIRequester(metaclass=abc.ABCMeta):
"""LLM API请求器
"""
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
pass
@abc.abstractmethod
async def request(
self,
query: core_entities.Query,
conversation: session_entities.Conversation,
) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""请求
"""
raise NotImplementedError

View File

View File

@ -0,0 +1,32 @@
from __future__ import annotations
import asyncio
import typing
import openai
from .. import api
from ....core import entities as core_entities
from ... import entities as llm_entities
from ...session import entities as session_entities
class OpenAIChatCompletion(api.LLMAPIRequester):
client: openai.Client
async def initialize(self):
self.client = openai.Client(
base_url=self.ap.cfg_mgr.data['openai_config']['reverse_proxy'],
timeout=self.ap.cfg_mgr.data['process_message_timeout']
)
async def request(self, query: core_entities.Query, conversation: session_entities.Conversation) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""请求
"""
await asyncio.sleep(10)
yield llm_entities.Message(
role=llm_entities.MessageRole.ASSISTANT,
content="hello"
)

View File

@ -0,0 +1,23 @@
import typing
import pydantic
from . import api
from . import token
class LLMModelInfo(pydantic.BaseModel):
"""模型"""
name: str
provider: str
token_mgr: token.TokenManager
requester: api.LLMAPIRequester
function_call_supported: typing.Optional[bool] = False
class Config:
arbitrary_types_allowed = True

View File

@ -0,0 +1,40 @@
from __future__ import annotations
from . import entities
from ...core import app
from .apis import chatcmpl
from . import token
class ModelManager:
ap: app.Application
model_list: list[entities.LLMModelInfo]
def __init__(self, ap: app.Application):
self.ap = ap
self.model_list = []
async def initialize(self):
openai_chat_completion = chatcmpl.OpenAIChatCompletion(self.ap)
openai_token_mgr = token.TokenManager(self.ap, self.ap.cfg_mgr.data['openai_config']['api_key'].values())
self.model_list.append(
entities.LLMModelInfo(
name="gpt-3.5-turbo",
provider="openai",
token_mgr=openai_token_mgr,
requester=openai_chat_completion,
function_call_supported=True
)
)
async def get_model_by_name(self, name: str) -> entities.LLMModelInfo:
"""通过名称获取模型
"""
for model in self.model_list:
if model.name == name:
return model
raise ValueError(f"Model {name} not found")

View File

@ -0,0 +1,25 @@
from __future__ import annotations
import typing
import pydantic
class TokenManager():
provider: str
tokens: list[str]
using_token_index: typing.Optional[int] = 0
def __init__(self, provider: str, tokens: list[str]):
self.provider = provider
self.tokens = tokens
self.using_token_index = 0
def get_token(self) -> str:
return self.tokens[self.using_token_index]
def next_token(self):
self.using_token_index = (self.using_token_index + 1) % len(self.tokens)

View File

View File

@ -0,0 +1,50 @@
from __future__ import annotations
import datetime
import asyncio
import typing
import pydantic
from ..sysprompt import entities as sysprompt_entities
from .. import entities as llm_entities
from ..requester import entities
from ...core import entities as core_entities
class Conversation(pydantic.BaseModel):
"""对话"""
prompt: sysprompt_entities.Prompt
messages: list[llm_entities.Message]
create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
use_model: entities.LLMModelInfo
class Session(pydantic.BaseModel):
"""会话"""
launcher_type: core_entities.LauncherTypes
launcher_id: int
sender_id: typing.Optional[int] = 0
use_prompt_name: typing.Optional[str] = 'default'
using_conversation: typing.Optional[Conversation] = None
conversations: typing.Optional[list[Conversation]] = []
create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
semaphore: typing.Optional[asyncio.Semaphore] = None
class Config:
arbitrary_types_allowed = True

View File

@ -0,0 +1,50 @@
from __future__ import annotations
import asyncio
from ...core import app, entities as core_entities
from . import entities
class SessionManager:
ap: app.Application
session_list: list[entities.Session]
def __init__(self, ap: app.Application):
self.ap = ap
self.session_list = []
async def initialize(self):
pass
async def get_session(self, query: core_entities.Query) -> entities.Session:
"""获取会话
"""
for session in self.session_list:
if query.launcher_type == session.launcher_type and query.launcher_id == session.launcher_id:
return session
session = entities.Session(
launcher_type=query.launcher_type,
launcher_id=query.launcher_id,
semaphore=asyncio.Semaphore(1) if self.ap.cfg_mgr.data['wait_last_done'] else asyncio.Semaphore(10000)
)
self.session_list.append(session)
return session
async def get_conversation(self, session: entities.Session) -> entities.Conversation:
if not session.conversations:
session.conversations = []
if session.using_conversation is None:
conversation = entities.Conversation(
prompt=await self.ap.prompt_mgr.get_prompt(session.use_prompt_name),
messages=[],
use_model=await self.ap.model_mgr.get_model_by_name(self.ap.cfg_mgr.data['completion_api_params']['model']),
)
session.conversations.append(conversation)
session.using_conversation = conversation
return session.using_conversation

View File

View File

@ -0,0 +1,14 @@
from __future__ import annotations
import typing
import pydantic
from ...openai import entities
class Prompt(pydantic.BaseModel):
"""供AI使用的Prompt"""
name: str
messages: list[entities.Message]

View File

@ -0,0 +1,32 @@
from __future__ import annotations
import abc
from ...core import app
from . import entities
class PromptLoader(metaclass=abc.ABCMeta):
"""Prompt加载器抽象类
"""
ap: app.Application
prompts: list[entities.Prompt]
def __init__(self, ap: app.Application):
self.ap = ap
self.prompts = []
async def initialize(self):
pass
@abc.abstractmethod
async def load(self):
"""加载Prompt
"""
raise NotImplementedError
def get_prompts(self) -> list[entities.Prompt]:
"""获取Prompt列表
"""
return self.prompts

View File

View File

@ -0,0 +1,43 @@
from __future__ import annotations
import json
import os
from .. import loader
from .. import entities
from ....openai import entities as llm_entities
class ScenarioPromptLoader(loader.PromptLoader):
"""加载scenario目录下的json"""
async def load(self):
"""加载Prompt
"""
for file in os.listdir("scenarios"):
with open("scenarios/{}".format(file), "r", encoding="utf-8") as f:
file_str = f.read()
file_name = file.split(".")[0]
file_json = json.loads(file_str)
messages = []
for msg in file_json["prompt"]:
role = llm_entities.MessageRole.SYSTEM
if "role" in msg:
if msg["role"] == "user":
role = llm_entities.MessageRole.USER
elif msg["role"] == "system":
role = llm_entities.MessageRole.SYSTEM
elif msg["role"] == "function":
role = llm_entities.MessageRole.FUNCTION
messages.append(
llm_entities.Message(
role=role,
content=msg['content'],
)
)
prompt = entities.Prompt(
name=file_name,
messages=messages
)
self.prompts.append(prompt)

View File

@ -0,0 +1,42 @@
from __future__ import annotations
import os
from .. import loader
from .. import entities
from ....openai import entities as llm_entities
class SingleSystemPromptLoader(loader.PromptLoader):
"""配置文件中的单条system prompt的prompt加载器
"""
async def load(self):
"""加载Prompt
"""
for name, cnt in self.ap.cfg_mgr.data['default_prompt'].items():
prompt = entities.Prompt(
name=name,
messages=[
llm_entities.Message(
role=llm_entities.MessageRole.SYSTEM,
content=cnt
)
]
)
self.prompts.append(prompt)
for file in os.listdir("prompts"):
with open("prompts/{}".format(file), "r", encoding="utf-8") as f:
file_str = f.read()
file_name = file.split(".")[0]
prompt = entities.Prompt(
name=file_name,
messages=[
llm_entities.Message(
role=llm_entities.MessageRole.SYSTEM,
content=file_str
)
]
)
self.prompts.append(prompt)

View File

@ -0,0 +1,43 @@
from __future__ import annotations
from ...core import app
from . import loader
from .loaders import single, scenario
class PromptManager:
ap: app.Application
loader_inst: loader.PromptLoader
default_prompt: str = 'default'
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
loader_map = {
"normal": single.SingleSystemPromptLoader,
"full_scenario": scenario.ScenarioPromptLoader
}
loader_cls = loader_map[self.ap.cfg_mgr.data['preset_mode']]
self.loader_inst: loader.PromptLoader = loader_cls(self.ap)
await self.loader_inst.initialize()
await self.loader_inst.load()
def get_all_prompts(self) -> list[loader.entities.Prompt]:
"""获取所有Prompt
"""
return self.loader_inst.get_prompts()
async def get_prompt(self, name: str) -> loader.entities.Prompt:
"""获取Prompt
"""
for prompt in self.get_all_prompts():
if prompt.name == name:
return prompt

View File

View File

@ -0,0 +1,25 @@
from __future__ import annotations
import abc
from ...core import app
from ...core import entities as core_entities
from .. import entities
class MessageHandler(metaclass=abc.ABCMeta):
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
pass
@abc.abstractmethod
async def handle(
self,
query: core_entities.Query,
) -> entities.StageProcessResult:
raise NotImplementedError

View File

@ -0,0 +1,38 @@
from __future__ import annotations
import typing
import mirai
from .. import handler
from ... import entities
from ....core import entities as core_entities
class ChatMessageHandler(handler.MessageHandler):
async def handle(
self,
query: core_entities.Query,
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
"""处理
"""
# 取session
# 取conversation
# 调API
# 生成器
session = await self.ap.sess_mgr.get_session(query)
conversation = await self.ap.sess_mgr.get_conversation(session)
async for result in conversation.use_model.requester.request(query, conversation):
query.resp_message_chain = mirai.MessageChain([mirai.Plain(str(result))])
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)

View File

@ -0,0 +1,35 @@
from __future__ import annotations
import typing
import mirai
from .. import handler
from ... import entities
from ....core import entities as core_entities
class CommandHandler(handler.MessageHandler):
async def handle(
self,
query: core_entities.Query,
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
"""处理
"""
query.resp_message_chain = mirai.MessageChain([
mirai.Plain('CommandHandler')
])
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
query.resp_message_chain = mirai.MessageChain([
mirai.Plain('The Second Message')
])
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)

View File

@ -0,0 +1,38 @@
from __future__ import annotations
from ...core import app, entities as core_entities
from . import handler
from .handlers import chat, command
from .. import entities
from .. import stage, entities, stagemgr
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
@stage.stage_class("MessageProcessor")
class Processor(stage.PipelineStage):
cmd_handler: handler.MessageHandler
chat_handler: handler.MessageHandler
async def initialize(self):
self.cmd_handler = command.CommandHandler(self.ap)
self.chat_handler = chat.ChatMessageHandler(self.ap)
await self.cmd_handler.initialize()
await self.chat_handler.initialize()
async def process(
self,
query: core_entities.Query,
stage_inst_name: str,
) -> entities.StageProcessResult:
"""处理
"""
message_text = str(query.message_chain).strip()
if message_text.startswith('!') or message_text.startswith(''):
return self.cmd_handler.handle(query)
else:
return self.chat_handler.handle(query)

View File

View File

@ -0,0 +1,29 @@
from __future__ import annotations
import mirai
from ...core import app
from .. import stage, entities, stagemgr
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
@stage.stage_class("SendResponseBackStage")
class SendResponseBackStage(stage.PipelineStage):
"""发送响应消息
"""
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
"""处理
"""
await self.ap.im_mgr.send(
query.message_event,
query.resp_message_chain
)
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import abc
import typing
from ..core import app, entities as core_entities
from . import entities
@ -37,7 +38,10 @@ class PipelineStage(metaclass=abc.ABCMeta):
self,
query: core_entities.Query,
stage_inst_name: str,
) -> entities.StageProcessResult:
) -> typing.Union[
entities.StageProcessResult,
typing.AsyncGenerator[entities.StageProcessResult, None],
]:
"""处理
"""
raise NotImplementedError

View File

@ -7,7 +7,20 @@ from . import stage
from .resprule import resprule
from .bansess import bansess
from .cntfilter import cntfilter
from .process import process
from .longtext import longtext
from .respback import respback
stage_order = [
"GroupRespondRuleCheckStage",
"BanSessionCheckStage",
"PreContentFilterStage",
"MessageProcessor",
"PostContentFilterStage",
"LongTextProcessStage",
"SendResponseBackStage",
]
class StageInstContainer():
@ -45,3 +58,6 @@ class StageManager:
for stage_containers in self.stage_containers:
await stage_containers.inst.initialize()
# 按照 stage_order 排序
self.stage_containers.sort(key=lambda x: stage_order.index(x.inst_name))

View File

@ -18,10 +18,6 @@ from ..plugin import host as plugin_host
from ..plugin import models as plugin_models
import tips as tips_custom
from ..qqbot import adapter as msadapter
from .resprule import resprule
from .bansess import bansess
from .cntfilter import cntfilter
from .longtext import longtext
from .ratelim import ratelim
from ..core import app, entities as core_entities
@ -41,30 +37,18 @@ class QQBotManager:
# modern
ap: app.Application = None
bansess_mgr: bansess.SessionBanManager = None
cntfilter_mgr: cntfilter.ContentFilterManager = None
longtext_pcs: longtext.LongTextProcessor = None
resprule_chkr: resprule.GroupRespondRuleChecker = None
ratelimiter: ratelim.RateLimiter = None
def __init__(self, first_time_init=True, ap: app.Application = None):
config = context.get_config_manager().data
self.ap = ap
self.bansess_mgr = bansess.SessionBanManager(ap)
self.cntfilter_mgr = cntfilter.ContentFilterManager(ap)
self.longtext_pcs = longtext.LongTextProcessor(ap)
self.resprule_chkr = resprule.GroupRespondRuleChecker(ap)
self.ratelimiter = ratelim.RateLimiter(ap)
self.timeout = config['process_message_timeout']
self.retry = config['retry_times']
async def initialize(self):
await self.bansess_mgr.initialize()
await self.cntfilter_mgr.initialize()
await self.longtext_pcs.initialize()
await self.resprule_chkr.initialize()
await self.ratelimiter.initialize()
config = context.get_config_manager().data

View File

@ -15,7 +15,7 @@ from ..plugin import host as plugin_host
from ..plugin import models as plugin_models
import tips as tips_custom
from ..core import app
from .cntfilter import entities
# from .cntfilter import entities
processing = []