diff --git a/api/core/model_runtime/model_providers/fishaudio/__init__.py b/api/core/model_runtime/model_providers/fishaudio/__init__.py new file mode 100644 index 0000000000..5f282702bb --- /dev/null +++ b/api/core/model_runtime/model_providers/fishaudio/__init__.py @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/fishaudio/_assets/fishaudio_l_en.svg b/api/core/model_runtime/model_providers/fishaudio/_assets/fishaudio_l_en.svg new file mode 100644 index 0000000000..d6f7723bd5 --- /dev/null +++ b/api/core/model_runtime/model_providers/fishaudio/_assets/fishaudio_l_en.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/fishaudio/_assets/fishaudio_s_en.svg b/api/core/model_runtime/model_providers/fishaudio/_assets/fishaudio_s_en.svg new file mode 100644 index 0000000000..d6f7723bd5 --- /dev/null +++ b/api/core/model_runtime/model_providers/fishaudio/_assets/fishaudio_s_en.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/fishaudio/fishaudio.py b/api/core/model_runtime/model_providers/fishaudio/fishaudio.py new file mode 100644 index 0000000000..9f80996d9d --- /dev/null +++ b/api/core/model_runtime/model_providers/fishaudio/fishaudio.py @@ -0,0 +1,28 @@ +import logging + +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class FishAudioProvider(ModelProvider): + def validate_provider_credentials(self, credentials: dict) -> None: + """ + Validate provider credentials + + For debugging purposes, this method now always passes validation. + + :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. + """ + try: + model_instance = self.get_model_instance(ModelType.TTS) + model_instance.validate_credentials( + credentials=credentials + ) + except CredentialsValidateFailedError as ex: + raise ex + except Exception as ex: + logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + raise ex diff --git a/api/core/model_runtime/model_providers/fishaudio/fishaudio.yaml b/api/core/model_runtime/model_providers/fishaudio/fishaudio.yaml new file mode 100644 index 0000000000..479eb7fb85 --- /dev/null +++ b/api/core/model_runtime/model_providers/fishaudio/fishaudio.yaml @@ -0,0 +1,76 @@ +provider: fishaudio +label: + en_US: Fish Audio +description: + en_US: Models provided by Fish Audio, currently only support TTS. + zh_Hans: Fish Audio 提供的模型,目前仅支持 TTS。 +icon_small: + en_US: fishaudio_s_en.svg +icon_large: + en_US: fishaudio_l_en.svg +background: "#E5E7EB" +help: + title: + en_US: Get your API key from Fish Audio + zh_Hans: 从 Fish Audio 获取你的 API Key + url: + en_US: https://fish.audio/go-api/ +supported_model_types: + - tts +configurate_methods: + - predefined-model +provider_credential_schema: + credential_form_schemas: + - variable: api_key + label: + en_US: API Key + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入您的 API Key + en_US: Enter your API Key + - variable: api_base + label: + en_US: API URL + type: text-input + required: false + default: https://api.fish.audio + placeholder: + en_US: Enter your API URL + zh_Hans: 在此输入您的 API URL + - variable: use_public_models + label: + en_US: Use Public Models + type: select + required: false + default: "false" + placeholder: + en_US: Toggle to use public models + zh_Hans: 切换以使用公共模型 + options: + - value: "true" + label: + en_US: Allow Public Models + zh_Hans: 使用公共模型 + - value: "false" + label: + en_US: Private Models Only + zh_Hans: 仅使用私有模型 + - variable: latency + label: + en_US: Latency + type: select + required: false + default: "normal" + placeholder: + en_US: Toggle to choice latency + zh_Hans: 切换以调整延迟 + options: + - value: "balanced" + label: + en_US: Low (may affect quality) + zh_Hans: 低延迟 (可能降低质量) + - value: "normal" + label: + en_US: Normal + zh_Hans: 标准 diff --git a/api/core/model_runtime/model_providers/fishaudio/tts/__init__.py b/api/core/model_runtime/model_providers/fishaudio/tts/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/fishaudio/tts/tts.py b/api/core/model_runtime/model_providers/fishaudio/tts/tts.py new file mode 100644 index 0000000000..5b673ce186 --- /dev/null +++ b/api/core/model_runtime/model_providers/fishaudio/tts/tts.py @@ -0,0 +1,174 @@ +from typing import Optional + +import httpx + +from core.model_runtime.errors.invoke import InvokeBadRequestError, InvokeError +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.tts_model import TTSModel + + +class FishAudioText2SpeechModel(TTSModel): + """ + Model class for Fish.audio Text to Speech model. + """ + + def get_tts_model_voices( + self, model: str, credentials: dict, language: Optional[str] = None + ) -> list: + api_base = credentials.get("api_base", "https://api.fish.audio") + api_key = credentials.get("api_key") + use_public_models = credentials.get("use_public_models", "false") == "true" + + params = { + "self": str(not use_public_models).lower(), + "page_size": "100", + } + + if language is not None: + if "-" in language: + language = language.split("-")[0] + params["language"] = language + + results = httpx.get( + f"{api_base}/model", + headers={"Authorization": f"Bearer {api_key}"}, + params=params, + ) + + results.raise_for_status() + data = results.json() + + return [{"name": i["title"], "value": i["_id"]} for i in data["items"]] + + def _invoke( + self, + model: str, + tenant_id: str, + credentials: dict, + content_text: str, + voice: str, + user: Optional[str] = None, + ) -> any: + """ + Invoke text2speech model + + :param model: model name + :param tenant_id: user tenant id + :param credentials: model credentials + :param voice: model timbre + :param content_text: text content to be translated + :param user: unique user id + :return: generator yielding audio chunks + """ + + return self._tts_invoke_streaming( + model=model, + credentials=credentials, + content_text=content_text, + voice=voice, + ) + + def validate_credentials( + self, credentials: dict, user: Optional[str] = None + ) -> None: + """ + Validate credentials for text2speech model + + :param credentials: model credentials + :param user: unique user id + """ + + try: + self.get_tts_model_voices( + None, + credentials={ + "api_key": credentials["api_key"], + "api_base": credentials["api_base"], + # Disable public models will trigger a 403 error if user is not logged in + "use_public_models": "false", + }, + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + def _tts_invoke_streaming( + self, model: str, credentials: dict, content_text: str, voice: str + ) -> any: + """ + Invoke streaming text2speech model + :param model: model name + :param credentials: model credentials + :param content_text: text content to be translated + :param voice: ID of the reference audio (if any) + :return: generator yielding audio chunks + """ + + try: + word_limit = self._get_model_word_limit(model, credentials) + if len(content_text) > word_limit: + sentences = self._split_text_into_sentences( + content_text, max_length=word_limit + ) + else: + sentences = [content_text.strip()] + + for i in range(len(sentences)): + yield from self._tts_invoke_streaming_sentence( + credentials=credentials, content_text=sentences[i], voice=voice + ) + + except Exception as ex: + raise InvokeBadRequestError(str(ex)) + + def _tts_invoke_streaming_sentence( + self, credentials: dict, content_text: str, voice: Optional[str] = None + ) -> any: + """ + Invoke streaming text2speech model + + :param credentials: model credentials + :param content_text: text content to be translated + :param voice: ID of the reference audio (if any) + :return: generator yielding audio chunks + """ + api_key = credentials.get("api_key") + api_url = credentials.get("api_base", "https://api.fish.audio") + latency = credentials.get("latency") + + if not api_key: + raise InvokeBadRequestError("API key is required") + + with httpx.stream( + "POST", + api_url + "/v1/tts", + json={ + "text": content_text, + "reference_id": voice, + "latency": latency + }, + headers={ + "Authorization": f"Bearer {api_key}", + }, + timeout=None, + ) as response: + if response.status_code != 200: + raise InvokeBadRequestError( + f"Error: {response.status_code} - {response.text}" + ) + yield from response.iter_bytes() + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeBadRequestError: [ + httpx.HTTPStatusError, + ], + } diff --git a/api/core/model_runtime/model_providers/fishaudio/tts/tts.yaml b/api/core/model_runtime/model_providers/fishaudio/tts/tts.yaml new file mode 100644 index 0000000000..b4a446a957 --- /dev/null +++ b/api/core/model_runtime/model_providers/fishaudio/tts/tts.yaml @@ -0,0 +1,5 @@ +model: tts-default +model_type: tts +model_properties: + word_limit: 1000 + audio_type: 'mp3' diff --git a/api/tests/integration_tests/model_runtime/__mock/fishaudio.py b/api/tests/integration_tests/model_runtime/__mock/fishaudio.py new file mode 100644 index 0000000000..bec3babeaf --- /dev/null +++ b/api/tests/integration_tests/model_runtime/__mock/fishaudio.py @@ -0,0 +1,82 @@ +import os +from collections.abc import Callable +from typing import Literal + +import httpx +import pytest +from _pytest.monkeypatch import MonkeyPatch + + +def mock_get(*args, **kwargs): + if kwargs.get("headers", {}).get("Authorization") != "Bearer test": + raise httpx.HTTPStatusError( + "Invalid API key", + request=httpx.Request("GET", ""), + response=httpx.Response(401), + ) + + return httpx.Response( + 200, + json={ + "items": [ + {"title": "Model 1", "_id": "model1"}, + {"title": "Model 2", "_id": "model2"}, + ] + }, + request=httpx.Request("GET", ""), + ) + + +def mock_stream(*args, **kwargs): + class MockStreamResponse: + def __init__(self): + self.status_code = 200 + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + def iter_bytes(self): + yield b"Mocked audio data" + + return MockStreamResponse() + + +def mock_fishaudio( + monkeypatch: MonkeyPatch, + methods: list[Literal["list-models", "tts"]], +) -> Callable[[], None]: + """ + mock fishaudio module + + :param monkeypatch: pytest monkeypatch fixture + :return: unpatch function + """ + + def unpatch() -> None: + monkeypatch.undo() + + if "list-models" in methods: + monkeypatch.setattr(httpx, "get", mock_get) + + if "tts" in methods: + monkeypatch.setattr(httpx, "stream", mock_stream) + + return unpatch + + +MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" + + +@pytest.fixture +def setup_fishaudio_mock(request, monkeypatch): + methods = request.param if hasattr(request, "param") else [] + if MOCK: + unpatch = mock_fishaudio(monkeypatch, methods=methods) + + yield + + if MOCK: + unpatch() diff --git a/api/tests/integration_tests/model_runtime/fishaudio/__init__.py b/api/tests/integration_tests/model_runtime/fishaudio/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/model_runtime/fishaudio/test_provider.py b/api/tests/integration_tests/model_runtime/fishaudio/test_provider.py new file mode 100644 index 0000000000..3526574b61 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/fishaudio/test_provider.py @@ -0,0 +1,33 @@ +import os + +import httpx +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.fishaudio.fishaudio import FishAudioProvider +from tests.integration_tests.model_runtime.__mock.fishaudio import setup_fishaudio_mock + + +@pytest.mark.parametrize("setup_fishaudio_mock", [["list-models"]], indirect=True) +def test_validate_provider_credentials(setup_fishaudio_mock): + print("-----", httpx.get) + provider = FishAudioProvider() + + with pytest.raises(CredentialsValidateFailedError): + provider.validate_provider_credentials( + credentials={ + "api_key": "bad_api_key", + "api_base": os.environ.get("FISH_AUDIO_API_BASE", "https://api.fish.audio"), + "use_public_models": "false", + "latency": "normal", + } + ) + + provider.validate_provider_credentials( + credentials={ + "api_key": os.environ.get("FISH_AUDIO_API_KEY", "test"), + "api_base": os.environ.get("FISH_AUDIO_API_BASE", "https://api.fish.audio"), + "use_public_models": "false", + "latency": "normal", + } + ) diff --git a/api/tests/integration_tests/model_runtime/fishaudio/test_tts.py b/api/tests/integration_tests/model_runtime/fishaudio/test_tts.py new file mode 100644 index 0000000000..f61fee28b9 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/fishaudio/test_tts.py @@ -0,0 +1,32 @@ +import os + +import pytest + +from core.model_runtime.model_providers.fishaudio.tts.tts import ( + FishAudioText2SpeechModel, +) +from tests.integration_tests.model_runtime.__mock.fishaudio import setup_fishaudio_mock + + +@pytest.mark.parametrize("setup_fishaudio_mock", [["tts"]], indirect=True) +def test_invoke_model(setup_fishaudio_mock): + model = FishAudioText2SpeechModel() + + result = model.invoke( + model="tts-default", + tenant_id="test", + credentials={ + "api_key": os.environ.get("FISH_AUDIO_API_KEY", "test"), + "api_base": os.environ.get("FISH_AUDIO_API_BASE", "https://api.fish.audio"), + "use_public_models": "false", + "latency": "normal", + }, + content_text="Hello, world!", + voice="03397b4c4be74759b72533b663fbd001", + ) + + content = b"" + for chunk in result: + content += chunk + + assert content != b""