diff --git a/api/core/model_providers/models/llm/chatglm_model.py b/api/core/model_providers/models/llm/chatglm_model.py index 5f22cdf6af..cb6e98d5e2 100644 --- a/api/core/model_providers/models/llm/chatglm_model.py +++ b/api/core/model_providers/models/llm/chatglm_model.py @@ -1,27 +1,45 @@ -import decimal +import logging from typing import List, Optional, Any +import openai from langchain.callbacks.manager import Callbacks -from langchain.llms import ChatGLM -from langchain.schema import LLMResult +from langchain.schema import LLMResult, get_buffer_string -from core.model_providers.error import LLMBadRequestError +from core.model_providers.error import LLMBadRequestError, LLMRateLimitError, LLMAuthorizationError, \ + LLMAPIUnavailableError, LLMAPIConnectionError from core.model_providers.models.llm.base import BaseLLM from core.model_providers.models.entity.message import PromptMessage, MessageType from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs +from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI class ChatGLMModel(BaseLLM): - model_mode: ModelMode = ModelMode.COMPLETION + model_mode: ModelMode = ModelMode.CHAT def _init_client(self) -> Any: provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) - return ChatGLM( + + extra_model_kwargs = { + 'top_p': provider_model_kwargs.get('top_p') + } + + if provider_model_kwargs.get('max_length') is not None: + extra_model_kwargs['max_length'] = provider_model_kwargs.get('max_length') + + client = EnhanceChatOpenAI( + model_name=self.name, + temperature=provider_model_kwargs.get('temperature'), + max_tokens=provider_model_kwargs.get('max_tokens'), + model_kwargs=extra_model_kwargs, + streaming=self.streaming, callbacks=self.callbacks, - endpoint_url=self.credentials.get('api_base'), - **provider_model_kwargs + request_timeout=60, + openai_api_key="1", + openai_api_base=self.credentials['api_base'] + '/v1' ) + return client + def _run(self, messages: List[PromptMessage], stop: Optional[List[str]] = None, callbacks: Callbacks = None, @@ -45,19 +63,40 @@ class ChatGLMModel(BaseLLM): :return: """ prompts = self._get_prompt_from_messages(messages) - return max(self._client.get_num_tokens(prompts), 0) + return max(sum([self._client.get_num_tokens(get_buffer_string([m])) for m in prompts]) - len(prompts), 0) def get_currency(self): return 'RMB' def _set_model_kwargs(self, model_kwargs: ModelKwargs): provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) - for k, v in provider_model_kwargs.items(): - if hasattr(self.client, k): - setattr(self.client, k, v) + extra_model_kwargs = { + 'top_p': provider_model_kwargs.get('top_p') + } + + self.client.temperature = provider_model_kwargs.get('temperature') + self.client.max_tokens = provider_model_kwargs.get('max_tokens') + self.client.model_kwargs = extra_model_kwargs def handle_exceptions(self, ex: Exception) -> Exception: - if isinstance(ex, ValueError): - return LLMBadRequestError(f"ChatGLM: {str(ex)}") + if isinstance(ex, openai.error.InvalidRequestError): + logging.warning("Invalid request to ChatGLM API.") + return LLMBadRequestError(str(ex)) + elif isinstance(ex, openai.error.APIConnectionError): + logging.warning("Failed to connect to ChatGLM API.") + return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex)) + elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)): + logging.warning("ChatGLM service unavailable.") + return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex)) + elif isinstance(ex, openai.error.RateLimitError): + return LLMRateLimitError(str(ex)) + elif isinstance(ex, openai.error.AuthenticationError): + return LLMAuthorizationError(str(ex)) + elif isinstance(ex, openai.error.OpenAIError): + return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex)) else: return ex + + @classmethod + def support_streaming(cls): + return True \ No newline at end of file diff --git a/api/core/model_providers/providers/chatglm_provider.py b/api/core/model_providers/providers/chatglm_provider.py index d3c83e37ce..e3db27b2d2 100644 --- a/api/core/model_providers/providers/chatglm_provider.py +++ b/api/core/model_providers/providers/chatglm_provider.py @@ -2,6 +2,7 @@ import json from json import JSONDecodeError from typing import Type +import requests from langchain.llms import ChatGLM from core.helper import encrypter @@ -25,21 +26,26 @@ class ChatGLMProvider(BaseModelProvider): if model_type == ModelType.TEXT_GENERATION: return [ { - 'id': 'chatglm2-6b', - 'name': 'ChatGLM2-6B', - 'mode': ModelMode.COMPLETION.value, + 'id': 'chatglm3-6b', + 'name': 'ChatGLM3-6B', + 'mode': ModelMode.CHAT.value, }, { - 'id': 'chatglm-6b', - 'name': 'ChatGLM-6B', - 'mode': ModelMode.COMPLETION.value, + 'id': 'chatglm3-6b-32k', + 'name': 'ChatGLM3-6B-32K', + 'mode': ModelMode.CHAT.value, + }, + { + 'id': 'chatglm2-6b', + 'name': 'ChatGLM2-6B', + 'mode': ModelMode.CHAT.value, } ] else: return [] def _get_text_generation_model_mode(self, model_name) -> str: - return ModelMode.COMPLETION.value + return ModelMode.CHAT.value def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: """ @@ -64,16 +70,19 @@ class ChatGLMProvider(BaseModelProvider): :return: """ model_max_tokens = { - 'chatglm-6b': 2000, - 'chatglm2-6b': 32000, + 'chatglm3-6b-32k': 32000, + 'chatglm3-6b': 8000, + 'chatglm2-6b': 8000, } + max_tokens_alias = 'max_length' if model_name == 'chatglm2-6b' else 'max_tokens' + return ModelKwargsRules( temperature=KwargRule[float](min=0, max=2, default=1, precision=2), top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2), presence_penalty=KwargRule[float](enabled=False), frequency_penalty=KwargRule[float](enabled=False), - max_tokens=KwargRule[int](alias='max_token', min=10, max=model_max_tokens.get(model_name), default=2048, precision=0), + max_tokens=KwargRule[int](alias=max_tokens_alias, min=10, max=model_max_tokens.get(model_name), default=2048, precision=0), ) @classmethod @@ -85,16 +94,10 @@ class ChatGLMProvider(BaseModelProvider): raise CredentialsValidateFailedError('ChatGLM Endpoint URL must be provided.') try: - credential_kwargs = { - 'endpoint_url': credentials['api_base'] - } + response = requests.get(f"{credentials['api_base']}/v1/models", timeout=5) - llm = ChatGLM( - max_token=10, - **credential_kwargs - ) - - llm("ping") + if response.status_code != 200: + raise Exception('ChatGLM Endpoint URL is invalid.') except Exception as ex: raise CredentialsValidateFailedError(str(ex))