Feat/add triton inference server (#2928)

This commit is contained in:
Yeuoly 2024-03-22 15:15:48 +08:00 committed by GitHub
parent 16af509c46
commit 240a94182e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 365 additions and 1 deletions

View File

@ -11,6 +11,8 @@
- groq
- replicate
- huggingface_hub
- xinference
- triton_inference_server
- zhipuai
- baichuan
- spark
@ -20,7 +22,6 @@
- moonshot
- jina
- chatglm
- xinference
- yi
- openllm
- localai

Binary file not shown.

After

Width:  |  Height:  |  Size: 78 KiB

View File

@ -0,0 +1,3 @@
<svg width="567" height="376" viewBox="0 0 567 376" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M58.0366 161.868C58.0366 161.868 109.261 86.2912 211.538 78.4724V51.053C98.2528 60.1511 0.152344 156.098 0.152344 156.098C0.152344 156.098 55.7148 316.717 211.538 331.426V302.282C97.1876 287.896 58.0366 161.868 58.0366 161.868ZM211.538 244.32V271.013C125.114 255.603 101.125 165.768 101.125 165.768C101.125 165.768 142.621 119.799 211.538 112.345V141.633C211.486 141.633 211.449 141.617 211.406 141.617C175.235 137.276 146.978 171.067 146.978 171.067C146.978 171.067 162.816 227.949 211.538 244.32ZM211.538 0.47998V51.053C214.864 50.7981 218.189 50.5818 221.533 50.468C350.326 46.1273 434.243 156.098 434.243 156.098C434.243 156.098 337.861 273.296 237.448 273.296C228.245 273.296 219.63 272.443 211.538 271.009V302.282C218.695 303.201 225.903 303.667 233.119 303.675C326.56 303.675 394.134 255.954 459.566 199.474C470.415 208.162 514.828 229.299 523.958 238.55C461.745 290.639 316.752 332.626 234.551 332.626C226.627 332.626 219.018 332.148 211.538 331.426V375.369H566.701V0.47998H211.538ZM211.538 112.345V78.4724C214.829 78.2425 218.146 78.0672 221.533 77.9602C314.148 75.0512 374.909 157.548 374.909 157.548C374.909 157.548 309.281 248.693 238.914 248.693C228.787 248.693 219.707 247.065 211.536 244.318V141.631C247.591 145.987 254.848 161.914 276.524 198.049L324.737 157.398C324.737 157.398 289.544 111.243 230.219 111.243C223.768 111.241 217.597 111.696 211.538 112.345Z" fill="#77B900"/>
</svg>

After

Width:  |  Height:  |  Size: 1.5 KiB

View File

@ -0,0 +1,267 @@
from collections.abc import Generator
from httpx import Response, post
from yarl import URL
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageTool,
SystemPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import (
AIModelEntity,
FetchFrom,
ModelPropertyKey,
ModelType,
ParameterRule,
ParameterType,
)
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
class TritonInferenceAILargeLanguageModel(LargeLanguageModel):
def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
-> LLMResult | Generator:
"""
invoke LLM
see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke`
"""
return self._generate(
model=model, credentials=credentials, prompt_messages=prompt_messages, model_parameters=model_parameters,
tools=tools, stop=stop, stream=stream, user=user,
)
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
validate credentials
"""
if 'server_url' not in credentials:
raise CredentialsValidateFailedError('server_url is required in credentials')
try:
self._invoke(model=model, credentials=credentials, prompt_messages=[
UserPromptMessage(content='ping')
], model_parameters={}, stream=False)
except InvokeError as ex:
raise CredentialsValidateFailedError(f'An error occurred during connection: {str(ex)}')
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool] | None = None) -> int:
"""
get number of tokens
cause TritonInference LLM is a customized model, we could net detect which tokenizer to use
so we just take the GPT2 tokenizer as default
"""
return self._get_num_tokens_by_gpt2(self._convert_prompt_message_to_text(prompt_messages))
def _convert_prompt_message_to_text(self, message: list[PromptMessage]) -> str:
"""
convert prompt message to text
"""
text = ''
for item in message:
if isinstance(item, UserPromptMessage):
text += f'User: {item.content}'
elif isinstance(item, SystemPromptMessage):
text += f'System: {item.content}'
elif isinstance(item, AssistantPromptMessage):
text += f'Assistant: {item.content}'
else:
raise NotImplementedError(f'PromptMessage type {type(item)} is not supported')
return text
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
used to define customizable model schema
"""
rules = [
ParameterRule(
name='temperature',
type=ParameterType.FLOAT,
use_template='temperature',
label=I18nObject(
zh_Hans='温度',
en_US='Temperature'
),
),
ParameterRule(
name='top_p',
type=ParameterType.FLOAT,
use_template='top_p',
label=I18nObject(
zh_Hans='Top P',
en_US='Top P'
)
),
ParameterRule(
name='max_tokens',
type=ParameterType.INT,
use_template='max_tokens',
min=1,
max=int(credentials.get('context_length', 2048)),
default=min(512, int(credentials.get('context_length', 2048))),
label=I18nObject(
zh_Hans='最大生成长度',
en_US='Max Tokens'
)
)
]
completion_type = None
if 'completion_type' in credentials:
if credentials['completion_type'] == 'chat':
completion_type = LLMMode.CHAT.value
elif credentials['completion_type'] == 'completion':
completion_type = LLMMode.COMPLETION.value
else:
raise ValueError(f'completion_type {credentials["completion_type"]} is not supported')
entity = AIModelEntity(
model=model,
label=I18nObject(
en_US=model
),
parameter_rules=rules,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.LLM,
model_properties={
ModelPropertyKey.MODE: completion_type,
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_length', 2048)),
},
)
return entity
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
-> LLMResult | Generator:
"""
generate text from LLM
"""
if 'server_url' not in credentials:
raise CredentialsValidateFailedError('server_url is required in credentials')
if 'stream' in credentials and not bool(credentials['stream']) and stream:
raise ValueError(f'stream is not supported by model {model}')
try:
parameters = {}
if 'temperature' in model_parameters:
parameters['temperature'] = model_parameters['temperature']
if 'top_p' in model_parameters:
parameters['top_p'] = model_parameters['top_p']
if 'top_k' in model_parameters:
parameters['top_k'] = model_parameters['top_k']
if 'presence_penalty' in model_parameters:
parameters['presence_penalty'] = model_parameters['presence_penalty']
if 'frequency_penalty' in model_parameters:
parameters['frequency_penalty'] = model_parameters['frequency_penalty']
response = post(str(URL(credentials['server_url']) / 'v2' / 'models' / model / 'generate'), json={
'text_input': self._convert_prompt_message_to_text(prompt_messages),
'max_tokens': model_parameters.get('max_tokens', 512),
'parameters': {
'stream': False,
**parameters
},
}, timeout=(10, 120))
response.raise_for_status()
if response.status_code != 200:
raise InvokeBadRequestError(f'Invoke failed with status code {response.status_code}, {response.text}')
if stream:
return self._handle_chat_stream_response(model=model, credentials=credentials, prompt_messages=prompt_messages,
tools=tools, resp=response)
return self._handle_chat_generate_response(model=model, credentials=credentials, prompt_messages=prompt_messages,
tools=tools, resp=response)
except Exception as ex:
raise InvokeConnectionError(f'An error occurred during connection: {str(ex)}')
def _handle_chat_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool],
resp: Response) -> LLMResult:
"""
handle normal chat generate response
"""
text = resp.json()['text_output']
usage = LLMUsage.empty_usage()
usage.prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
usage.completion_tokens = self._get_num_tokens_by_gpt2(text)
return LLMResult(
model=model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(
content=text
),
usage=usage
)
def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool],
resp: Response) -> Generator:
"""
handle normal chat generate response
"""
text = resp.json()['text_output']
usage = LLMUsage.empty_usage()
usage.prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
usage.completion_tokens = self._get_num_tokens_by_gpt2(text)
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=text
),
usage=usage
)
)
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [
],
InvokeServerUnavailableError: [
],
InvokeRateLimitError: [
],
InvokeAuthorizationError: [
],
InvokeBadRequestError: [
ValueError
]
}

View File

@ -0,0 +1,9 @@
import logging
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
logger = logging.getLogger(__name__)
class XinferenceAIProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
pass

View File

@ -0,0 +1,84 @@
provider: triton_inference_server
label:
en_US: Triton Inference Server
icon_small:
en_US: icon_s_en.svg
icon_large:
en_US: icon_l_en.png
background: "#EFFDFD"
help:
title:
en_US: How to deploy Triton Inference Server
zh_Hans: 如何部署 Triton Inference Server
url:
en_US: https://github.com/triton-inference-server/server
supported_model_types:
- llm
configurate_methods:
- customizable-model
model_credential_schema:
model:
label:
en_US: Model Name
zh_Hans: 模型名称
placeholder:
en_US: Enter your model name
zh_Hans: 输入模型名称
credential_form_schemas:
- variable: server_url
label:
zh_Hans: 服务器URL
en_US: Server url
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入 Triton Inference Server 的服务器地址,如 http://192.168.1.100:8000
en_US: Enter the url of your Triton Inference Server, e.g. http://192.168.1.100:8000
- variable: context_size
label:
zh_Hans: 上下文大小
en_US: Context size
type: text-input
required: true
placeholder:
zh_Hans: 在此输入您的上下文大小
en_US: Enter the context size
default: 2048
- variable: completion_type
label:
zh_Hans: 补全类型
en_US: Model type
type: select
required: true
default: chat
placeholder:
zh_Hans: 在此输入您的补全类型
en_US: Enter the completion type
options:
- label:
zh_Hans: 补全模型
en_US: Completion model
value: completion
- label:
zh_Hans: 对话模型
en_US: Chat model
value: chat
- variable: stream
label:
zh_Hans: 流式输出
en_US: Stream output
type: select
required: true
default: true
placeholder:
zh_Hans: 是否支持流式输出
en_US: Whether to support stream output
options:
- label:
zh_Hans:
en_US: Yes
value: true
- label:
zh_Hans:
en_US: No
value: false