feat: add openllm support (#928)

This commit is contained in:
takatost 2023-08-20 19:04:33 +08:00 committed by GitHub
parent da3f10a55e
commit 3ea8d7a019
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 412 additions and 3 deletions

View File

@ -60,6 +60,9 @@ class ModelProviderFactory:
elif provider_name == 'xinference':
from core.model_providers.providers.xinference_provider import XinferenceProvider
return XinferenceProvider
elif provider_name == 'openllm':
from core.model_providers.providers.openllm_provider import OpenLLMProvider
return OpenLLMProvider
else:
raise NotImplementedError

View File

@ -0,0 +1,60 @@
from typing import List, Optional, Any
from langchain.callbacks.manager import Callbacks
from langchain.llms import OpenLLM
from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
class OpenLLMModel(BaseLLM):
model_mode: ModelMode = ModelMode.COMPLETION
def _init_client(self) -> Any:
self.provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
client = OpenLLM(
server_url=self.credentials.get('server_url'),
callbacks=self.callbacks,
**self.provider_model_kwargs
)
return client
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0)
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
pass
def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"OpenLLM: {str(ex)}")
@classmethod
def support_streaming(cls):
return False

View File

@ -0,0 +1,137 @@
import json
from typing import Type
from langchain.llms import OpenLLM
from core.helper import encrypter
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
from core.model_providers.models.llm.openllm_model import OpenLLMModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.model_providers.models.base import BaseProviderModel
from models.provider import ProviderType
class OpenLLMProvider(BaseModelProvider):
@property
def provider_name(self):
"""
Returns the name of a provider.
"""
return 'openllm'
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return []
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
:param model_type:
:return:
"""
if model_type == ModelType.TEXT_GENERATION:
model_class = OpenLLMModel
else:
raise NotImplementedError
return model_class
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
"""
get model parameter rules.
:param model_name:
:param model_type:
:return:
"""
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=0.7),
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
max_tokens=KwargRule[int](min=10, max=4000, default=128),
)
@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
if 'server_url' not in credentials:
raise CredentialsValidateFailedError('OpenLLM Server URL must be provided.')
try:
credential_kwargs = {
'server_url': credentials['server_url']
}
llm = OpenLLM(
max_tokens=10,
**credential_kwargs
)
llm("ping")
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url'])
return credentials
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
if self.provider.provider_type != ProviderType.CUSTOM.value:
raise NotImplementedError
provider_model = self._get_provider_model(model_name, model_type)
if not provider_model.encrypted_config:
return {
'server_url': None
}
credentials = json.loads(provider_model.encrypted_config)
if credentials['server_url']:
credentials['server_url'] = encrypter.decrypt_token(
self.provider.tenant_id,
credentials['server_url']
)
if obfuscated:
credentials['server_url'] = encrypter.obfuscated_token(credentials['server_url'])
return credentials
@classmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
return
@classmethod
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
return {}
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
return {}

View File

@ -9,5 +9,6 @@
"chatglm",
"replicate",
"huggingface_hub",
"xinference"
"xinference",
"openllm"
]

View File

@ -0,0 +1,7 @@
{
"support_provider_types": [
"custom"
],
"system_config": null,
"model_flexibility": "configurable"
}

View File

@ -49,4 +49,5 @@ huggingface_hub~=0.16.4
transformers~=4.31.0
stripe~=5.5.0
pandas==1.5.3
xinference==0.2.0
xinference==0.2.0
openllm~=0.2.26

View File

@ -36,4 +36,7 @@ CHATGLM_API_BASE=
# Xinference Credentials
XINFERENCE_SERVER_URL=
XINFERENCE_MODEL_UID=
XINFERENCE_MODEL_UID=
# OpenLLM Credentials
OPENLLM_SERVER_URL=

View File

@ -0,0 +1,72 @@
import json
import os
from unittest.mock import patch, MagicMock
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
from core.model_providers.models.llm.openllm_model import OpenLLMModel
from core.model_providers.providers.openllm_provider import OpenLLMProvider
from models.provider import Provider, ProviderType, ProviderModel
def get_mock_provider():
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='openllm',
provider_type=ProviderType.CUSTOM.value,
encrypted_config='',
is_valid=True,
)
def get_mock_model(model_name, mocker):
model_kwargs = ModelKwargs(
max_tokens=10,
temperature=0.01
)
server_url = os.environ['OPENLLM_SERVER_URL']
model_provider = OpenLLMProvider(provider=get_mock_provider())
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
provider_name='openllm',
model_name=model_name,
model_type=ModelType.TEXT_GENERATION.value,
encrypted_config=json.dumps({
'server_url': server_url
}),
is_valid=True,
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
return OpenLLMModel(
model_provider=model_provider,
name=model_name,
model_kwargs=model_kwargs
)
def decrypt_side_effect(tenant_id, encrypted_api_key):
return encrypted_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_num_tokens(mock_decrypt, mocker):
model = get_mock_model('facebook/opt-125m', mocker)
rst = model.get_num_tokens([
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
])
assert rst == 5
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_run(mock_decrypt, mocker):
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
model = get_mock_model('facebook/opt-125m', mocker)
messages = [PromptMessage(content='Human: who are you? \nAnswer: ')]
rst = model.run(
messages
)
assert len(rst.content) > 0

View File

@ -0,0 +1,125 @@
import pytest
from unittest.mock import patch, MagicMock
import json
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.base import CredentialsValidateFailedError
from core.model_providers.providers.openllm_provider import OpenLLMProvider
from models.provider import ProviderType, Provider, ProviderModel
PROVIDER_NAME = 'openllm'
MODEL_PROVIDER_CLASS = OpenLLMProvider
VALIDATE_CREDENTIAL = {
'server_url': 'http://127.0.0.1:3333/'
}
def encrypt_side_effect(tenant_id, encrypt_key):
return f'encrypted_{encrypt_key}'
def decrypt_side_effect(tenant_id, encrypted_key):
return encrypted_key.replace('encrypted_', '')
def test_is_credentials_valid_or_raise_valid(mocker):
mocker.patch('langchain.llms.openllm.OpenLLM._identifying_params', return_value=None)
mocker.patch('langchain.llms.openllm.OpenLLM._call',
return_value="abc")
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
model_name='username/test_model_name',
model_type=ModelType.TEXT_GENERATION,
credentials=VALIDATE_CREDENTIAL.copy()
)
def test_is_credentials_valid_or_raise_invalid(mocker):
mocker.patch('langchain.llms.openllm.OpenLLM._identifying_params', return_value=None)
# raise CredentialsValidateFailedError if credential is not in credentials
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
model_name='test_model_name',
model_type=ModelType.TEXT_GENERATION,
credentials={}
)
# raise CredentialsValidateFailedError if credential is invalid
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
model_name='test_model_name',
model_type=ModelType.TEXT_GENERATION,
credentials={'server_url': 'invalid'})
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
def test_encrypt_model_credentials(mock_encrypt):
api_key = 'http://127.0.0.1:3333/'
result = MODEL_PROVIDER_CLASS.encrypt_model_credentials(
tenant_id='tenant_id',
model_name='test_model_name',
model_type=ModelType.TEXT_GENERATION,
credentials=VALIDATE_CREDENTIAL.copy()
)
mock_encrypt.assert_called_with('tenant_id', api_key)
assert result['server_url'] == f'encrypted_{api_key}'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_model_credentials_custom(mock_decrypt, mocker):
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=None,
is_valid=True,
)
encrypted_credential = VALIDATE_CREDENTIAL.copy()
encrypted_credential['server_url'] = 'encrypted_' + encrypted_credential['server_url']
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
encrypted_config=json.dumps(encrypted_credential)
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_model_credentials(
model_name='test_model_name',
model_type=ModelType.TEXT_GENERATION
)
assert result['server_url'] == 'http://127.0.0.1:3333/'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_model_credentials_obfuscated(mock_decrypt, mocker):
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=None,
is_valid=True,
)
encrypted_credential = VALIDATE_CREDENTIAL.copy()
encrypted_credential['server_url'] = 'encrypted_' + encrypted_credential['server_url']
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
encrypted_config=json.dumps(encrypted_credential)
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_model_credentials(
model_name='test_model_name',
model_type=ModelType.TEXT_GENERATION,
obfuscated=True
)
middle_token = result['server_url'][6:-2]
assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['server_url']) - 8, 0)
assert all(char == '*' for char in middle_token)