Merge pull request #844 from canyuan0801/pr

Feat: Ollama平台集成
This commit is contained in:
Junyan Qin 2024-07-10 00:09:48 +08:00 committed by GitHub
commit 9d91c13b12
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 253 additions and 3 deletions

View File

@ -8,7 +8,7 @@ from . import entities, operator, errors
from ..config import manager as cfg_mgr
# 引入所有算子以便注册
from .operators import func, plugin, default, reset, list as list_cmd, last, next, delc, resend, prompt, cmd, help, version, update
from .operators import func, plugin, default, reset, list as list_cmd, last, next, delc, resend, prompt, cmd, help, version, update, ollama
class CommandManager:

View File

@ -0,0 +1,115 @@
from __future__ import annotations
import json
import typing
import ollama
from .. import operator, entities
@operator.operator_class(
name="ollama",
help="ollama平台操作",
usage="!ollama\n!ollama show <模型名>\n!ollama pull <模型名>\n!ollama del <模型名>"
)
class OllamaOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
content: str = '模型列表:\n'
model_list: list = ollama.list().get('models', [])
for model in model_list:
content += f"name: {model['name']}\n"
content += f"modified_at: {model['modified_at']}\n"
content += f"size: {bytes_to_mb(model['size'])}MB\n\n"
yield entities.CommandReturn(text=f"{content.strip()}")
def bytes_to_mb(num_bytes):
mb: float = num_bytes / 1024 / 1024
return format(mb, '.2f')
@operator.operator_class(
name="show",
help="ollama模型详情",
privilege=2,
parent_class=OllamaOperator
)
class OllamaShowOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
content: str = '模型详情:\n'
try:
show: dict = ollama.show(model=context.crt_params[0])
model_info: dict = show.get('model_info', {})
ignore_show: str = 'too long to show...'
for key in ['license', 'modelfile']:
show[key] = ignore_show
for key in ['tokenizer.chat_template.rag', 'tokenizer.chat_template.tool_use']:
model_info[key] = ignore_show
content += json.dumps(show, indent=4)
except ollama.ResponseError as e:
content = f"{e.error}"
yield entities.CommandReturn(text=content.strip())
@operator.operator_class(
name="pull",
help="ollama模型拉取",
privilege=2,
parent_class=OllamaOperator
)
class OllamaPullOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
model_list: list = ollama.list().get('models', [])
if context.crt_params[0] in [model['name'] for model in model_list]:
yield entities.CommandReturn(text="模型已存在")
return
on_progress: bool = False
progress_count: int = 0
try:
for resp in ollama.pull(model=context.crt_params[0], stream=True):
total: typing.Any = resp.get('total')
if not on_progress:
if total is not None:
on_progress = True
yield entities.CommandReturn(text=resp.get('status'))
else:
if total is None:
on_progress = False
completed: typing.Any = resp.get('completed')
if isinstance(completed, int) and isinstance(total, int):
percentage_completed = (completed / total) * 100
if percentage_completed > progress_count:
progress_count += 10
yield entities.CommandReturn(
text=f"下载进度: {completed}/{total} ({percentage_completed:.2f}%)")
except ollama.ResponseError as e:
yield entities.CommandReturn(text=f"拉取失败: {e.error}")
@operator.operator_class(
name="del",
help="ollama模型删除",
privilege=2,
parent_class=OllamaOperator
)
class OllamaDelOperator(operator.CommandOperator):
async def execute(
self, context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
try:
ret: str = ollama.delete(model=context.crt_params[0])['status']
except ollama.ResponseError as e:
ret = f"{e.error}"
yield entities.CommandReturn(text=ret)

View File

@ -0,0 +1,23 @@
from __future__ import annotations
from .. import migration
@migration.migration_class("ollama-requester-config", 10)
class MsgTruncatorConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return 'ollama-chat' not in self.ap.provider_cfg.data['requester']
async def run(self):
"""执行迁移"""
self.ap.provider_cfg.data['requester']['ollama-chat'] = {
"base-url": "http://127.0.0.1:11434",
"args": {},
"timeout": 600
}
await self.ap.provider_cfg.dump_config()

View File

@ -6,6 +6,7 @@ from .. import stage, app
from ...config import migration
from ...config.migrations import m001_sensitive_word_migration, m002_openai_config_migration, m003_anthropic_requester_cfg_completion, m004_moonshot_cfg_completion
from ...config.migrations import m005_deepseek_cfg_completion, m006_vision_config, m007_qcg_center_url, m008_ad_fixwin_config_migrate, m009_msg_truncator_cfg
from ...config.migrations import m010_ollama_requester_config
@stage.stage_class("MigrationStage")

View File

@ -0,0 +1,105 @@
from __future__ import annotations
import asyncio
import os
import typing
from typing import Union, Mapping, Any, AsyncIterator
import async_lru
import ollama
from .. import api, entities, errors
from ... import entities as llm_entities
from ...tools import entities as tools_entities
from ....core import app
from ....utils import image
REQUESTER_NAME: str = "ollama-chat"
@api.requester_class(REQUESTER_NAME)
class OllamaChatCompletions(api.LLMAPIRequester):
"""Ollama平台 ChatCompletion API请求器"""
client: ollama.AsyncClient
request_cfg: dict
def __init__(self, ap: app.Application):
super().__init__(ap)
self.ap = ap
self.request_cfg = self.ap.provider_cfg.data['requester'][REQUESTER_NAME]
async def initialize(self):
os.environ['OLLAMA_HOST'] = self.request_cfg['base-url']
self.client = ollama.AsyncClient(
timeout=self.request_cfg['timeout']
)
async def _req(self,
args: dict,
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
return await self.client.chat(
**args
)
async def _closure(self, req_messages: list[dict], use_model: entities.LLMModelInfo,
user_funcs: list[tools_entities.LLMFunction] = None) -> (
llm_entities.Message):
args: Any = self.request_cfg['args'].copy()
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
messages: list[dict] = req_messages.copy()
for msg in messages:
if 'content' in msg and isinstance(msg["content"], list):
text_content: list = []
image_urls: list = []
for me in msg["content"]:
if me["type"] == "text":
text_content.append(me["text"])
elif me["type"] == "image_url":
image_url = await self.get_base64_str(me["image_url"]['url'])
image_urls.append(image_url)
msg["content"] = "\n".join(text_content)
msg["images"] = [url.split(',')[1] for url in image_urls]
args["messages"] = messages
resp: Mapping[str, Any] | AsyncIterator[Mapping[str, Any]] = await self._req(args)
message: llm_entities.Message = await self._make_msg(resp)
return message
async def _make_msg(
self,
chat_completions: Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]) -> llm_entities.Message:
message: Any = chat_completions.pop('message', None)
if message is None:
raise ValueError("chat_completions must contain a 'message' field")
message.update(chat_completions)
ret_msg: llm_entities.Message = llm_entities.Message(**message)
return ret_msg
async def call(
self,
model: entities.LLMModelInfo,
messages: typing.List[llm_entities.Message],
funcs: typing.List[tools_entities.LLMFunction] = None,
) -> llm_entities.Message:
req_messages: list = []
for m in messages:
msg_dict: dict = m.dict(exclude_none=True)
content: Any = msg_dict.get("content")
if isinstance(content, list):
if all(isinstance(part, dict) and part.get('type') == 'text' for part in content):
msg_dict["content"] = "\n".join(part["text"] for part in content)
req_messages.append(msg_dict)
try:
return await self._closure(req_messages, model)
except asyncio.TimeoutError:
raise errors.RequesterError('请求超时')
@async_lru.alru_cache(maxsize=128)
async def get_base64_str(
self,
original_url: str,
) -> str:
base64_image: str = await image.qq_image_url_to_base64(original_url)
return f"data:image/jpeg;base64,{base64_image}"

View File

@ -6,7 +6,7 @@ from . import entities
from ...core import app
from . import token, api
from .apis import chatcmpl, anthropicmsgs, moonshotchatcmpl, deepseekchatcmpl
from .apis import chatcmpl, anthropicmsgs, moonshotchatcmpl, deepseekchatcmpl, ollamachat
FETCH_MODEL_LIST_URL = "https://api.qchatgpt.rockchin.top/api/v2/fetch/model_list"

View File

@ -14,4 +14,5 @@ pydantic
websockets
urllib3
psutil
async-lru
async-lru
ollama

View File

@ -37,6 +37,11 @@
"base-url": "https://api.deepseek.com",
"args": {},
"timeout": 120
},
"ollama-chat": {
"base-url": "http://127.0.0.1:11434",
"args": {},
"timeout": 600
}
},
"model": "gpt-3.5-turbo",