diff --git a/pkg/provider/modelmgr/apis/ollamachatcmpl.py b/pkg/provider/modelmgr/apis/ollamachatcmpl.py new file mode 100644 index 0000000..150a2af --- /dev/null +++ b/pkg/provider/modelmgr/apis/ollamachatcmpl.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +import asyncio +import os +import typing +from typing import Union, Mapping, Any, AsyncIterator + +import async_lru +import ollama + +from .. import api, entities, errors +from ... import entities as llm_entities +from ...tools import entities as tools_entities +from ....core import app +from ....utils import image + +REQUESTER_NAME: str = "ollama-chat-completions" + + +@api.requester_class(REQUESTER_NAME) +class OllamaChatCompletions(api.LLMAPIRequester): + """Ollama平台 ChatCompletion API请求器""" + client: ollama.AsyncClient + request_cfg: dict + + def __init__(self, ap: app.Application): + super().__init__(ap) + self.ap = ap + self.request_cfg = self.ap.provider_cfg.data['requester'][REQUESTER_NAME] + + async def initialize(self): + os.environ['OLLAMA_HOST'] = self.request_cfg['base-url'] + self.client = ollama.AsyncClient( + timeout=self.request_cfg['timeout'] + ) + + async def _req(self, + args: dict, + ) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]: + return await self.client.chat( + **args + ) + + async def _closure(self, req_messages: list[dict], use_model: entities.LLMModelInfo, + user_funcs: list[tools_entities.LLMFunction] = None) -> ( + llm_entities.Message): + args: Any = self.request_cfg['args'].copy() + args["model"] = use_model.name if use_model.model_name is None else use_model.model_name + + messages: list[dict] = req_messages.copy() + for msg in messages: + if 'content' in msg and isinstance(msg["content"], list): + text_content: list = [] + image_urls: list = [] + for me in msg["content"]: + if me["type"] == "text": + text_content.append(me["text"]) + elif me["type"] == "image_url": + image_url = await self.get_base64_str(me["image_url"]['url']) + image_urls.append(image_url) + msg["content"] = "\n".join(text_content) + msg["images"] = [url.split(',')[1] for url in image_urls] + args["messages"] = messages + + resp: Mapping[str, Any] | AsyncIterator[Mapping[str, Any]] = await self._req(args) + message: llm_entities.Message = await self._make_msg(resp) + return message + + async def _make_msg( + self, + chat_completions: Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]) -> llm_entities.Message: + message: Any = chat_completions.pop('message', None) + if message is None: + raise ValueError("chat_completions must contain a 'message' field") + + message.update(chat_completions) + ret_msg: llm_entities.Message = llm_entities.Message(**message) + return ret_msg + + async def call( + self, + model: entities.LLMModelInfo, + messages: typing.List[llm_entities.Message], + funcs: typing.List[tools_entities.LLMFunction] = None, + ) -> llm_entities.Message: + req_messages: list = [] + for m in messages: + msg_dict: dict = m.dict(exclude_none=True) + content: Any = msg_dict.get("content") + if isinstance(content, list): + if all(isinstance(part, dict) and part.get('type') == 'text' for part in content): + msg_dict["content"] = "\n".join(part["text"] for part in content) + req_messages.append(msg_dict) + try: + return await self._closure(req_messages, model) + except asyncio.TimeoutError: + raise errors.RequesterError('请求超时') + + @async_lru.alru_cache(maxsize=128) + async def get_base64_str( + self, + original_url: str, + ) -> str: + base64_image: str = await image.qq_image_url_to_base64(original_url) + return f"data:image/jpeg;base64,{base64_image}" diff --git a/pkg/provider/modelmgr/modelmgr.py b/pkg/provider/modelmgr/modelmgr.py index 79e467a..6a221d2 100644 --- a/pkg/provider/modelmgr/modelmgr.py +++ b/pkg/provider/modelmgr/modelmgr.py @@ -6,7 +6,7 @@ from . import entities from ...core import app from . import token, api -from .apis import chatcmpl, anthropicmsgs, moonshotchatcmpl, deepseekchatcmpl +from .apis import chatcmpl, anthropicmsgs, moonshotchatcmpl, deepseekchatcmpl, ollamachatcmpl FETCH_MODEL_LIST_URL = "https://api.qchatgpt.rockchin.top/api/v2/fetch/model_list" diff --git a/requirements.txt b/requirements.txt index 44bc285..1b554d2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,4 +14,5 @@ pydantic websockets urllib3 psutil -async-lru \ No newline at end of file +async-lru +ollama \ No newline at end of file