feat: support fish audio TTS (#7982)

This commit is contained in:
Leng Yue 2024-09-04 23:18:39 -07:00 committed by GitHub
parent 3e7597f2bd
commit bd0992275c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 433 additions and 0 deletions

View File

@ -0,0 +1 @@


View File

@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" viewBox="61.1 180.15 377.8 139.718"><path d="M431.911 245.181c3.842 0 6.989 1.952 6.989 4.337v14.776c0 2.385-3.147 4.337-6.989 4.337-3.846 0-6.99-1.952-6.99-4.337v-14.776c0-2.385 3.144-4.337 6.99-4.337ZM404.135 250.955c3.846 0 6.989 1.952 6.989 4.337v32.528c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.952-6.989-4.337v-32.528c0-2.385 3.147-4.337 6.989-4.337ZM376.363 257.688c3.842 0 6.989 1.952 6.989 4.337v36.562c0 2.385-3.147 4.337-6.989 4.337-3.846 0-6.993-1.952-6.993-4.337v-36.562c0-2.386 3.147-4.337 6.993-4.337ZM348.587 263.26c3.846 0 6.989 1.952 6.989 4.337v36.159c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.952-6.989-4.337v-36.159c0-2.385 3.147-4.337 6.989-4.337ZM320.811 268.177c3.846 0 6.989 1.952 6.989 4.337v31.318c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.952-6.989-4.337v-31.318c0-2.385 3.147-4.337 6.989-4.337ZM293.179 288.148c3.846 0 6.989 1.952 6.989 4.337v9.935c0 2.384-3.147 4.336-6.989 4.336s-6.99-1.951-6.99-4.336v-9.935c0-2.386 3.144-4.337 6.99-4.337Z" style="fill:#b1b3b4;fill-rule:evenodd"></path><path d="M431.911 205.441c3.842 0 6.989 1.952 6.989 4.337v24.459c0 2.385-3.147 4.337-6.989 4.337-3.846 0-6.99-1.952-6.99-4.337v-24.459c0-2.385 3.144-4.337 6.99-4.337ZM404.135 189.026c3.846 0 6.989 1.952 6.989 4.337v43.622c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.951-6.989-4.337v-43.622c0-2.385 3.147-4.337 6.989-4.337ZM376.363 182.848c3.842 0 6.989 1.953 6.989 4.337v56.937c0 2.384-3.147 4.337-6.989 4.337-3.846 0-6.993-1.952-6.993-4.337v-56.937c0-2.385 3.147-4.337 6.993-4.337ZM348.587 180.15c3.846 0 6.989 1.952 6.989 4.337v66.619c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.952-6.989-4.337v-66.619c0-2.385 3.147-4.337 6.989-4.337ZM320.811 181.84c3.846 0 6.989 1.952 6.989 4.337v67.627c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.951-6.989-4.337v-67.627c0-2.386 3.147-4.337 6.989-4.337ZM293.179 186.076c3.846 0 6.989 1.952 6.989 4.337v84.37c0 2.385-3.147 4.337-6.989 4.337s-6.99-1.951-6.99-4.337v-84.37c0-2.386 3.144-4.337 6.99-4.337ZM264.829 193.262c3.846 0 6.989 1.953 6.989 4.337v95.667c0 2.385-3.143 4.337-6.989 4.337-3.843 0-6.99-1.951-6.99-4.337v-95.667c0-2.385 3.147-4.337 6.99-4.337ZM237.057 205.441c3.842 0 6.989 1.953 6.989 4.337v92.036c0 2.385-3.147 4.337-6.989 4.337-3.846 0-6.99-1.951-6.99-4.337v-92.036c0-2.385 3.144-4.337 6.99-4.337ZM209.281 221.302c3.846 0 6.989 1.952 6.989 4.337v80.134c0 2.385-3.147 4.337-6.989 4.337s-6.99-1.952-6.99-4.337v-80.134c0-2.386 3.144-4.337 6.99-4.337ZM181.505 232.271c3.846 0 6.993 1.952 6.993 4.336v78.924c0 2.385-3.147 4.337-6.993 4.337-3.842 0-6.989-1.951-6.989-4.337v-78.924c0-2.385 3.147-4.336 6.989-4.336ZM153.873 241.348c3.846 0 6.989 1.953 6.989 4.337v42.009c0 2.384-3.147 4.337-6.989 4.337-3.843 0-6.99-1.952-6.99-4.337v-42.009c0-2.385 3.147-4.337 6.99-4.337ZM125.266 200.398c3.842 0 6.989 1.953 6.989 4.337v58.55c0 2.384-3.147 4.337-6.989 4.337-3.843 0-6.99-1.951-6.99-4.337v-58.55c0-2.385 3.144-4.337 6.99-4.337ZM96.7 204.231c3.842 0 6.989 1.953 6.989 4.337v18.004c0 2.384-3.147 4.337-6.989 4.337s-6.989-1.952-6.989-4.337v-18.004c0-2.385 3.143-4.337 6.989-4.337ZM68.089 201.81c3.846 0 6.99 1.953 6.99 4.337v8.12c0 2.384-3.147 4.336-6.99 4.336-3.842 0-6.989-1.951-6.989-4.336v-8.12c0-2.385 3.143-4.337 6.989-4.337ZM153.873 194.94c3.846 0 6.989 1.953 6.989 4.337v6.102c0 2.384-3.147 4.337-6.989 4.337-3.843 0-6.99-1.952-6.99-4.337v-6.102c0-2.385 3.147-4.337 6.99-4.337Z" style="fill:#000;fill-rule:evenodd"></path></svg>

After

Width:  |  Height:  |  Size: 3.4 KiB

View File

@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" viewBox="61.1 180.15 377.8 139.718"><path d="M431.911 245.181c3.842 0 6.989 1.952 6.989 4.337v14.776c0 2.385-3.147 4.337-6.989 4.337-3.846 0-6.99-1.952-6.99-4.337v-14.776c0-2.385 3.144-4.337 6.99-4.337ZM404.135 250.955c3.846 0 6.989 1.952 6.989 4.337v32.528c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.952-6.989-4.337v-32.528c0-2.385 3.147-4.337 6.989-4.337ZM376.363 257.688c3.842 0 6.989 1.952 6.989 4.337v36.562c0 2.385-3.147 4.337-6.989 4.337-3.846 0-6.993-1.952-6.993-4.337v-36.562c0-2.386 3.147-4.337 6.993-4.337ZM348.587 263.26c3.846 0 6.989 1.952 6.989 4.337v36.159c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.952-6.989-4.337v-36.159c0-2.385 3.147-4.337 6.989-4.337ZM320.811 268.177c3.846 0 6.989 1.952 6.989 4.337v31.318c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.952-6.989-4.337v-31.318c0-2.385 3.147-4.337 6.989-4.337ZM293.179 288.148c3.846 0 6.989 1.952 6.989 4.337v9.935c0 2.384-3.147 4.336-6.989 4.336s-6.99-1.951-6.99-4.336v-9.935c0-2.386 3.144-4.337 6.99-4.337Z" style="fill:#b1b3b4;fill-rule:evenodd"></path><path d="M431.911 205.441c3.842 0 6.989 1.952 6.989 4.337v24.459c0 2.385-3.147 4.337-6.989 4.337-3.846 0-6.99-1.952-6.99-4.337v-24.459c0-2.385 3.144-4.337 6.99-4.337ZM404.135 189.026c3.846 0 6.989 1.952 6.989 4.337v43.622c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.951-6.989-4.337v-43.622c0-2.385 3.147-4.337 6.989-4.337ZM376.363 182.848c3.842 0 6.989 1.953 6.989 4.337v56.937c0 2.384-3.147 4.337-6.989 4.337-3.846 0-6.993-1.952-6.993-4.337v-56.937c0-2.385 3.147-4.337 6.993-4.337ZM348.587 180.15c3.846 0 6.989 1.952 6.989 4.337v66.619c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.952-6.989-4.337v-66.619c0-2.385 3.147-4.337 6.989-4.337ZM320.811 181.84c3.846 0 6.989 1.952 6.989 4.337v67.627c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.951-6.989-4.337v-67.627c0-2.386 3.147-4.337 6.989-4.337ZM293.179 186.076c3.846 0 6.989 1.952 6.989 4.337v84.37c0 2.385-3.147 4.337-6.989 4.337s-6.99-1.951-6.99-4.337v-84.37c0-2.386 3.144-4.337 6.99-4.337ZM264.829 193.262c3.846 0 6.989 1.953 6.989 4.337v95.667c0 2.385-3.143 4.337-6.989 4.337-3.843 0-6.99-1.951-6.99-4.337v-95.667c0-2.385 3.147-4.337 6.99-4.337ZM237.057 205.441c3.842 0 6.989 1.953 6.989 4.337v92.036c0 2.385-3.147 4.337-6.989 4.337-3.846 0-6.99-1.951-6.99-4.337v-92.036c0-2.385 3.144-4.337 6.99-4.337ZM209.281 221.302c3.846 0 6.989 1.952 6.989 4.337v80.134c0 2.385-3.147 4.337-6.989 4.337s-6.99-1.952-6.99-4.337v-80.134c0-2.386 3.144-4.337 6.99-4.337ZM181.505 232.271c3.846 0 6.993 1.952 6.993 4.336v78.924c0 2.385-3.147 4.337-6.993 4.337-3.842 0-6.989-1.951-6.989-4.337v-78.924c0-2.385 3.147-4.336 6.989-4.336ZM153.873 241.348c3.846 0 6.989 1.953 6.989 4.337v42.009c0 2.384-3.147 4.337-6.989 4.337-3.843 0-6.99-1.952-6.99-4.337v-42.009c0-2.385 3.147-4.337 6.99-4.337ZM125.266 200.398c3.842 0 6.989 1.953 6.989 4.337v58.55c0 2.384-3.147 4.337-6.989 4.337-3.843 0-6.99-1.951-6.99-4.337v-58.55c0-2.385 3.144-4.337 6.99-4.337ZM96.7 204.231c3.842 0 6.989 1.953 6.989 4.337v18.004c0 2.384-3.147 4.337-6.989 4.337s-6.989-1.952-6.989-4.337v-18.004c0-2.385 3.143-4.337 6.989-4.337ZM68.089 201.81c3.846 0 6.99 1.953 6.99 4.337v8.12c0 2.384-3.147 4.336-6.99 4.336-3.842 0-6.989-1.951-6.989-4.336v-8.12c0-2.385 3.143-4.337 6.989-4.337ZM153.873 194.94c3.846 0 6.989 1.953 6.989 4.337v6.102c0 2.384-3.147 4.337-6.989 4.337-3.843 0-6.99-1.952-6.99-4.337v-6.102c0-2.385 3.147-4.337 6.99-4.337Z" style="fill:#000;fill-rule:evenodd"></path></svg>

After

Width:  |  Height:  |  Size: 3.4 KiB

View File

@ -0,0 +1,28 @@
import logging
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
logger = logging.getLogger(__name__)
class FishAudioProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
For debugging purposes, this method now always passes validation.
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
try:
model_instance = self.get_model_instance(ModelType.TTS)
model_instance.validate_credentials(
credentials=credentials
)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
raise ex

View File

@ -0,0 +1,76 @@
provider: fishaudio
label:
en_US: Fish Audio
description:
en_US: Models provided by Fish Audio, currently only support TTS.
zh_Hans: Fish Audio 提供的模型,目前仅支持 TTS。
icon_small:
en_US: fishaudio_s_en.svg
icon_large:
en_US: fishaudio_l_en.svg
background: "#E5E7EB"
help:
title:
en_US: Get your API key from Fish Audio
zh_Hans: 从 Fish Audio 获取你的 API Key
url:
en_US: https://fish.audio/go-api/
supported_model_types:
- tts
configurate_methods:
- predefined-model
provider_credential_schema:
credential_form_schemas:
- variable: api_key
label:
en_US: API Key
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
- variable: api_base
label:
en_US: API URL
type: text-input
required: false
default: https://api.fish.audio
placeholder:
en_US: Enter your API URL
zh_Hans: 在此输入您的 API URL
- variable: use_public_models
label:
en_US: Use Public Models
type: select
required: false
default: "false"
placeholder:
en_US: Toggle to use public models
zh_Hans: 切换以使用公共模型
options:
- value: "true"
label:
en_US: Allow Public Models
zh_Hans: 使用公共模型
- value: "false"
label:
en_US: Private Models Only
zh_Hans: 仅使用私有模型
- variable: latency
label:
en_US: Latency
type: select
required: false
default: "normal"
placeholder:
en_US: Toggle to choice latency
zh_Hans: 切换以调整延迟
options:
- value: "balanced"
label:
en_US: Low (may affect quality)
zh_Hans: 低延迟 (可能降低质量)
- value: "normal"
label:
en_US: Normal
zh_Hans: 标准

View File

@ -0,0 +1,174 @@
from typing import Optional
import httpx
from core.model_runtime.errors.invoke import InvokeBadRequestError, InvokeError
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.tts_model import TTSModel
class FishAudioText2SpeechModel(TTSModel):
"""
Model class for Fish.audio Text to Speech model.
"""
def get_tts_model_voices(
self, model: str, credentials: dict, language: Optional[str] = None
) -> list:
api_base = credentials.get("api_base", "https://api.fish.audio")
api_key = credentials.get("api_key")
use_public_models = credentials.get("use_public_models", "false") == "true"
params = {
"self": str(not use_public_models).lower(),
"page_size": "100",
}
if language is not None:
if "-" in language:
language = language.split("-")[0]
params["language"] = language
results = httpx.get(
f"{api_base}/model",
headers={"Authorization": f"Bearer {api_key}"},
params=params,
)
results.raise_for_status()
data = results.json()
return [{"name": i["title"], "value": i["_id"]} for i in data["items"]]
def _invoke(
self,
model: str,
tenant_id: str,
credentials: dict,
content_text: str,
voice: str,
user: Optional[str] = None,
) -> any:
"""
Invoke text2speech model
:param model: model name
:param tenant_id: user tenant id
:param credentials: model credentials
:param voice: model timbre
:param content_text: text content to be translated
:param user: unique user id
:return: generator yielding audio chunks
"""
return self._tts_invoke_streaming(
model=model,
credentials=credentials,
content_text=content_text,
voice=voice,
)
def validate_credentials(
self, credentials: dict, user: Optional[str] = None
) -> None:
"""
Validate credentials for text2speech model
:param credentials: model credentials
:param user: unique user id
"""
try:
self.get_tts_model_voices(
None,
credentials={
"api_key": credentials["api_key"],
"api_base": credentials["api_base"],
# Disable public models will trigger a 403 error if user is not logged in
"use_public_models": "false",
},
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def _tts_invoke_streaming(
self, model: str, credentials: dict, content_text: str, voice: str
) -> any:
"""
Invoke streaming text2speech model
:param model: model name
:param credentials: model credentials
:param content_text: text content to be translated
:param voice: ID of the reference audio (if any)
:return: generator yielding audio chunks
"""
try:
word_limit = self._get_model_word_limit(model, credentials)
if len(content_text) > word_limit:
sentences = self._split_text_into_sentences(
content_text, max_length=word_limit
)
else:
sentences = [content_text.strip()]
for i in range(len(sentences)):
yield from self._tts_invoke_streaming_sentence(
credentials=credentials, content_text=sentences[i], voice=voice
)
except Exception as ex:
raise InvokeBadRequestError(str(ex))
def _tts_invoke_streaming_sentence(
self, credentials: dict, content_text: str, voice: Optional[str] = None
) -> any:
"""
Invoke streaming text2speech model
:param credentials: model credentials
:param content_text: text content to be translated
:param voice: ID of the reference audio (if any)
:return: generator yielding audio chunks
"""
api_key = credentials.get("api_key")
api_url = credentials.get("api_base", "https://api.fish.audio")
latency = credentials.get("latency")
if not api_key:
raise InvokeBadRequestError("API key is required")
with httpx.stream(
"POST",
api_url + "/v1/tts",
json={
"text": content_text,
"reference_id": voice,
"latency": latency
},
headers={
"Authorization": f"Bearer {api_key}",
},
timeout=None,
) as response:
if response.status_code != 200:
raise InvokeBadRequestError(
f"Error: {response.status_code} - {response.text}"
)
yield from response.iter_bytes()
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeBadRequestError: [
httpx.HTTPStatusError,
],
}

View File

@ -0,0 +1,5 @@
model: tts-default
model_type: tts
model_properties:
word_limit: 1000
audio_type: 'mp3'

View File

@ -0,0 +1,82 @@
import os
from collections.abc import Callable
from typing import Literal
import httpx
import pytest
from _pytest.monkeypatch import MonkeyPatch
def mock_get(*args, **kwargs):
if kwargs.get("headers", {}).get("Authorization") != "Bearer test":
raise httpx.HTTPStatusError(
"Invalid API key",
request=httpx.Request("GET", ""),
response=httpx.Response(401),
)
return httpx.Response(
200,
json={
"items": [
{"title": "Model 1", "_id": "model1"},
{"title": "Model 2", "_id": "model2"},
]
},
request=httpx.Request("GET", ""),
)
def mock_stream(*args, **kwargs):
class MockStreamResponse:
def __init__(self):
self.status_code = 200
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
pass
def iter_bytes(self):
yield b"Mocked audio data"
return MockStreamResponse()
def mock_fishaudio(
monkeypatch: MonkeyPatch,
methods: list[Literal["list-models", "tts"]],
) -> Callable[[], None]:
"""
mock fishaudio module
:param monkeypatch: pytest monkeypatch fixture
:return: unpatch function
"""
def unpatch() -> None:
monkeypatch.undo()
if "list-models" in methods:
monkeypatch.setattr(httpx, "get", mock_get)
if "tts" in methods:
monkeypatch.setattr(httpx, "stream", mock_stream)
return unpatch
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
@pytest.fixture
def setup_fishaudio_mock(request, monkeypatch):
methods = request.param if hasattr(request, "param") else []
if MOCK:
unpatch = mock_fishaudio(monkeypatch, methods=methods)
yield
if MOCK:
unpatch()

View File

@ -0,0 +1,33 @@
import os
import httpx
import pytest
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.fishaudio.fishaudio import FishAudioProvider
from tests.integration_tests.model_runtime.__mock.fishaudio import setup_fishaudio_mock
@pytest.mark.parametrize("setup_fishaudio_mock", [["list-models"]], indirect=True)
def test_validate_provider_credentials(setup_fishaudio_mock):
print("-----", httpx.get)
provider = FishAudioProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(
credentials={
"api_key": "bad_api_key",
"api_base": os.environ.get("FISH_AUDIO_API_BASE", "https://api.fish.audio"),
"use_public_models": "false",
"latency": "normal",
}
)
provider.validate_provider_credentials(
credentials={
"api_key": os.environ.get("FISH_AUDIO_API_KEY", "test"),
"api_base": os.environ.get("FISH_AUDIO_API_BASE", "https://api.fish.audio"),
"use_public_models": "false",
"latency": "normal",
}
)

View File

@ -0,0 +1,32 @@
import os
import pytest
from core.model_runtime.model_providers.fishaudio.tts.tts import (
FishAudioText2SpeechModel,
)
from tests.integration_tests.model_runtime.__mock.fishaudio import setup_fishaudio_mock
@pytest.mark.parametrize("setup_fishaudio_mock", [["tts"]], indirect=True)
def test_invoke_model(setup_fishaudio_mock):
model = FishAudioText2SpeechModel()
result = model.invoke(
model="tts-default",
tenant_id="test",
credentials={
"api_key": os.environ.get("FISH_AUDIO_API_KEY", "test"),
"api_base": os.environ.get("FISH_AUDIO_API_BASE", "https://api.fish.audio"),
"use_public_models": "false",
"latency": "normal",
},
content_text="Hello, world!",
voice="03397b4c4be74759b72533b663fbd001",
)
content = b""
for chunk in result:
content += chunk
assert content != b""