mirror of
https://github.com/RockChinQ/QChatGPT.git
synced 2024-11-16 03:32:33 +08:00
refactor: 重构openai包基础组件架构
This commit is contained in:
parent
411034902a
commit
850a4eeb7c
|
@ -1,3 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from . import model as file_model
|
||||
from ..utils import context
|
||||
from .impls import pymodule, json as json_file
|
||||
|
|
|
@ -5,6 +5,9 @@ import asyncio
|
|||
|
||||
from ..qqbot import manager as qqbot_mgr
|
||||
from ..openai import manager as openai_mgr
|
||||
from ..openai.session import sessionmgr as llm_session_mgr
|
||||
from ..openai.requester import modelmgr as llm_model_mgr
|
||||
from ..openai.sysprompt import sysprompt as llm_prompt_mgr
|
||||
from ..config import manager as config_mgr
|
||||
from ..database import manager as database_mgr
|
||||
from ..utils.center import v2 as center_mgr
|
||||
|
@ -18,6 +21,12 @@ class Application:
|
|||
|
||||
llm_mgr: openai_mgr.OpenAIInteract = None
|
||||
|
||||
sess_mgr: llm_session_mgr.SessionManager = None
|
||||
|
||||
model_mgr: llm_model_mgr.ModelManager = None
|
||||
|
||||
prompt_mgr: llm_prompt_mgr.PromptManager = None
|
||||
|
||||
cfg_mgr: config_mgr.ConfigManager = None
|
||||
|
||||
tips_mgr: config_mgr.ConfigManager = None
|
||||
|
|
|
@ -15,7 +15,9 @@ from ..pipeline import stagemgr
|
|||
from ..audit import identifier
|
||||
from ..database import manager as db_mgr
|
||||
from ..openai import manager as llm_mgr
|
||||
from ..openai import session as llm_session
|
||||
from ..openai.session import sessionmgr as llm_session_mgr
|
||||
from ..openai.requester import modelmgr as llm_model_mgr
|
||||
from ..openai.sysprompt import sysprompt as llm_prompt_mgr
|
||||
from ..openai import dprompt as llm_dprompt
|
||||
from ..qqbot import manager as im_mgr
|
||||
from ..qqbot.cmds import aamgr as im_cmd_aamgr
|
||||
|
@ -112,8 +114,18 @@ async def make_app() -> app.Application:
|
|||
|
||||
llm_mgr_inst = llm_mgr.OpenAIInteract(ap)
|
||||
ap.llm_mgr = llm_mgr_inst
|
||||
# TODO make it async
|
||||
llm_session.load_sessions()
|
||||
|
||||
llm_model_mgr_inst = llm_model_mgr.ModelManager(ap)
|
||||
await llm_model_mgr_inst.initialize()
|
||||
ap.model_mgr = llm_model_mgr_inst
|
||||
|
||||
llm_session_mgr_inst = llm_session_mgr.SessionManager(ap)
|
||||
await llm_session_mgr_inst.initialize()
|
||||
ap.sess_mgr = llm_session_mgr_inst
|
||||
|
||||
llm_prompt_mgr_inst = llm_prompt_mgr.PromptManager(ap)
|
||||
await llm_prompt_mgr_inst.initialize()
|
||||
ap.prompt_mgr = llm_prompt_mgr_inst
|
||||
|
||||
im_mgr_inst = im_mgr.QQBotManager(first_time_init=True, ap=ap)
|
||||
await im_mgr_inst.initialize()
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from ...config import manager as config_mgr
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import typing
|
||||
import traceback
|
||||
|
||||
from . import app, entities
|
||||
|
@ -24,25 +25,115 @@ class Controller:
|
|||
async def consumer(self):
|
||||
"""事件处理循环
|
||||
"""
|
||||
while True:
|
||||
selected_query: entities.Query = None
|
||||
try:
|
||||
while True:
|
||||
selected_query: entities.Query = None
|
||||
|
||||
# 取请求
|
||||
async with self.ap.query_pool:
|
||||
queries: list[entities.Query] = self.ap.query_pool.queries
|
||||
# 取请求
|
||||
async with self.ap.query_pool:
|
||||
queries: list[entities.Query] = self.ap.query_pool.queries
|
||||
|
||||
if queries:
|
||||
selected_query = queries.pop(0) # FCFS
|
||||
else:
|
||||
await self.ap.query_pool.condition.wait()
|
||||
continue
|
||||
for query in queries:
|
||||
session = await self.ap.sess_mgr.get_session(query)
|
||||
self.ap.logger.debug(f"Checking query {query} session {session}")
|
||||
|
||||
if selected_query:
|
||||
async def _process_query(selected_query):
|
||||
async with self.semaphore:
|
||||
await self.process_query(selected_query)
|
||||
|
||||
asyncio.create_task(_process_query(selected_query))
|
||||
if not session.semaphore.locked():
|
||||
selected_query = query
|
||||
await session.semaphore.acquire()
|
||||
|
||||
break
|
||||
|
||||
if selected_query: # 找到了
|
||||
queries.remove(selected_query)
|
||||
else: # 没找到 说明:没有请求 或者 所有query对应的session都已达到并发上限
|
||||
await self.ap.query_pool.condition.wait()
|
||||
continue
|
||||
|
||||
if selected_query:
|
||||
async def _process_query(selected_query):
|
||||
async with self.semaphore: # 总并发上限
|
||||
await self.process_query(selected_query)
|
||||
|
||||
async with self.ap.query_pool:
|
||||
(await self.ap.sess_mgr.get_session(selected_query)).semaphore.release()
|
||||
# 通知其他协程,有新的请求可以处理了
|
||||
self.ap.query_pool.condition.notify_all()
|
||||
|
||||
asyncio.create_task(_process_query(selected_query))
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f"事件处理循环出错: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
async def _check_output(self, result: pipeline_entities.StageProcessResult):
|
||||
"""检查输出
|
||||
"""
|
||||
if result.user_notice:
|
||||
await self.ap.im_mgr.send(
|
||||
result.user_notice
|
||||
)
|
||||
if result.debug_notice:
|
||||
self.ap.logger.debug(result.debug_notice)
|
||||
if result.console_notice:
|
||||
self.ap.logger.info(result.console_notice)
|
||||
|
||||
async def _execute_from_stage(
|
||||
self,
|
||||
stage_index: int,
|
||||
query: entities.Query,
|
||||
):
|
||||
"""从指定阶段开始执行
|
||||
|
||||
如何看懂这里为什么这么写?
|
||||
去问 GPT-4:
|
||||
Q1: 现在有一个责任链,其中有多个stage,query对象在其中传递,stage.process可能返回Result也有可能返回typing.AsyncGenerator[Result, None],
|
||||
如果返回的是生成器,需要挨个生成result,检查是否result中是否要求继续,如果要求继续就进行下一个stage。如果此次生成器产生的result处理完了,就继续生成下一个result,
|
||||
调用后续的stage,直到该生成器全部生成完。责任链中可能有多个stage会返回生成器
|
||||
Q2: 不是这样的,你可能理解有误。如果我们责任链上有这些Stage:
|
||||
|
||||
A B C D E F G
|
||||
|
||||
如果所有的stage都返回Result,且所有Result都要求继续,那么执行顺序是:
|
||||
|
||||
A B C D E F G
|
||||
|
||||
现在假设C返回的是AsyncGenerator,那么执行顺序是:
|
||||
|
||||
A B C D E F G C D E F G C D E F G ...
|
||||
Q3: 但是如果不止一个stage会返回生成器呢?
|
||||
"""
|
||||
i = stage_index
|
||||
|
||||
while i < len(self.ap.stage_mgr.stage_containers):
|
||||
stage_container = self.ap.stage_mgr.stage_containers[i]
|
||||
|
||||
result = await stage_container.inst.process(query, stage_container.inst_name)
|
||||
|
||||
|
||||
if isinstance(result, pipeline_entities.StageProcessResult): # 直接返回结果
|
||||
self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {result}")
|
||||
await self._check_output(result)
|
||||
|
||||
if result.result_type == pipeline_entities.ResultType.INTERRUPT:
|
||||
self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}")
|
||||
break
|
||||
elif result.result_type == pipeline_entities.ResultType.CONTINUE:
|
||||
query = result.new_query
|
||||
elif isinstance(result, typing.AsyncGenerator): # 生成器
|
||||
self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} gen")
|
||||
|
||||
async for sub_result in result:
|
||||
self.ap.logger.debug(f"Stage {stage_container.inst_name} processed query {query} res {sub_result}")
|
||||
await self._check_output(sub_result)
|
||||
|
||||
if sub_result.result_type == pipeline_entities.ResultType.INTERRUPT:
|
||||
self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}")
|
||||
break
|
||||
elif sub_result.result_type == pipeline_entities.ResultType.CONTINUE:
|
||||
query = sub_result.new_query
|
||||
await self._execute_from_stage(i + 1, query)
|
||||
break
|
||||
|
||||
i += 1
|
||||
|
||||
async def process_query(self, query: entities.Query):
|
||||
"""处理请求
|
||||
|
@ -50,28 +141,7 @@ class Controller:
|
|||
self.ap.logger.debug(f"Processing query {query}")
|
||||
|
||||
try:
|
||||
for stage_container in self.ap.stage_mgr.stage_containers:
|
||||
res = await stage_container.inst.process(query, stage_container.inst_name)
|
||||
|
||||
self.ap.logger.debug(f"Stage {stage_container.inst_name} res {res}")
|
||||
|
||||
if res.user_notice:
|
||||
await self.ap.im_mgr.send(
|
||||
query.message_event,
|
||||
res.user_notice
|
||||
)
|
||||
if res.debug_notice:
|
||||
self.ap.logger.debug(res.debug_notice)
|
||||
if res.console_notice:
|
||||
self.ap.logger.info(res.console_notice)
|
||||
|
||||
if res.result_type == pipeline_entities.ResultType.INTERRUPT:
|
||||
self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}")
|
||||
break
|
||||
elif res.result_type == pipeline_entities.ResultType.CONTINUE:
|
||||
query = res.new_query
|
||||
continue
|
||||
|
||||
await self._execute_from_stage(0, query)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f"处理请求时出错 {query}: {e}")
|
||||
traceback.print_exc()
|
||||
|
|
31
pkg/openai/entities.py
Normal file
31
pkg/openai/entities.py
Normal file
|
@ -0,0 +1,31 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import enum
|
||||
import pydantic
|
||||
|
||||
|
||||
class MessageRole(enum.Enum):
|
||||
|
||||
SYSTEM = 'system'
|
||||
|
||||
USER = 'user'
|
||||
|
||||
ASSISTANT = 'assistant'
|
||||
|
||||
FUNCTION = 'function'
|
||||
|
||||
|
||||
class FunctionCall(pydantic.BaseModel):
|
||||
name: str
|
||||
|
||||
args: dict[str, typing.Any]
|
||||
|
||||
|
||||
class Message(pydantic.BaseModel):
|
||||
|
||||
role: MessageRole
|
||||
|
||||
content: typing.Optional[str] = None
|
||||
|
||||
function_call: typing.Optional[FunctionCall] = None
|
0
pkg/openai/requester/__init__.py
Normal file
0
pkg/openai/requester/__init__.py
Normal file
31
pkg/openai/requester/api.py
Normal file
31
pkg/openai/requester/api.py
Normal file
|
@ -0,0 +1,31 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import typing
|
||||
|
||||
from ...core import app
|
||||
from ...core import entities as core_entities
|
||||
from .. import entities as llm_entities
|
||||
from ..session import entities as session_entities
|
||||
|
||||
class LLMAPIRequester(metaclass=abc.ABCMeta):
|
||||
"""LLM API请求器
|
||||
"""
|
||||
|
||||
ap: app.Application
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def request(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
conversation: session_entities.Conversation,
|
||||
) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
"""请求
|
||||
"""
|
||||
raise NotImplementedError
|
0
pkg/openai/requester/apis/__init__.py
Normal file
0
pkg/openai/requester/apis/__init__.py
Normal file
32
pkg/openai/requester/apis/chatcmpl.py
Normal file
32
pkg/openai/requester/apis/chatcmpl.py
Normal file
|
@ -0,0 +1,32 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import typing
|
||||
|
||||
import openai
|
||||
|
||||
from .. import api
|
||||
from ....core import entities as core_entities
|
||||
from ... import entities as llm_entities
|
||||
from ...session import entities as session_entities
|
||||
|
||||
|
||||
class OpenAIChatCompletion(api.LLMAPIRequester):
|
||||
|
||||
client: openai.Client
|
||||
|
||||
async def initialize(self):
|
||||
self.client = openai.Client(
|
||||
base_url=self.ap.cfg_mgr.data['openai_config']['reverse_proxy'],
|
||||
timeout=self.ap.cfg_mgr.data['process_message_timeout']
|
||||
)
|
||||
|
||||
async def request(self, query: core_entities.Query, conversation: session_entities.Conversation) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
"""请求
|
||||
"""
|
||||
await asyncio.sleep(10)
|
||||
|
||||
yield llm_entities.Message(
|
||||
role=llm_entities.MessageRole.ASSISTANT,
|
||||
content="hello"
|
||||
)
|
23
pkg/openai/requester/entities.py
Normal file
23
pkg/openai/requester/entities.py
Normal file
|
@ -0,0 +1,23 @@
|
|||
import typing
|
||||
|
||||
import pydantic
|
||||
|
||||
from . import api
|
||||
from . import token
|
||||
|
||||
|
||||
class LLMModelInfo(pydantic.BaseModel):
|
||||
"""模型"""
|
||||
|
||||
name: str
|
||||
|
||||
provider: str
|
||||
|
||||
token_mgr: token.TokenManager
|
||||
|
||||
requester: api.LLMAPIRequester
|
||||
|
||||
function_call_supported: typing.Optional[bool] = False
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
40
pkg/openai/requester/modelmgr.py
Normal file
40
pkg/openai/requester/modelmgr.py
Normal file
|
@ -0,0 +1,40 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from . import entities
|
||||
from ...core import app
|
||||
|
||||
from .apis import chatcmpl
|
||||
from . import token
|
||||
|
||||
|
||||
class ModelManager:
|
||||
|
||||
ap: app.Application
|
||||
|
||||
model_list: list[entities.LLMModelInfo]
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
self.model_list = []
|
||||
|
||||
async def initialize(self):
|
||||
openai_chat_completion = chatcmpl.OpenAIChatCompletion(self.ap)
|
||||
openai_token_mgr = token.TokenManager(self.ap, self.ap.cfg_mgr.data['openai_config']['api_key'].values())
|
||||
|
||||
self.model_list.append(
|
||||
entities.LLMModelInfo(
|
||||
name="gpt-3.5-turbo",
|
||||
provider="openai",
|
||||
token_mgr=openai_token_mgr,
|
||||
requester=openai_chat_completion,
|
||||
function_call_supported=True
|
||||
)
|
||||
)
|
||||
|
||||
async def get_model_by_name(self, name: str) -> entities.LLMModelInfo:
|
||||
"""通过名称获取模型
|
||||
"""
|
||||
for model in self.model_list:
|
||||
if model.name == name:
|
||||
return model
|
||||
raise ValueError(f"Model {name} not found")
|
25
pkg/openai/requester/token.py
Normal file
25
pkg/openai/requester/token.py
Normal file
|
@ -0,0 +1,25 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
import pydantic
|
||||
|
||||
|
||||
class TokenManager():
|
||||
|
||||
provider: str
|
||||
|
||||
tokens: list[str]
|
||||
|
||||
using_token_index: typing.Optional[int] = 0
|
||||
|
||||
def __init__(self, provider: str, tokens: list[str]):
|
||||
self.provider = provider
|
||||
self.tokens = tokens
|
||||
self.using_token_index = 0
|
||||
|
||||
def get_token(self) -> str:
|
||||
return self.tokens[self.using_token_index]
|
||||
|
||||
def next_token(self):
|
||||
self.using_token_index = (self.using_token_index + 1) % len(self.tokens)
|
0
pkg/openai/session/__init__.py
Normal file
0
pkg/openai/session/__init__.py
Normal file
50
pkg/openai/session/entities.py
Normal file
50
pkg/openai/session/entities.py
Normal file
|
@ -0,0 +1,50 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import asyncio
|
||||
import typing
|
||||
|
||||
import pydantic
|
||||
|
||||
from ..sysprompt import entities as sysprompt_entities
|
||||
from .. import entities as llm_entities
|
||||
from ..requester import entities
|
||||
from ...core import entities as core_entities
|
||||
|
||||
|
||||
class Conversation(pydantic.BaseModel):
|
||||
"""对话"""
|
||||
|
||||
prompt: sysprompt_entities.Prompt
|
||||
|
||||
messages: list[llm_entities.Message]
|
||||
|
||||
create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
|
||||
|
||||
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
|
||||
|
||||
use_model: entities.LLMModelInfo
|
||||
|
||||
|
||||
class Session(pydantic.BaseModel):
|
||||
"""会话"""
|
||||
launcher_type: core_entities.LauncherTypes
|
||||
|
||||
launcher_id: int
|
||||
|
||||
sender_id: typing.Optional[int] = 0
|
||||
|
||||
use_prompt_name: typing.Optional[str] = 'default'
|
||||
|
||||
using_conversation: typing.Optional[Conversation] = None
|
||||
|
||||
conversations: typing.Optional[list[Conversation]] = []
|
||||
|
||||
create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
|
||||
|
||||
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
|
||||
|
||||
semaphore: typing.Optional[asyncio.Semaphore] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
50
pkg/openai/session/sessionmgr.py
Normal file
50
pkg/openai/session/sessionmgr.py
Normal file
|
@ -0,0 +1,50 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
from ...core import app, entities as core_entities
|
||||
from . import entities
|
||||
|
||||
|
||||
class SessionManager:
|
||||
|
||||
ap: app.Application
|
||||
|
||||
session_list: list[entities.Session]
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
self.session_list = []
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def get_session(self, query: core_entities.Query) -> entities.Session:
|
||||
"""获取会话
|
||||
"""
|
||||
for session in self.session_list:
|
||||
if query.launcher_type == session.launcher_type and query.launcher_id == session.launcher_id:
|
||||
return session
|
||||
|
||||
session = entities.Session(
|
||||
launcher_type=query.launcher_type,
|
||||
launcher_id=query.launcher_id,
|
||||
semaphore=asyncio.Semaphore(1) if self.ap.cfg_mgr.data['wait_last_done'] else asyncio.Semaphore(10000)
|
||||
)
|
||||
self.session_list.append(session)
|
||||
return session
|
||||
|
||||
async def get_conversation(self, session: entities.Session) -> entities.Conversation:
|
||||
if not session.conversations:
|
||||
session.conversations = []
|
||||
|
||||
if session.using_conversation is None:
|
||||
conversation = entities.Conversation(
|
||||
prompt=await self.ap.prompt_mgr.get_prompt(session.use_prompt_name),
|
||||
messages=[],
|
||||
use_model=await self.ap.model_mgr.get_model_by_name(self.ap.cfg_mgr.data['completion_api_params']['model']),
|
||||
)
|
||||
session.conversations.append(conversation)
|
||||
session.using_conversation = conversation
|
||||
|
||||
return session.using_conversation
|
0
pkg/openai/sysprompt/__init__.py
Normal file
0
pkg/openai/sysprompt/__init__.py
Normal file
14
pkg/openai/sysprompt/entities.py
Normal file
14
pkg/openai/sysprompt/entities.py
Normal file
|
@ -0,0 +1,14 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import pydantic
|
||||
|
||||
from ...openai import entities
|
||||
|
||||
|
||||
class Prompt(pydantic.BaseModel):
|
||||
"""供AI使用的Prompt"""
|
||||
|
||||
name: str
|
||||
|
||||
messages: list[entities.Message]
|
32
pkg/openai/sysprompt/loader.py
Normal file
32
pkg/openai/sysprompt/loader.py
Normal file
|
@ -0,0 +1,32 @@
|
|||
from __future__ import annotations
|
||||
import abc
|
||||
|
||||
from ...core import app
|
||||
from . import entities
|
||||
|
||||
|
||||
class PromptLoader(metaclass=abc.ABCMeta):
|
||||
"""Prompt加载器抽象类
|
||||
"""
|
||||
|
||||
ap: app.Application
|
||||
|
||||
prompts: list[entities.Prompt]
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
self.prompts = []
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def load(self):
|
||||
"""加载Prompt
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_prompts(self) -> list[entities.Prompt]:
|
||||
"""获取Prompt列表
|
||||
"""
|
||||
return self.prompts
|
0
pkg/openai/sysprompt/loaders/__init__.py
Normal file
0
pkg/openai/sysprompt/loaders/__init__.py
Normal file
43
pkg/openai/sysprompt/loaders/scenario.py
Normal file
43
pkg/openai/sysprompt/loaders/scenario.py
Normal file
|
@ -0,0 +1,43 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
from .. import loader
|
||||
from .. import entities
|
||||
from ....openai import entities as llm_entities
|
||||
|
||||
|
||||
class ScenarioPromptLoader(loader.PromptLoader):
|
||||
"""加载scenario目录下的json"""
|
||||
|
||||
async def load(self):
|
||||
"""加载Prompt
|
||||
"""
|
||||
for file in os.listdir("scenarios"):
|
||||
with open("scenarios/{}".format(file), "r", encoding="utf-8") as f:
|
||||
file_str = f.read()
|
||||
file_name = file.split(".")[0]
|
||||
file_json = json.loads(file_str)
|
||||
messages = []
|
||||
for msg in file_json["prompt"]:
|
||||
role = llm_entities.MessageRole.SYSTEM
|
||||
if "role" in msg:
|
||||
if msg["role"] == "user":
|
||||
role = llm_entities.MessageRole.USER
|
||||
elif msg["role"] == "system":
|
||||
role = llm_entities.MessageRole.SYSTEM
|
||||
elif msg["role"] == "function":
|
||||
role = llm_entities.MessageRole.FUNCTION
|
||||
messages.append(
|
||||
llm_entities.Message(
|
||||
role=role,
|
||||
content=msg['content'],
|
||||
)
|
||||
)
|
||||
prompt = entities.Prompt(
|
||||
name=file_name,
|
||||
messages=messages
|
||||
)
|
||||
self.prompts.append(prompt)
|
||||
|
42
pkg/openai/sysprompt/loaders/single.py
Normal file
42
pkg/openai/sysprompt/loaders/single.py
Normal file
|
@ -0,0 +1,42 @@
|
|||
from __future__ import annotations
|
||||
import os
|
||||
|
||||
from .. import loader
|
||||
from .. import entities
|
||||
from ....openai import entities as llm_entities
|
||||
|
||||
|
||||
class SingleSystemPromptLoader(loader.PromptLoader):
|
||||
"""配置文件中的单条system prompt的prompt加载器
|
||||
"""
|
||||
|
||||
async def load(self):
|
||||
"""加载Prompt
|
||||
"""
|
||||
|
||||
for name, cnt in self.ap.cfg_mgr.data['default_prompt'].items():
|
||||
prompt = entities.Prompt(
|
||||
name=name,
|
||||
messages=[
|
||||
llm_entities.Message(
|
||||
role=llm_entities.MessageRole.SYSTEM,
|
||||
content=cnt
|
||||
)
|
||||
]
|
||||
)
|
||||
self.prompts.append(prompt)
|
||||
|
||||
for file in os.listdir("prompts"):
|
||||
with open("prompts/{}".format(file), "r", encoding="utf-8") as f:
|
||||
file_str = f.read()
|
||||
file_name = file.split(".")[0]
|
||||
prompt = entities.Prompt(
|
||||
name=file_name,
|
||||
messages=[
|
||||
llm_entities.Message(
|
||||
role=llm_entities.MessageRole.SYSTEM,
|
||||
content=file_str
|
||||
)
|
||||
]
|
||||
)
|
||||
self.prompts.append(prompt)
|
43
pkg/openai/sysprompt/sysprompt.py
Normal file
43
pkg/openai/sysprompt/sysprompt.py
Normal file
|
@ -0,0 +1,43 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from ...core import app
|
||||
from . import loader
|
||||
from .loaders import single, scenario
|
||||
|
||||
|
||||
class PromptManager:
|
||||
|
||||
ap: app.Application
|
||||
|
||||
loader_inst: loader.PromptLoader
|
||||
|
||||
default_prompt: str = 'default'
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
async def initialize(self):
|
||||
|
||||
loader_map = {
|
||||
"normal": single.SingleSystemPromptLoader,
|
||||
"full_scenario": scenario.ScenarioPromptLoader
|
||||
}
|
||||
|
||||
loader_cls = loader_map[self.ap.cfg_mgr.data['preset_mode']]
|
||||
|
||||
self.loader_inst: loader.PromptLoader = loader_cls(self.ap)
|
||||
|
||||
await self.loader_inst.initialize()
|
||||
await self.loader_inst.load()
|
||||
|
||||
def get_all_prompts(self) -> list[loader.entities.Prompt]:
|
||||
"""获取所有Prompt
|
||||
"""
|
||||
return self.loader_inst.get_prompts()
|
||||
|
||||
async def get_prompt(self, name: str) -> loader.entities.Prompt:
|
||||
"""获取Prompt
|
||||
"""
|
||||
for prompt in self.get_all_prompts():
|
||||
if prompt.name == name:
|
||||
return prompt
|
0
pkg/pipeline/process/__init__.py
Normal file
0
pkg/pipeline/process/__init__.py
Normal file
25
pkg/pipeline/process/handler.py
Normal file
25
pkg/pipeline/process/handler.py
Normal file
|
@ -0,0 +1,25 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
|
||||
from ...core import app
|
||||
from ...core import entities as core_entities
|
||||
from .. import entities
|
||||
|
||||
|
||||
class MessageHandler(metaclass=abc.ABCMeta):
|
||||
|
||||
ap: app.Application
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def handle(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
) -> entities.StageProcessResult:
|
||||
raise NotImplementedError
|
0
pkg/pipeline/process/handlers/__init__.py
Normal file
0
pkg/pipeline/process/handlers/__init__.py
Normal file
38
pkg/pipeline/process/handlers/chat.py
Normal file
38
pkg/pipeline/process/handlers/chat.py
Normal file
|
@ -0,0 +1,38 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
import mirai
|
||||
|
||||
from .. import handler
|
||||
from ... import entities
|
||||
from ....core import entities as core_entities
|
||||
|
||||
|
||||
class ChatMessageHandler(handler.MessageHandler):
|
||||
|
||||
async def handle(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
|
||||
"""处理
|
||||
"""
|
||||
# 取session
|
||||
# 取conversation
|
||||
# 调API
|
||||
# 生成器
|
||||
session = await self.ap.sess_mgr.get_session(query)
|
||||
|
||||
conversation = await self.ap.sess_mgr.get_conversation(session)
|
||||
|
||||
async for result in conversation.use_model.requester.request(query, conversation):
|
||||
query.resp_message_chain = mirai.MessageChain([mirai.Plain(str(result))])
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
35
pkg/pipeline/process/handlers/command.py
Normal file
35
pkg/pipeline/process/handlers/command.py
Normal file
|
@ -0,0 +1,35 @@
|
|||
from __future__ import annotations
|
||||
import typing
|
||||
|
||||
import mirai
|
||||
|
||||
from .. import handler
|
||||
from ... import entities
|
||||
from ....core import entities as core_entities
|
||||
|
||||
|
||||
class CommandHandler(handler.MessageHandler):
|
||||
|
||||
async def handle(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
|
||||
"""处理
|
||||
"""
|
||||
query.resp_message_chain = mirai.MessageChain([
|
||||
mirai.Plain('CommandHandler')
|
||||
])
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
|
||||
query.resp_message_chain = mirai.MessageChain([
|
||||
mirai.Plain('The Second Message')
|
||||
])
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
38
pkg/pipeline/process/process.py
Normal file
38
pkg/pipeline/process/process.py
Normal file
|
@ -0,0 +1,38 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from ...core import app, entities as core_entities
|
||||
from . import handler
|
||||
from .handlers import chat, command
|
||||
from .. import entities
|
||||
from .. import stage, entities, stagemgr
|
||||
from ...core import entities as core_entities
|
||||
from ...config import manager as cfg_mgr
|
||||
|
||||
|
||||
@stage.stage_class("MessageProcessor")
|
||||
class Processor(stage.PipelineStage):
|
||||
|
||||
cmd_handler: handler.MessageHandler
|
||||
|
||||
chat_handler: handler.MessageHandler
|
||||
|
||||
async def initialize(self):
|
||||
self.cmd_handler = command.CommandHandler(self.ap)
|
||||
self.chat_handler = chat.ChatMessageHandler(self.ap)
|
||||
|
||||
await self.cmd_handler.initialize()
|
||||
await self.chat_handler.initialize()
|
||||
|
||||
async def process(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
stage_inst_name: str,
|
||||
) -> entities.StageProcessResult:
|
||||
"""处理
|
||||
"""
|
||||
message_text = str(query.message_chain).strip()
|
||||
|
||||
if message_text.startswith('!') or message_text.startswith('!'):
|
||||
return self.cmd_handler.handle(query)
|
||||
else:
|
||||
return self.chat_handler.handle(query)
|
0
pkg/pipeline/respback/__init__.py
Normal file
0
pkg/pipeline/respback/__init__.py
Normal file
29
pkg/pipeline/respback/respback.py
Normal file
29
pkg/pipeline/respback/respback.py
Normal file
|
@ -0,0 +1,29 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import mirai
|
||||
|
||||
from ...core import app
|
||||
|
||||
from .. import stage, entities, stagemgr
|
||||
from ...core import entities as core_entities
|
||||
from ...config import manager as cfg_mgr
|
||||
|
||||
|
||||
@stage.stage_class("SendResponseBackStage")
|
||||
class SendResponseBackStage(stage.PipelineStage):
|
||||
"""发送响应消息
|
||||
"""
|
||||
|
||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
"""处理
|
||||
"""
|
||||
|
||||
await self.ap.im_mgr.send(
|
||||
query.message_event,
|
||||
query.resp_message_chain
|
||||
)
|
||||
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
|
@ -1,6 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import typing
|
||||
|
||||
from ..core import app, entities as core_entities
|
||||
from . import entities
|
||||
|
@ -37,7 +38,10 @@ class PipelineStage(metaclass=abc.ABCMeta):
|
|||
self,
|
||||
query: core_entities.Query,
|
||||
stage_inst_name: str,
|
||||
) -> entities.StageProcessResult:
|
||||
) -> typing.Union[
|
||||
entities.StageProcessResult,
|
||||
typing.AsyncGenerator[entities.StageProcessResult, None],
|
||||
]:
|
||||
"""处理
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -7,7 +7,20 @@ from . import stage
|
|||
from .resprule import resprule
|
||||
from .bansess import bansess
|
||||
from .cntfilter import cntfilter
|
||||
from .process import process
|
||||
from .longtext import longtext
|
||||
from .respback import respback
|
||||
|
||||
|
||||
stage_order = [
|
||||
"GroupRespondRuleCheckStage",
|
||||
"BanSessionCheckStage",
|
||||
"PreContentFilterStage",
|
||||
"MessageProcessor",
|
||||
"PostContentFilterStage",
|
||||
"LongTextProcessStage",
|
||||
"SendResponseBackStage",
|
||||
]
|
||||
|
||||
|
||||
class StageInstContainer():
|
||||
|
@ -45,3 +58,6 @@ class StageManager:
|
|||
|
||||
for stage_containers in self.stage_containers:
|
||||
await stage_containers.inst.initialize()
|
||||
|
||||
# 按照 stage_order 排序
|
||||
self.stage_containers.sort(key=lambda x: stage_order.index(x.inst_name))
|
||||
|
|
|
@ -18,10 +18,6 @@ from ..plugin import host as plugin_host
|
|||
from ..plugin import models as plugin_models
|
||||
import tips as tips_custom
|
||||
from ..qqbot import adapter as msadapter
|
||||
from .resprule import resprule
|
||||
from .bansess import bansess
|
||||
from .cntfilter import cntfilter
|
||||
from .longtext import longtext
|
||||
from .ratelim import ratelim
|
||||
|
||||
from ..core import app, entities as core_entities
|
||||
|
@ -41,30 +37,18 @@ class QQBotManager:
|
|||
# modern
|
||||
ap: app.Application = None
|
||||
|
||||
bansess_mgr: bansess.SessionBanManager = None
|
||||
cntfilter_mgr: cntfilter.ContentFilterManager = None
|
||||
longtext_pcs: longtext.LongTextProcessor = None
|
||||
resprule_chkr: resprule.GroupRespondRuleChecker = None
|
||||
ratelimiter: ratelim.RateLimiter = None
|
||||
|
||||
def __init__(self, first_time_init=True, ap: app.Application = None):
|
||||
config = context.get_config_manager().data
|
||||
|
||||
self.ap = ap
|
||||
self.bansess_mgr = bansess.SessionBanManager(ap)
|
||||
self.cntfilter_mgr = cntfilter.ContentFilterManager(ap)
|
||||
self.longtext_pcs = longtext.LongTextProcessor(ap)
|
||||
self.resprule_chkr = resprule.GroupRespondRuleChecker(ap)
|
||||
self.ratelimiter = ratelim.RateLimiter(ap)
|
||||
|
||||
self.timeout = config['process_message_timeout']
|
||||
self.retry = config['retry_times']
|
||||
|
||||
async def initialize(self):
|
||||
await self.bansess_mgr.initialize()
|
||||
await self.cntfilter_mgr.initialize()
|
||||
await self.longtext_pcs.initialize()
|
||||
await self.resprule_chkr.initialize()
|
||||
await self.ratelimiter.initialize()
|
||||
|
||||
config = context.get_config_manager().data
|
||||
|
|
|
@ -15,7 +15,7 @@ from ..plugin import host as plugin_host
|
|||
from ..plugin import models as plugin_models
|
||||
import tips as tips_custom
|
||||
from ..core import app
|
||||
from .cntfilter import entities
|
||||
# from .cntfilter import entities
|
||||
|
||||
processing = []
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user