mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
support xinference tts (#6746)
This commit is contained in:
parent
08f922d8c9
commit
f6e8e120a1
240
api/core/model_runtime/model_providers/xinference/tts/tts.py
Normal file
240
api/core/model_runtime/model_providers/xinference/tts/tts.py
Normal file
|
@ -0,0 +1,240 @@
|
||||||
|
import concurrent.futures
|
||||||
|
from functools import reduce
|
||||||
|
from io import BytesIO
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from flask import Response
|
||||||
|
from pydub import AudioSegment
|
||||||
|
from xinference_client.client.restful.restful_client import Client, RESTfulAudioModelHandle
|
||||||
|
|
||||||
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
|
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||||
|
from core.model_runtime.errors.invoke import (
|
||||||
|
InvokeAuthorizationError,
|
||||||
|
InvokeBadRequestError,
|
||||||
|
InvokeConnectionError,
|
||||||
|
InvokeError,
|
||||||
|
InvokeRateLimitError,
|
||||||
|
InvokeServerUnavailableError,
|
||||||
|
)
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||||
|
|
||||||
|
|
||||||
|
class XinferenceText2SpeechModel(TTSModel):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# preset voices, need support custom voice
|
||||||
|
self.model_voices = {
|
||||||
|
'chattts': {
|
||||||
|
'all': [
|
||||||
|
{'name': 'Alloy', 'value': 'alloy'},
|
||||||
|
{'name': 'Echo', 'value': 'echo'},
|
||||||
|
{'name': 'Fable', 'value': 'fable'},
|
||||||
|
{'name': 'Onyx', 'value': 'onyx'},
|
||||||
|
{'name': 'Nova', 'value': 'nova'},
|
||||||
|
{'name': 'Shimmer', 'value': 'shimmer'},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
'cosyvoice': {
|
||||||
|
'zh-Hans': [
|
||||||
|
{'name': '中文男', 'value': '中文男'},
|
||||||
|
{'name': '中文女', 'value': '中文女'},
|
||||||
|
{'name': '粤语女', 'value': '粤语女'},
|
||||||
|
],
|
||||||
|
'zh-Hant': [
|
||||||
|
{'name': '中文男', 'value': '中文男'},
|
||||||
|
{'name': '中文女', 'value': '中文女'},
|
||||||
|
{'name': '粤语女', 'value': '粤语女'},
|
||||||
|
],
|
||||||
|
'en-US': [
|
||||||
|
{'name': '英文男', 'value': '英文男'},
|
||||||
|
{'name': '英文女', 'value': '英文女'},
|
||||||
|
],
|
||||||
|
'ja-JP': [
|
||||||
|
{'name': '日语男', 'value': '日语男'},
|
||||||
|
],
|
||||||
|
'ko-KR': [
|
||||||
|
{'name': '韩语女', 'value': '韩语女'},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
|
"""
|
||||||
|
Validate model credentials
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if ("/" in credentials['model_uid'] or
|
||||||
|
"?" in credentials['model_uid'] or
|
||||||
|
"#" in credentials['model_uid']):
|
||||||
|
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")
|
||||||
|
|
||||||
|
if credentials['server_url'].endswith('/'):
|
||||||
|
credentials['server_url'] = credentials['server_url'][:-1]
|
||||||
|
|
||||||
|
# initialize client
|
||||||
|
client = Client(
|
||||||
|
base_url=credentials['server_url']
|
||||||
|
)
|
||||||
|
|
||||||
|
xinference_client = client.get_model(model_uid=credentials['model_uid'])
|
||||||
|
|
||||||
|
if not isinstance(xinference_client, RESTfulAudioModelHandle):
|
||||||
|
raise InvokeBadRequestError(
|
||||||
|
'please check model type, the model you want to invoke is not a audio model')
|
||||||
|
|
||||||
|
self._tts_invoke(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
content_text='Hello Dify!',
|
||||||
|
voice=self._get_model_default_voice(model, credentials),
|
||||||
|
)
|
||||||
|
except Exception as ex:
|
||||||
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
|
||||||
|
def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str,
|
||||||
|
user: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
_invoke text2speech model
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param tenant_id: user tenant id
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param voice: model timbre
|
||||||
|
:param content_text: text content to be translated
|
||||||
|
:param user: unique user id
|
||||||
|
:return: text translated to audio file
|
||||||
|
"""
|
||||||
|
return self._tts_invoke(model, credentials, content_text, voice)
|
||||||
|
|
||||||
|
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||||
|
"""
|
||||||
|
used to define customizable model schema
|
||||||
|
"""
|
||||||
|
|
||||||
|
entity = AIModelEntity(
|
||||||
|
model=model,
|
||||||
|
label=I18nObject(
|
||||||
|
en_US=model
|
||||||
|
),
|
||||||
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
|
model_type=ModelType.TTS,
|
||||||
|
model_properties={},
|
||||||
|
parameter_rules=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
return entity
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||||
|
"""
|
||||||
|
Map model invoke error to unified error
|
||||||
|
The key is the error type thrown to the caller
|
||||||
|
The value is the error type thrown by the model,
|
||||||
|
which needs to be converted into a unified error type for the caller.
|
||||||
|
|
||||||
|
:return: Invoke error mapping
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
InvokeConnectionError: [
|
||||||
|
InvokeConnectionError
|
||||||
|
],
|
||||||
|
InvokeServerUnavailableError: [
|
||||||
|
InvokeServerUnavailableError
|
||||||
|
],
|
||||||
|
InvokeRateLimitError: [
|
||||||
|
InvokeRateLimitError
|
||||||
|
],
|
||||||
|
InvokeAuthorizationError: [
|
||||||
|
InvokeAuthorizationError
|
||||||
|
],
|
||||||
|
InvokeBadRequestError: [
|
||||||
|
InvokeBadRequestError,
|
||||||
|
KeyError,
|
||||||
|
ValueError
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list:
|
||||||
|
for key, voices in self.model_voices.items():
|
||||||
|
if key in model.lower():
|
||||||
|
if language in voices:
|
||||||
|
return voices[language]
|
||||||
|
elif 'all' in voices:
|
||||||
|
return voices['all']
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _get_model_default_voice(self, model: str, credentials: dict) -> any:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def _get_model_word_limit(self, model: str, credentials: dict) -> int:
|
||||||
|
return 3500
|
||||||
|
|
||||||
|
def _get_model_audio_type(self, model: str, credentials: dict) -> str:
|
||||||
|
return "mp3"
|
||||||
|
|
||||||
|
def _get_model_workers_limit(self, model: str, credentials: dict) -> int:
|
||||||
|
return 5
|
||||||
|
|
||||||
|
def _tts_invoke(self, model: str, credentials: dict, content_text: str, voice: str) -> any:
|
||||||
|
"""
|
||||||
|
_tts_invoke text2speech model
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param voice: model timbre
|
||||||
|
:param content_text: text content to be translated
|
||||||
|
:return: text translated to audio file
|
||||||
|
"""
|
||||||
|
if credentials['server_url'].endswith('/'):
|
||||||
|
credentials['server_url'] = credentials['server_url'][:-1]
|
||||||
|
|
||||||
|
word_limit = self._get_model_word_limit(model, credentials)
|
||||||
|
audio_type = self._get_model_audio_type(model, credentials)
|
||||||
|
handle = RESTfulAudioModelHandle(credentials['model_uid'], credentials['server_url'], auth_headers={})
|
||||||
|
|
||||||
|
try:
|
||||||
|
sentences = list(self._split_text_into_sentences(org_text=content_text, max_length=word_limit))
|
||||||
|
audio_bytes_list = []
|
||||||
|
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(max_workers=min((3, len(sentences)))) as executor:
|
||||||
|
futures = [executor.submit(
|
||||||
|
handle.speech, input=sentence, voice=voice, response_format="mp3", speed=1.0, stream=False)
|
||||||
|
for sentence in sentences]
|
||||||
|
for future in futures:
|
||||||
|
try:
|
||||||
|
if future.result():
|
||||||
|
audio_bytes_list.append(future.result())
|
||||||
|
except Exception as ex:
|
||||||
|
raise InvokeBadRequestError(str(ex))
|
||||||
|
|
||||||
|
if len(audio_bytes_list) > 0:
|
||||||
|
audio_segments = [AudioSegment.from_file(
|
||||||
|
BytesIO(audio_bytes), format=audio_type) for audio_bytes in
|
||||||
|
audio_bytes_list if audio_bytes]
|
||||||
|
combined_segment = reduce(lambda x, y: x + y, audio_segments)
|
||||||
|
buffer: BytesIO = BytesIO()
|
||||||
|
combined_segment.export(buffer, format=audio_type)
|
||||||
|
buffer.seek(0)
|
||||||
|
return Response(buffer.read(), status=200, mimetype=f"audio/{audio_type}")
|
||||||
|
except Exception as ex:
|
||||||
|
raise InvokeBadRequestError(str(ex))
|
||||||
|
|
||||||
|
def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> any:
|
||||||
|
"""
|
||||||
|
_tts_invoke_streaming text2speech model
|
||||||
|
|
||||||
|
Attention: stream api may return error [Parallel generation is not supported by ggml]
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param voice: model timbre
|
||||||
|
:param content_text: text content to be translated
|
||||||
|
:return: text translated to audio file
|
||||||
|
"""
|
||||||
|
pass
|
|
@ -17,6 +17,7 @@ supported_model_types:
|
||||||
- text-embedding
|
- text-embedding
|
||||||
- rerank
|
- rerank
|
||||||
- speech2text
|
- speech2text
|
||||||
|
- tts
|
||||||
configurate_methods:
|
configurate_methods:
|
||||||
- customizable-model
|
- customizable-model
|
||||||
model_credential_schema:
|
model_credential_schema:
|
||||||
|
|
8
api/poetry.lock
generated
8
api/poetry.lock
generated
|
@ -9098,13 +9098,13 @@ h11 = ">=0.9.0,<1"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "xinference-client"
|
name = "xinference-client"
|
||||||
version = "0.9.4"
|
version = "0.13.3"
|
||||||
description = "Client for Xinference"
|
description = "Client for Xinference"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = "*"
|
python-versions = "*"
|
||||||
files = [
|
files = [
|
||||||
{file = "xinference-client-0.9.4.tar.gz", hash = "sha256:21934bc9f3142ade66aaed33c2b6cf244c274d5b4b3163f9981bebdddacf205f"},
|
{file = "xinference-client-0.13.3.tar.gz", hash = "sha256:822b722100affdff049c27760be7d62ac92de58c87a40d3361066df446ba648f"},
|
||||||
{file = "xinference_client-0.9.4-py3-none-any.whl", hash = "sha256:6d3f1df3537a011f0afee5f9c9ca4f3ff564ca32cc999cf7038b324c0b907d0c"},
|
{file = "xinference_client-0.13.3-py3-none-any.whl", hash = "sha256:f0eff3858b1ebcef2129726f82b09259c177e11db466a7ca23def3d4849c419f"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
|
@ -9502,4 +9502,4 @@ cffi = ["cffi (>=1.11)"]
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.10"
|
python-versions = "^3.10"
|
||||||
content-hash = "a8b61d74d9322302b7447b6f8728ad606abc160202a8a122a05a8ef3cec7055b"
|
content-hash = "ca55e4a4bb354fe969cc73c823557525c7598b0375e8791fcd77febc59e03b96"
|
||||||
|
|
|
@ -173,7 +173,7 @@ transformers = "~4.35.0"
|
||||||
unstructured = { version = "~0.10.27", extras = ["docx", "epub", "md", "msg", "ppt", "pptx"] }
|
unstructured = { version = "~0.10.27", extras = ["docx", "epub", "md", "msg", "ppt", "pptx"] }
|
||||||
websocket-client = "~1.7.0"
|
websocket-client = "~1.7.0"
|
||||||
werkzeug = "~3.0.1"
|
werkzeug = "~3.0.1"
|
||||||
xinference-client = "0.9.4"
|
xinference-client = "0.13.3"
|
||||||
yarl = "~1.9.4"
|
yarl = "~1.9.4"
|
||||||
zhipuai = "1.0.7"
|
zhipuai = "1.0.7"
|
||||||
rank-bm25 = "~0.2.2"
|
rank-bm25 = "~0.2.2"
|
||||||
|
|
Loading…
Reference in New Issue
Block a user