feat: 添加 model 命令

This commit is contained in:
RockChinQ 2024-07-28 15:46:09 +08:00
parent 70583f5ba0
commit 68ddb3a6e1
3 changed files with 108 additions and 16 deletions

View File

@ -8,7 +8,7 @@ from . import entities, operator, errors
from ..config import manager as cfg_mgr
# 引入所有算子以便注册
from .operators import func, plugin, default, reset, list as list_cmd, last, next, delc, resend, prompt, cmd, help, version, update, ollama
from .operators import func, plugin, default, reset, list as list_cmd, last, next, delc, resend, prompt, cmd, help, version, update, ollama, model
class CommandManager:

View File

@ -0,0 +1,86 @@
from __future__ import annotations
import typing
from .. import operator, entities, cmdmgr, errors
@operator.operator_class(
name="model",
help='显示和切换模型列表',
usage='!model\n!model show <模型名>\n!model set <模型名>',
privilege=2
)
class ModelOperator(operator.CommandOperator):
"""Model命令"""
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
content = '模型列表:\n'
model_list = self.ap.model_mgr.model_list
for model in model_list:
content += f"\n名称: {model.name}\n"
content += f"请求器: {model.requester.name}\n"
content += f"\n当前对话使用模型: {context.query.use_model.name}\n"
content += f"新对话默认使用模型: {self.ap.provider_cfg.data.get('model')}\n"
yield entities.CommandReturn(text=content.strip())
@operator.operator_class(
name="show",
help='显示模型详情',
privilege=2,
parent_class=ModelOperator
)
class ModelShowOperator(operator.CommandOperator):
"""Model Show命令"""
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
model_name = context.crt_params[0]
model = None
for _model in self.ap.model_mgr.model_list:
if model_name == _model.name:
model = _model
break
if model is None:
yield entities.CommandReturn(error=errors.CommandError(f"未找到模型 {model_name}"))
else:
content = f"模型详情\n"
content += f"名称: {model.name}\n"
if model.model_name is not None:
content += f"请求模型名称: {model.model_name}\n"
content += f"请求器: {model.requester.name}\n"
content += f"密钥组: {model.token_mgr.provider}\n"
content += f"支持视觉: {model.vision_supported}\n"
content += f"支持工具: {model.tool_call_supported}\n"
yield entities.CommandReturn(text=content.strip())
@operator.operator_class(
name="set",
help='设置默认使用模型',
privilege=2,
parent_class=ModelOperator
)
class ModelSetOperator(operator.CommandOperator):
"""Model Set命令"""
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
model_name = context.crt_params[0]
model = None
for _model in self.ap.model_mgr.model_list:
if model_name == _model.name:
model = _model
break
if model is None:
yield entities.CommandReturn(error=errors.CommandError(f"未找到模型 {model_name}"))
else:
self.ap.provider_cfg.data['model'] = model_name
await self.ap.provider_cfg.dump_config()
yield entities.CommandReturn(text=f"已设置当前使用模型为 {model_name},重置会话以生效")

View File

@ -2,9 +2,10 @@ from __future__ import annotations
import json
import typing
import traceback
import ollama
from .. import operator, entities
from .. import operator, entities, errors
@operator.operator_class(
@ -16,13 +17,16 @@ class OllamaOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
content: str = '模型列表:\n'
model_list: list = ollama.list().get('models', [])
for model in model_list:
content += f"name: {model['name']}\n"
content += f"modified_at: {model['modified_at']}\n"
content += f"size: {bytes_to_mb(model['size'])}MB\n\n"
yield entities.CommandReturn(text=f"{content.strip()}")
try:
content: str = '模型列表:\n'
model_list: list = ollama.list().get('models', [])
for model in model_list:
content += f"名称: {model['name']}\n"
content += f"修改时间: {model['modified_at']}\n"
content += f"大小: {bytes_to_mb(model['size'])}MB\n\n"
yield entities.CommandReturn(text=f"{content.strip()}")
except ollama.ResponseError as e:
yield entities.CommandReturn(error=errors.CommandError(f"无法获取模型列表,请确认 Ollama 服务正常"))
def bytes_to_mb(num_bytes):
@ -53,11 +57,9 @@ class OllamaShowOperator(operator.CommandOperator):
model_info[key] = ignore_show
content += json.dumps(show, indent=4)
yield entities.CommandReturn(text=content.strip())
except ollama.ResponseError as e:
content = f"{e.error}"
yield entities.CommandReturn(text=content.strip())
yield entities.CommandReturn(error=errors.CommandError(f"无法获取模型详情,请确认 Ollama 服务正常"))
@operator.operator_class(
name="pull",
@ -69,9 +71,13 @@ class OllamaPullOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
model_list: list = ollama.list().get('models', [])
if context.crt_params[0] in [model['name'] for model in model_list]:
yield entities.CommandReturn(text="模型已存在")
try:
model_list: list = ollama.list().get('models', [])
if context.crt_params[0] in [model['name'] for model in model_list]:
yield entities.CommandReturn(text="模型已存在")
return
except ollama.ResponseError as e:
yield entities.CommandReturn(error=errors.CommandError(f"无法获取模型列表,请确认 Ollama 服务正常"))
return
on_progress: bool = False