support xinference tts (#6746)

This commit is contained in:
Weaxs 2024-08-01 11:59:15 +08:00 committed by GitHub
parent 08f922d8c9
commit f6e8e120a1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 246 additions and 5 deletions

View File

@ -0,0 +1,240 @@
import concurrent.futures
from functools import reduce
from io import BytesIO
from typing import Optional
from flask import Response
from pydub import AudioSegment
from xinference_client.client.restful.restful_client import Client, RESTfulAudioModelHandle
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.tts_model import TTSModel
class XinferenceText2SpeechModel(TTSModel):
def __init__(self):
# preset voices, need support custom voice
self.model_voices = {
'chattts': {
'all': [
{'name': 'Alloy', 'value': 'alloy'},
{'name': 'Echo', 'value': 'echo'},
{'name': 'Fable', 'value': 'fable'},
{'name': 'Onyx', 'value': 'onyx'},
{'name': 'Nova', 'value': 'nova'},
{'name': 'Shimmer', 'value': 'shimmer'},
]
},
'cosyvoice': {
'zh-Hans': [
{'name': '中文男', 'value': '中文男'},
{'name': '中文女', 'value': '中文女'},
{'name': '粤语女', 'value': '粤语女'},
],
'zh-Hant': [
{'name': '中文男', 'value': '中文男'},
{'name': '中文女', 'value': '中文女'},
{'name': '粤语女', 'value': '粤语女'},
],
'en-US': [
{'name': '英文男', 'value': '英文男'},
{'name': '英文女', 'value': '英文女'},
],
'ja-JP': [
{'name': '日语男', 'value': '日语男'},
],
'ko-KR': [
{'name': '韩语女', 'value': '韩语女'},
]
}
}
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
if ("/" in credentials['model_uid'] or
"?" in credentials['model_uid'] or
"#" in credentials['model_uid']):
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
if credentials['server_url'].endswith('/'):
credentials['server_url'] = credentials['server_url'][:-1]
# initialize client
client = Client(
base_url=credentials['server_url']
)
xinference_client = client.get_model(model_uid=credentials['model_uid'])
if not isinstance(xinference_client, RESTfulAudioModelHandle):
raise InvokeBadRequestError(
'please check model type, the model you want to invoke is not a audio model')
self._tts_invoke(
model=model,
credentials=credentials,
content_text='Hello Dify!',
voice=self._get_model_default_voice(model, credentials),
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str,
user: Optional[str] = None):
"""
_invoke text2speech model
:param model: model name
:param tenant_id: user tenant id
:param credentials: model credentials
:param voice: model timbre
:param content_text: text content to be translated
:param user: unique user id
:return: text translated to audio file
"""
return self._tts_invoke(model, credentials, content_text, voice)
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
used to define customizable model schema
"""
entity = AIModelEntity(
model=model,
label=I18nObject(
en_US=model
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.TTS,
model_properties={},
parameter_rules=[]
)
return entity
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [
InvokeConnectionError
],
InvokeServerUnavailableError: [
InvokeServerUnavailableError
],
InvokeRateLimitError: [
InvokeRateLimitError
],
InvokeAuthorizationError: [
InvokeAuthorizationError
],
InvokeBadRequestError: [
InvokeBadRequestError,
KeyError,
ValueError
]
}
def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list:
for key, voices in self.model_voices.items():
if key in model.lower():
if language in voices:
return voices[language]
elif 'all' in voices:
return voices['all']
return []
def _get_model_default_voice(self, model: str, credentials: dict) -> any:
return ""
def _get_model_word_limit(self, model: str, credentials: dict) -> int:
return 3500
def _get_model_audio_type(self, model: str, credentials: dict) -> str:
return "mp3"
def _get_model_workers_limit(self, model: str, credentials: dict) -> int:
return 5
def _tts_invoke(self, model: str, credentials: dict, content_text: str, voice: str) -> any:
"""
_tts_invoke text2speech model
:param model: model name
:param credentials: model credentials
:param voice: model timbre
:param content_text: text content to be translated
:return: text translated to audio file
"""
if credentials['server_url'].endswith('/'):
credentials['server_url'] = credentials['server_url'][:-1]
word_limit = self._get_model_word_limit(model, credentials)
audio_type = self._get_model_audio_type(model, credentials)
handle = RESTfulAudioModelHandle(credentials['model_uid'], credentials['server_url'], auth_headers={})
try:
sentences = list(self._split_text_into_sentences(org_text=content_text, max_length=word_limit))
audio_bytes_list = []
with concurrent.futures.ThreadPoolExecutor(max_workers=min((3, len(sentences)))) as executor:
futures = [executor.submit(
handle.speech, input=sentence, voice=voice, response_format="mp3", speed=1.0, stream=False)
for sentence in sentences]
for future in futures:
try:
if future.result():
audio_bytes_list.append(future.result())
except Exception as ex:
raise InvokeBadRequestError(str(ex))
if len(audio_bytes_list) > 0:
audio_segments = [AudioSegment.from_file(
BytesIO(audio_bytes), format=audio_type) for audio_bytes in
audio_bytes_list if audio_bytes]
combined_segment = reduce(lambda x, y: x + y, audio_segments)
buffer: BytesIO = BytesIO()
combined_segment.export(buffer, format=audio_type)
buffer.seek(0)
return Response(buffer.read(), status=200, mimetype=f"audio/{audio_type}")
except Exception as ex:
raise InvokeBadRequestError(str(ex))
def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> any:
"""
_tts_invoke_streaming text2speech model
Attention: stream api may return error [Parallel generation is not supported by ggml]
:param model: model name
:param credentials: model credentials
:param voice: model timbre
:param content_text: text content to be translated
:return: text translated to audio file
"""
pass

View File

@ -17,6 +17,7 @@ supported_model_types:
- text-embedding
- rerank
- speech2text
- tts
configurate_methods:
- customizable-model
model_credential_schema:

8
api/poetry.lock generated
View File

@ -9098,13 +9098,13 @@ h11 = ">=0.9.0,<1"
[[package]]
name = "xinference-client"
version = "0.9.4"
version = "0.13.3"
description = "Client for Xinference"
optional = false
python-versions = "*"
files = [
{file = "xinference-client-0.9.4.tar.gz", hash = "sha256:21934bc9f3142ade66aaed33c2b6cf244c274d5b4b3163f9981bebdddacf205f"},
{file = "xinference_client-0.9.4-py3-none-any.whl", hash = "sha256:6d3f1df3537a011f0afee5f9c9ca4f3ff564ca32cc999cf7038b324c0b907d0c"},
{file = "xinference-client-0.13.3.tar.gz", hash = "sha256:822b722100affdff049c27760be7d62ac92de58c87a40d3361066df446ba648f"},
{file = "xinference_client-0.13.3-py3-none-any.whl", hash = "sha256:f0eff3858b1ebcef2129726f82b09259c177e11db466a7ca23def3d4849c419f"},
]
[package.dependencies]
@ -9502,4 +9502,4 @@ cffi = ["cffi (>=1.11)"]
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "a8b61d74d9322302b7447b6f8728ad606abc160202a8a122a05a8ef3cec7055b"
content-hash = "ca55e4a4bb354fe969cc73c823557525c7598b0375e8791fcd77febc59e03b96"

View File

@ -173,7 +173,7 @@ transformers = "~4.35.0"
unstructured = { version = "~0.10.27", extras = ["docx", "epub", "md", "msg", "ppt", "pptx"] }
websocket-client = "~1.7.0"
werkzeug = "~3.0.1"
xinference-client = "0.9.4"
xinference-client = "0.13.3"
yarl = "~1.9.4"
zhipuai = "1.0.7"
rank-bm25 = "~0.2.2"