2024-01-27 00:06:38 +08:00
|
|
|
from __future__ import annotations
|
|
|
|
|
2024-03-16 21:43:09 +08:00
|
|
|
import aiohttp
|
|
|
|
|
2024-01-27 00:06:38 +08:00
|
|
|
from . import entities
|
|
|
|
from ...core import app
|
|
|
|
|
2024-03-16 21:43:09 +08:00
|
|
|
from . import token, api
|
2024-03-12 16:04:11 +08:00
|
|
|
from .apis import chatcmpl
|
2024-01-27 00:06:38 +08:00
|
|
|
|
2024-03-16 21:43:09 +08:00
|
|
|
FETCH_MODEL_LIST_URL = "https://api.qchatgpt.rockchin.top/api/v2/fetch/model_list"
|
|
|
|
|
2024-01-27 00:06:38 +08:00
|
|
|
|
|
|
|
class ModelManager:
|
2024-03-03 16:34:59 +08:00
|
|
|
"""模型管理器"""
|
2024-01-27 00:06:38 +08:00
|
|
|
|
|
|
|
ap: app.Application
|
|
|
|
|
|
|
|
model_list: list[entities.LLMModelInfo]
|
2024-03-16 21:43:09 +08:00
|
|
|
|
|
|
|
requesters: dict[str, api.LLMAPIRequester]
|
|
|
|
|
|
|
|
token_mgrs: dict[str, token.TokenManager]
|
2024-01-27 00:06:38 +08:00
|
|
|
|
|
|
|
def __init__(self, ap: app.Application):
|
|
|
|
self.ap = ap
|
|
|
|
self.model_list = []
|
2024-03-16 21:43:09 +08:00
|
|
|
self.requesters = {}
|
|
|
|
self.token_mgrs = {}
|
2024-01-27 00:06:38 +08:00
|
|
|
|
2024-01-29 21:22:27 +08:00
|
|
|
async def get_model_by_name(self, name: str) -> entities.LLMModelInfo:
|
|
|
|
"""通过名称获取模型
|
|
|
|
"""
|
|
|
|
for model in self.model_list:
|
|
|
|
if model.name == name:
|
|
|
|
return model
|
2024-03-16 21:43:09 +08:00
|
|
|
raise ValueError(f"无法确定模型 {name} 的信息,请在元数据中配置")
|
|
|
|
|
2024-01-27 00:06:38 +08:00
|
|
|
async def initialize(self):
|
2024-03-16 21:43:09 +08:00
|
|
|
|
|
|
|
# 初始化token_mgr, requester
|
2024-03-16 22:12:13 +08:00
|
|
|
for k, v in self.ap.provider_cfg.data['keys'].items():
|
|
|
|
self.token_mgrs[k] = token.TokenManager(k, v)
|
2024-03-16 21:43:09 +08:00
|
|
|
|
|
|
|
for api_cls in api.preregistered_requesters:
|
|
|
|
api_inst = api_cls(self.ap)
|
|
|
|
await api_inst.initialize()
|
|
|
|
self.requesters[api_inst.name] = api_inst
|
|
|
|
|
|
|
|
# 尝试从api获取最新的模型信息
|
|
|
|
try:
|
|
|
|
async with aiohttp.ClientSession() as session:
|
|
|
|
async with session.request(
|
|
|
|
method="GET",
|
|
|
|
url=FETCH_MODEL_LIST_URL,
|
|
|
|
) as resp:
|
|
|
|
model_list = (await resp.json())['data']['list']
|
|
|
|
|
|
|
|
for model in model_list:
|
|
|
|
|
|
|
|
for index, local_model in enumerate(self.ap.llm_models_meta.data['list']):
|
|
|
|
if model['name'] == local_model['name']:
|
|
|
|
self.ap.llm_models_meta.data['list'][index] = model
|
|
|
|
break
|
|
|
|
else:
|
|
|
|
self.ap.llm_models_meta.data['list'].append(model)
|
|
|
|
|
|
|
|
await self.ap.llm_models_meta.dump_config()
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
self.ap.logger.debug(f'获取最新模型列表失败: {e}')
|
|
|
|
|
|
|
|
default_model_info: entities.LLMModelInfo = None
|
|
|
|
|
|
|
|
for model in self.ap.llm_models_meta.data['list']:
|
|
|
|
if model['name'] == 'default':
|
|
|
|
default_model_info = entities.LLMModelInfo(
|
|
|
|
name=model['name'],
|
|
|
|
model_name=None,
|
|
|
|
token_mgr=self.token_mgrs[model['token_mgr']],
|
|
|
|
requester=self.requesters[model['requester']],
|
|
|
|
tool_call_supported=model['tool_call_supported']
|
|
|
|
)
|
|
|
|
break
|
|
|
|
|
|
|
|
for model in self.ap.llm_models_meta.data['list']:
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
model_name = model.get('model_name', default_model_info.model_name)
|
|
|
|
token_mgr = self.token_mgrs[model['token_mgr']] if 'token_mgr' in model else default_model_info.token_mgr
|
|
|
|
requester = self.requesters[model['requester']] if 'requester' in model else default_model_info.requester
|
|
|
|
tool_call_supported = model.get('tool_call_supported', default_model_info.tool_call_supported)
|
|
|
|
|
|
|
|
model_info = entities.LLMModelInfo(
|
|
|
|
name=model['name'],
|
|
|
|
model_name=model_name,
|
|
|
|
token_mgr=token_mgr,
|
|
|
|
requester=requester,
|
|
|
|
tool_call_supported=tool_call_supported
|
|
|
|
)
|
|
|
|
self.model_list.append(model_info)
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
self.ap.logger.error(f"初始化模型 {model['name']} 失败: {e} ,请检查配置文件")
|