mirror of
https://github.com/RockChinQ/QChatGPT.git
synced 2024-11-16 03:32:33 +08:00
add ollama chat
This commit is contained in:
parent
45a10b4ac7
commit
987b3dc4ef
105
pkg/provider/modelmgr/apis/ollamachatcmpl.py
Normal file
105
pkg/provider/modelmgr/apis/ollamachatcmpl.py
Normal file
|
@ -0,0 +1,105 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import typing
|
||||||
|
from typing import Union, Mapping, Any, AsyncIterator
|
||||||
|
|
||||||
|
import async_lru
|
||||||
|
import ollama
|
||||||
|
|
||||||
|
from .. import api, entities, errors
|
||||||
|
from ... import entities as llm_entities
|
||||||
|
from ...tools import entities as tools_entities
|
||||||
|
from ....core import app
|
||||||
|
from ....utils import image
|
||||||
|
|
||||||
|
REQUESTER_NAME: str = "ollama-chat-completions"
|
||||||
|
|
||||||
|
|
||||||
|
@api.requester_class(REQUESTER_NAME)
|
||||||
|
class OllamaChatCompletions(api.LLMAPIRequester):
|
||||||
|
"""Ollama平台 ChatCompletion API请求器"""
|
||||||
|
client: ollama.AsyncClient
|
||||||
|
request_cfg: dict
|
||||||
|
|
||||||
|
def __init__(self, ap: app.Application):
|
||||||
|
super().__init__(ap)
|
||||||
|
self.ap = ap
|
||||||
|
self.request_cfg = self.ap.provider_cfg.data['requester'][REQUESTER_NAME]
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
os.environ['OLLAMA_HOST'] = self.request_cfg['base-url']
|
||||||
|
self.client = ollama.AsyncClient(
|
||||||
|
timeout=self.request_cfg['timeout']
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _req(self,
|
||||||
|
args: dict,
|
||||||
|
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
|
||||||
|
return await self.client.chat(
|
||||||
|
**args
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _closure(self, req_messages: list[dict], use_model: entities.LLMModelInfo,
|
||||||
|
user_funcs: list[tools_entities.LLMFunction] = None) -> (
|
||||||
|
llm_entities.Message):
|
||||||
|
args: Any = self.request_cfg['args'].copy()
|
||||||
|
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
|
||||||
|
|
||||||
|
messages: list[dict] = req_messages.copy()
|
||||||
|
for msg in messages:
|
||||||
|
if 'content' in msg and isinstance(msg["content"], list):
|
||||||
|
text_content: list = []
|
||||||
|
image_urls: list = []
|
||||||
|
for me in msg["content"]:
|
||||||
|
if me["type"] == "text":
|
||||||
|
text_content.append(me["text"])
|
||||||
|
elif me["type"] == "image_url":
|
||||||
|
image_url = await self.get_base64_str(me["image_url"]['url'])
|
||||||
|
image_urls.append(image_url)
|
||||||
|
msg["content"] = "\n".join(text_content)
|
||||||
|
msg["images"] = [url.split(',')[1] for url in image_urls]
|
||||||
|
args["messages"] = messages
|
||||||
|
|
||||||
|
resp: Mapping[str, Any] | AsyncIterator[Mapping[str, Any]] = await self._req(args)
|
||||||
|
message: llm_entities.Message = await self._make_msg(resp)
|
||||||
|
return message
|
||||||
|
|
||||||
|
async def _make_msg(
|
||||||
|
self,
|
||||||
|
chat_completions: Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]) -> llm_entities.Message:
|
||||||
|
message: Any = chat_completions.pop('message', None)
|
||||||
|
if message is None:
|
||||||
|
raise ValueError("chat_completions must contain a 'message' field")
|
||||||
|
|
||||||
|
message.update(chat_completions)
|
||||||
|
ret_msg: llm_entities.Message = llm_entities.Message(**message)
|
||||||
|
return ret_msg
|
||||||
|
|
||||||
|
async def call(
|
||||||
|
self,
|
||||||
|
model: entities.LLMModelInfo,
|
||||||
|
messages: typing.List[llm_entities.Message],
|
||||||
|
funcs: typing.List[tools_entities.LLMFunction] = None,
|
||||||
|
) -> llm_entities.Message:
|
||||||
|
req_messages: list = []
|
||||||
|
for m in messages:
|
||||||
|
msg_dict: dict = m.dict(exclude_none=True)
|
||||||
|
content: Any = msg_dict.get("content")
|
||||||
|
if isinstance(content, list):
|
||||||
|
if all(isinstance(part, dict) and part.get('type') == 'text' for part in content):
|
||||||
|
msg_dict["content"] = "\n".join(part["text"] for part in content)
|
||||||
|
req_messages.append(msg_dict)
|
||||||
|
try:
|
||||||
|
return await self._closure(req_messages, model)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
raise errors.RequesterError('请求超时')
|
||||||
|
|
||||||
|
@async_lru.alru_cache(maxsize=128)
|
||||||
|
async def get_base64_str(
|
||||||
|
self,
|
||||||
|
original_url: str,
|
||||||
|
) -> str:
|
||||||
|
base64_image: str = await image.qq_image_url_to_base64(original_url)
|
||||||
|
return f"data:image/jpeg;base64,{base64_image}"
|
|
@ -6,7 +6,7 @@ from . import entities
|
||||||
from ...core import app
|
from ...core import app
|
||||||
|
|
||||||
from . import token, api
|
from . import token, api
|
||||||
from .apis import chatcmpl, anthropicmsgs, moonshotchatcmpl, deepseekchatcmpl
|
from .apis import chatcmpl, anthropicmsgs, moonshotchatcmpl, deepseekchatcmpl, ollamachatcmpl
|
||||||
|
|
||||||
FETCH_MODEL_LIST_URL = "https://api.qchatgpt.rockchin.top/api/v2/fetch/model_list"
|
FETCH_MODEL_LIST_URL = "https://api.qchatgpt.rockchin.top/api/v2/fetch/model_list"
|
||||||
|
|
||||||
|
|
|
@ -15,3 +15,4 @@ websockets
|
||||||
urllib3
|
urllib3
|
||||||
psutil
|
psutil
|
||||||
async-lru
|
async-lru
|
||||||
|
ollama
|
Loading…
Reference in New Issue
Block a user