mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 03:32:23 +08:00
feat:xinference audio model support (#3045)
This commit is contained in:
parent
12782cad4d
commit
e215aae39a
|
@ -0,0 +1,148 @@
|
|||
from typing import IO, Optional
|
||||
|
||||
from xinference_client.client.restful.restful_client import Client, RESTfulAudioModelHandle
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
|
||||
|
||||
|
||||
class XinferenceSpeech2TextModel(Speech2TextModel):
|
||||
"""
|
||||
Model class for Xinference speech to text model.
|
||||
"""
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
file: IO[bytes], user: Optional[str] = None) \
|
||||
-> str:
|
||||
"""
|
||||
Invoke speech2text model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param file: audio file
|
||||
:param user: unique user id
|
||||
:return: text for given audio file
|
||||
"""
|
||||
return self._speech2text_invoke(model, credentials, file)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']:
|
||||
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
|
||||
|
||||
audio_file_path = self._get_demo_file_path()
|
||||
|
||||
with open(audio_file_path, 'rb') as audio_file:
|
||||
self.invoke(model, credentials, audio_file)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
InvokeConnectionError
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
InvokeServerUnavailableError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
InvokeRateLimitError
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
InvokeAuthorizationError
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
InvokeBadRequestError,
|
||||
KeyError,
|
||||
ValueError
|
||||
]
|
||||
}
|
||||
|
||||
def _speech2text_invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
file: IO[bytes],
|
||||
language: Optional[str] = None,
|
||||
prompt: Optional[str] = None,
|
||||
response_format: Optional[str] = "json",
|
||||
temperature: Optional[float] = 0,
|
||||
) -> str:
|
||||
"""
|
||||
Invoke speech2text model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param file: The audio file object (not file name) to transcribe, in one of these formats: flac, mp3, mp4, mpe g,mpga, m4a, ogg, wav, or webm.
|
||||
:param language: The language of the input audio. Supplying the input language in ISO-639-1
|
||||
:param prompt: An optional text to guide the model's style or continue a previous audio segment.
|
||||
The prompt should match the audio language.
|
||||
:param response_format: The format of the transcript output, in one of these options: json, text, srt, verbose _json, or vtt.
|
||||
:param temperature: The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output mor e random,while lower values like 0.2 will make it more focused and deterministic.If set to 0, the model wi ll use log probability to automatically increase the temperature until certain thresholds are hit.
|
||||
:return: text for given audio file
|
||||
"""
|
||||
if credentials['server_url'].endswith('/'):
|
||||
credentials['server_url'] = credentials['server_url'][:-1]
|
||||
|
||||
# initialize client
|
||||
client = Client(
|
||||
base_url=credentials['server_url']
|
||||
)
|
||||
|
||||
xinference_client = client.get_model(model_uid=credentials['model_uid'])
|
||||
|
||||
if not isinstance(xinference_client, RESTfulAudioModelHandle):
|
||||
raise InvokeBadRequestError('please check model type, the model you want to invoke is not a audio model')
|
||||
|
||||
response = xinference_client.transcriptions(
|
||||
audio=file,
|
||||
language = language,
|
||||
prompt = prompt,
|
||||
response_format = response_format,
|
||||
temperature = temperature
|
||||
)
|
||||
|
||||
return response["text"]
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(
|
||||
en_US=model
|
||||
),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.SPEECH2TEXT,
|
||||
model_properties={ },
|
||||
parameter_rules=[]
|
||||
)
|
||||
|
||||
return entity
|
|
@ -16,6 +16,7 @@ supported_model_types:
|
|||
- llm
|
||||
- text-embedding
|
||||
- rerank
|
||||
- speech2text
|
||||
configurate_methods:
|
||||
- customizable-model
|
||||
model_credential_schema:
|
||||
|
|
|
@ -48,7 +48,7 @@ dashscope[tokenizer]~=1.14.0
|
|||
huggingface_hub~=0.16.4
|
||||
transformers~=4.31.0
|
||||
pandas==1.5.3
|
||||
xinference-client==0.8.4
|
||||
xinference-client==0.9.4
|
||||
safetensors==0.3.2
|
||||
zhipuai==1.0.7
|
||||
werkzeug~=3.0.1
|
||||
|
@ -73,4 +73,4 @@ yarl~=1.9.4
|
|||
twilio==9.0.0
|
||||
qrcode~=7.4.2
|
||||
azure-storage-blob==12.9.0
|
||||
azure-identity==1.15.0
|
||||
azure-identity==1.15.0
|
||||
|
|
Loading…
Reference in New Issue
Block a user