From 2851a9f04e369d87acdf5fc1e89094dc1cc7cf97 Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 11 Oct 2023 20:17:41 +0800 Subject: [PATCH] feat: optimize minimax llm call (#1312) --- .../models/llm/minimax_model.py | 19 +- .../providers/minimax_provider.py | 7 +- .../third_party/langchain/llms/minimax_llm.py | 273 ++++++++++++++++++ 3 files changed, 287 insertions(+), 12 deletions(-) create mode 100644 api/core/third_party/langchain/llms/minimax_llm.py diff --git a/api/core/model_providers/models/llm/minimax_model.py b/api/core/model_providers/models/llm/minimax_model.py index e2252d7edc..83fca8fd75 100644 --- a/api/core/model_providers/models/llm/minimax_model.py +++ b/api/core/model_providers/models/llm/minimax_model.py @@ -1,26 +1,23 @@ -import decimal from typing import List, Optional, Any from langchain.callbacks.manager import Callbacks -from langchain.llms import Minimax from langchain.schema import LLMResult from core.model_providers.error import LLMBadRequestError from core.model_providers.models.llm.base import BaseLLM -from core.model_providers.models.entity.message import PromptMessage, MessageType +from core.model_providers.models.entity.message import PromptMessage from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs +from core.third_party.langchain.llms.minimax_llm import MinimaxChatLLM class MinimaxModel(BaseLLM): - model_mode: ModelMode = ModelMode.COMPLETION + model_mode: ModelMode = ModelMode.CHAT def _init_client(self) -> Any: provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) - return Minimax( + return MinimaxChatLLM( model=self.name, - model_kwargs={ - 'stream': False - }, + streaming=self.streaming, callbacks=self.callbacks, **self.credentials, **provider_model_kwargs @@ -49,7 +46,7 @@ class MinimaxModel(BaseLLM): :return: """ prompts = self._get_prompt_from_messages(messages) - return max(self._client.get_num_tokens(prompts), 0) + return max(self._client.get_num_tokens_from_messages(prompts), 0) def get_currency(self): return 'RMB' @@ -65,3 +62,7 @@ class MinimaxModel(BaseLLM): return LLMBadRequestError(f"Minimax: {str(ex)}") else: return ex + + @property + def support_streaming(self): + return True diff --git a/api/core/model_providers/providers/minimax_provider.py b/api/core/model_providers/providers/minimax_provider.py index 488e6438b4..c13165d602 100644 --- a/api/core/model_providers/providers/minimax_provider.py +++ b/api/core/model_providers/providers/minimax_provider.py @@ -2,7 +2,7 @@ import json from json import JSONDecodeError from typing import Type -from langchain.llms import Minimax +from langchain.schema import HumanMessage from core.helper import encrypter from core.model_providers.models.base import BaseProviderModel @@ -10,6 +10,7 @@ from core.model_providers.models.embedding.minimax_embedding import MinimaxEmbed from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType from core.model_providers.models.llm.minimax_model import MinimaxModel from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError +from core.third_party.langchain.llms.minimax_llm import MinimaxChatLLM from models.provider import ProviderType, ProviderQuotaType @@ -98,14 +99,14 @@ class MinimaxProvider(BaseModelProvider): 'minimax_api_key': credentials['minimax_api_key'], } - llm = Minimax( + llm = MinimaxChatLLM( model='abab5.5-chat', max_tokens=10, temperature=0.01, **credential_kwargs ) - llm("ping") + llm([HumanMessage(content='ping')]) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) diff --git a/api/core/third_party/langchain/llms/minimax_llm.py b/api/core/third_party/langchain/llms/minimax_llm.py new file mode 100644 index 0000000000..c17c63a4a9 --- /dev/null +++ b/api/core/third_party/langchain/llms/minimax_llm.py @@ -0,0 +1,273 @@ +import json +from typing import Dict, Any, Optional, List, Tuple, Iterator + +import requests +from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.chat_models.base import BaseChatModel +from langchain.llms.utils import enforce_stop_tokens +from langchain.schema import BaseMessage, ChatResult, HumanMessage, AIMessage, SystemMessage +from langchain.schema.messages import AIMessageChunk +from langchain.schema.output import ChatGenerationChunk, ChatGeneration +from langchain.utils import get_from_dict_or_env +from pydantic import root_validator, Field, BaseModel + + +class _MinimaxEndpointClient(BaseModel): + """An API client that talks to a Minimax llm endpoint.""" + + host: str + group_id: str + api_key: str + api_url: str + + @root_validator(pre=True) + def set_api_url(cls, values: Dict[str, Any]) -> Dict[str, Any]: + if "api_url" not in values: + host = values["host"] + group_id = values["group_id"] + api_url = f"{host}/v1/text/chatcompletion?GroupId={group_id}" + values["api_url"] = api_url + return values + + def post(self, **request: Any) -> Any: + stream = 'stream' in request and request['stream'] + + headers = {"Authorization": f"Bearer {self.api_key}"} + response = requests.post(self.api_url, headers=headers, json=request, stream=stream, timeout=(5, 60)) + if not response.ok: + raise ValueError(f"HTTP {response.status_code} error: {response.text}") + + if not stream: + if response.json()["base_resp"]["status_code"] > 0: + raise ValueError( + f"API {response.json()['base_resp']['status_code']}" + f" error: {response.json()['base_resp']['status_msg']}" + ) + return response.json() + else: + return response + + +class MinimaxChatLLM(BaseChatModel): + + _client: _MinimaxEndpointClient + model: str = "abab5.5-chat" + """Model name to use.""" + max_tokens: int = 256 + """Denotes the number of tokens to predict per generation.""" + temperature: float = 0.7 + """A non-negative float that tunes the degree of randomness in generation.""" + top_p: float = 0.95 + """Total probability mass of tokens to consider at each step.""" + model_kwargs: Dict[str, Any] = Field(default_factory=dict) + """Holds any model parameters valid for `create` call not explicitly specified.""" + streaming: bool = False + """Whether to stream the response or return it all at once.""" + minimax_api_host: Optional[str] = None + minimax_group_id: Optional[str] = None + minimax_api_key: Optional[str] = None + + @property + def lc_secrets(self) -> Dict[str, str]: + return {"minimax_api_key": "MINIMAX_API_KEY"} + + @property + def lc_serializable(self) -> bool: + return True + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + values["minimax_api_key"] = get_from_dict_or_env( + values, "minimax_api_key", "MINIMAX_API_KEY" + ) + values["minimax_group_id"] = get_from_dict_or_env( + values, "minimax_group_id", "MINIMAX_GROUP_ID" + ) + # Get custom api url from environment. + values["minimax_api_host"] = get_from_dict_or_env( + values, + "minimax_api_host", + "MINIMAX_API_HOST", + default="https://api.minimax.chat", + ) + values["_client"] = _MinimaxEndpointClient( + host=values["minimax_api_host"], + api_key=values["minimax_api_key"], + group_id=values["minimax_group_id"], + ) + return values + + @property + def _default_params(self) -> Dict[str, Any]: + """Get the default parameters for calling OpenAI API.""" + return { + "model": self.model, + "tokens_to_generate": self.max_tokens, + "temperature": self.temperature, + "top_p": self.top_p, + "role_meta": {"user_name": "我", "bot_name": "专家"}, + **self.model_kwargs, + } + + @property + def _identifying_params(self) -> Dict[str, Any]: + """Get the identifying parameters.""" + return {**{"model": self.model}, **self._default_params} + + @property + def _llm_type(self) -> str: + """Return type of llm.""" + return "minimax" + + def _convert_message_to_dict(self, message: BaseMessage) -> dict: + if isinstance(message, HumanMessage): + message_dict = {"sender_type": "USER", "text": message.content} + elif isinstance(message, AIMessage): + message_dict = {"sender_type": "BOT", "text": message.content} + else: + raise ValueError(f"Got unknown type {message}") + return message_dict + + def _create_messages_and_prompt( + self, messages: List[BaseMessage] + ) -> Tuple[List[Dict[str, Any]], str]: + prompt = "" + dict_messages = [] + for m in messages: + if isinstance(m, SystemMessage): + if prompt: + prompt += "\n" + prompt += f"{m.content}" + continue + + message = self._convert_message_to_dict(m) + dict_messages.append(message) + + prompt = prompt if prompt else ' ' + + return dict_messages, prompt + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + if self.streaming: + generation: Optional[ChatGenerationChunk] = None + llm_output: Optional[Dict] = None + for chunk in self._stream( + messages=messages, stop=stop, run_manager=run_manager, **kwargs + ): + if generation is None: + generation = chunk + else: + generation += chunk + + if chunk.generation_info is not None \ + and 'token_usage' in chunk.generation_info: + llm_output = {"token_usage": chunk.generation_info['token_usage'], "model_name": self.model} + + assert generation is not None + return ChatResult(generations=[generation], llm_output=llm_output) + else: + message_dicts, prompt = self._create_messages_and_prompt(messages) + params = self._default_params + params["messages"] = message_dicts + params["prompt"] = prompt + params.update(kwargs) + response = self._client.post(**params) + return self._create_chat_result(response, stop) + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + message_dicts, prompt = self._create_messages_and_prompt(messages) + params = self._default_params + params["messages"] = message_dicts + params["prompt"] = prompt + params["stream"] = True + params.update(kwargs) + + for token in self._client.post(**params).iter_lines(): + if token: + token = token.decode("utf-8") + + if not token.startswith("data:"): + data = json.loads(token) + if "base_resp" in data and data["base_resp"]["status_code"] > 0: + raise ValueError( + f"API {data['base_resp']['status_code']}" + f" error: {data['base_resp']['status_msg']}" + ) + else: + continue + + token = token.lstrip("data:").strip() + data = json.loads(token) + content = data['choices'][0]['delta'] + + chunk_kwargs = { + 'message': AIMessageChunk(content=content), + } + + if 'usage' in data: + token_usage = data['usage'] + overall_token_usage = { + 'prompt_tokens': 0, + 'completion_tokens': token_usage.get('total_tokens', 0), + 'total_tokens': token_usage.get('total_tokens', 0) + } + chunk_kwargs['generation_info'] = {'token_usage': overall_token_usage} + + yield ChatGenerationChunk(**chunk_kwargs) + if run_manager: + run_manager.on_llm_new_token(content) + + def _create_chat_result(self, response: Dict[str, Any], stop: Optional[List[str]] = None) -> ChatResult: + text = response['reply'] + if stop is not None: + # This is required since the stop tokens + # are not enforced by the model parameters + text = enforce_stop_tokens(text, stop) + + generations = [ChatGeneration(message=AIMessage(content=text))] + usage = response.get("usage") + + # only return total_tokens in minimax response + token_usage = { + 'prompt_tokens': 0, + 'completion_tokens': usage.get('total_tokens', 0), + 'total_tokens': usage.get('total_tokens', 0) + } + llm_output = {"token_usage": token_usage, "model_name": self.model} + return ChatResult(generations=generations, llm_output=llm_output) + + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + """Get the number of tokens in the messages. + + Useful for checking if an input will fit in a model's context window. + + Args: + messages: The message inputs to tokenize. + + Returns: + The sum of the number of tokens across the messages. + """ + return sum([self.get_num_tokens(m.content) for m in messages]) + + def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: + token_usage: dict = {} + for output in llm_outputs: + if output is None: + # Happens in streaming + continue + token_usage = output["token_usage"] + + return {"token_usage": token_usage, "model_name": self.model}