From 6ef401a9f0b557a37838f6cfd0446f2e97bd5b1f Mon Sep 17 00:00:00 2001 From: chenxu9741 Date: Tue, 9 Jul 2024 11:33:58 +0800 Subject: [PATCH] feat:add tts-streaming config and future (#5492) --- api/constants/tts_auto_play_timeout.py | 4 + api/controllers/console/app/audio.py | 31 ++- api/controllers/console/explore/audio.py | 30 +- api/controllers/service_api/app/audio.py | 33 ++- api/controllers/web/audio.py | 29 +- .../app_generator_tts_publisher.py | 135 +++++++++ .../advanced_chat/generate_task_pipeline.py | 97 +++++-- api/core/app/apps/base_app_queue_manager.py | 1 - .../apps/workflow/generate_task_pipeline.py | 63 ++++- api/core/app/entities/task_entities.py | 37 +++ .../easy_ui_based_generate_task_pipeline.py | 66 ++++- api/core/model_manager.py | 5 +- .../model_providers/__base/tts_model.py | 51 ++-- .../model_providers/azure_openai/tts/tts.py | 66 ++--- .../model_providers/openai/tts/tts-1-hd.yaml | 2 +- .../model_providers/openai/tts/tts-1.yaml | 2 +- .../model_providers/openai/tts/tts.py | 62 ++--- .../model_providers/tongyi/tts/tts-1.yaml | 2 +- .../model_providers/tongyi/tts/tts.py | 97 +++++-- api/pyproject.toml | 2 +- api/services/app_service.py | 2 + api/services/audio_service.py | 104 ++++--- .../config-voice/param-config-content.tsx | 64 ++++- .../chat-group/text-to-speech/index.tsx | 1 - .../app/text-generate/item/index.tsx | 3 +- .../base/audio-btn/audio.player.manager.ts | 53 ++++ web/app/components/base/audio-btn/audio.ts | 263 ++++++++++++++++++ web/app/components/base/audio-btn/index.tsx | 133 +++------ .../base/chat/chat/answer/index.tsx | 17 +- .../base/chat/chat/answer/operation.tsx | 2 +- web/app/components/base/chat/chat/hooks.ts | 28 +- .../text-to-speech/param-config-content.tsx | 66 ++++- web/app/components/base/features/types.ts | 3 +- .../workflow/hooks/use-workflow-run.ts | 27 ++ web/i18n/en-US/app-debug.ts | 3 + web/i18n/ja-JP/app-debug.ts | 3 + web/i18n/zh-Hans/app-debug.ts | 3 + web/i18n/zh-Hant/app-debug.ts | 3 + web/models/debug.ts | 3 +- web/next.config.js | 1 + web/service/apps.ts | 1 + web/service/base.ts | 22 +- web/service/share.ts | 12 +- web/types/app.ts | 6 + 44 files changed, 1280 insertions(+), 358 deletions(-) create mode 100644 api/constants/tts_auto_play_timeout.py create mode 100644 api/core/app/apps/advanced_chat/app_generator_tts_publisher.py create mode 100644 web/app/components/base/audio-btn/audio.player.manager.ts create mode 100644 web/app/components/base/audio-btn/audio.ts diff --git a/api/constants/tts_auto_play_timeout.py b/api/constants/tts_auto_play_timeout.py new file mode 100644 index 0000000000..d5ed30830a --- /dev/null +++ b/api/constants/tts_auto_play_timeout.py @@ -0,0 +1,4 @@ +TTS_AUTO_PLAY_TIMEOUT = 5 + +# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file) +TTS_AUTO_PLAY_YIELD_CPU_TIME = 0.02 diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 51322c92d3..1de08afa4e 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -81,15 +81,36 @@ class ChatMessageTextApi(Resource): @account_initialization_required @get_app_model def post(self, app_model): + from werkzeug.exceptions import InternalServerError + try: + parser = reqparse.RequestParser() + parser.add_argument('message_id', type=str, location='json') + parser.add_argument('text', type=str, location='json') + parser.add_argument('voice', type=str, location='json') + parser.add_argument('streaming', type=bool, location='json') + args = parser.parse_args() + + message_id = args.get('message_id', None) + text = args.get('text', None) + if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] + and app_model.workflow + and app_model.workflow.features_dict): + text_to_speech = app_model.workflow.features_dict.get('text_to_speech') + voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice') + else: + try: + voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get( + 'voice') + except Exception: + voice = None response = AudioService.transcript_tts( app_model=app_model, - text=request.form['text'], - voice=request.form['voice'], - streaming=False + text=text, + message_id=message_id, + voice=voice ) - - return {'data': response.data.decode('latin1')} + return response except services.errors.app_model_config.AppModelConfigBrokenError: logging.exception("App model config broken.") raise AppUnavailableError() diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index d869cd38ed..920b1d8383 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -19,6 +19,7 @@ from controllers.console.app.error import ( from controllers.console.explore.wraps import InstalledAppResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError +from models.model import AppMode from services.audio_service import AudioService from services.errors.audio import ( AudioTooLargeServiceError, @@ -70,16 +71,33 @@ class ChatAudioApi(InstalledAppResource): class ChatTextApi(InstalledAppResource): def post(self, installed_app): - app_model = installed_app.app + from flask_restful import reqparse + app_model = installed_app.app try: + parser = reqparse.RequestParser() + parser.add_argument('message_id', type=str, required=False, location='json') + parser.add_argument('voice', type=str, location='json') + parser.add_argument('streaming', type=bool, location='json') + args = parser.parse_args() + + message_id = args.get('message_id') + if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] + and app_model.workflow + and app_model.workflow.features_dict): + text_to_speech = app_model.workflow.features_dict.get('text_to_speech') + voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice') + else: + try: + voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get('voice') + except Exception: + voice = None response = AudioService.transcript_tts( app_model=app_model, - text=request.form['text'], - voice=request.form['voice'] if request.form.get('voice') else app_model.app_model_config.text_to_speech_dict.get('voice'), - streaming=False + message_id=message_id, + voice=voice ) - return {'data': response.data.decode('latin1')} + return response except services.errors.app_model_config.AppModelConfigBrokenError: logging.exception("App model config broken.") raise AppUnavailableError() @@ -108,3 +126,5 @@ class ChatTextApi(InstalledAppResource): api.add_resource(ChatAudioApi, '/installed-apps//audio-to-text', endpoint='installed_app_audio') api.add_resource(ChatTextApi, '/installed-apps//text-to-audio', endpoint='installed_app_text') +# api.add_resource(ChatTextApiWithMessageId, '/installed-apps//text-to-audio/message-id', +# endpoint='installed_app_text_with_message_id') diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index 15c0a153b8..607d71598f 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -20,7 +20,7 @@ from controllers.service_api.app.error import ( from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError -from models.model import App, EndUser +from models.model import App, AppMode, EndUser from services.audio_service import AudioService from services.errors.audio import ( AudioTooLargeServiceError, @@ -72,19 +72,30 @@ class AudioApi(Resource): class TextApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) def post(self, app_model: App, end_user: EndUser): - parser = reqparse.RequestParser() - parser.add_argument('text', type=str, required=True, nullable=False, location='json') - parser.add_argument('voice', type=str, location='json') - parser.add_argument('streaming', type=bool, required=False, nullable=False, location='json') - args = parser.parse_args() - try: + parser = reqparse.RequestParser() + parser.add_argument('message_id', type=str, required=False, location='json') + parser.add_argument('voice', type=str, location='json') + parser.add_argument('streaming', type=bool, location='json') + args = parser.parse_args() + + message_id = args.get('message_id') + if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] + and app_model.workflow + and app_model.workflow.features_dict): + text_to_speech = app_model.workflow.features_dict.get('text_to_speech') + voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice') + else: + try: + voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get( + 'voice') + except Exception: + voice = None response = AudioService.transcript_tts( app_model=app_model, - text=args['text'], - end_user=end_user, - voice=args.get('voice'), - streaming=args['streaming'] + message_id=message_id, + end_user=end_user.external_user_id, + voice=voice ) return response diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index 3fed7895e2..8be872f5f9 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -19,7 +19,7 @@ from controllers.web.error import ( from controllers.web.wraps import WebApiResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError -from models.model import App +from models.model import App, AppMode from services.audio_service import AudioService from services.errors.audio import ( AudioTooLargeServiceError, @@ -69,16 +69,35 @@ class AudioApi(WebApiResource): class TextApi(WebApiResource): def post(self, app_model: App, end_user): + from flask_restful import reqparse try: + parser = reqparse.RequestParser() + parser.add_argument('message_id', type=str, required=False, location='json') + parser.add_argument('voice', type=str, location='json') + parser.add_argument('streaming', type=bool, location='json') + args = parser.parse_args() + + message_id = args.get('message_id') + if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] + and app_model.workflow + and app_model.workflow.features_dict): + text_to_speech = app_model.workflow.features_dict.get('text_to_speech') + voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice') + else: + try: + voice = args.get('voice') if args.get( + 'voice') else app_model.app_model_config.text_to_speech_dict.get('voice') + except Exception: + voice = None + response = AudioService.transcript_tts( app_model=app_model, - text=request.form['text'], + message_id=message_id, end_user=end_user.external_user_id, - voice=request.form['voice'] if request.form.get('voice') else None, - streaming=False + voice=voice ) - return {'data': response.data.decode('latin1')} + return response except services.errors.app_model_config.AppModelConfigBrokenError: logging.exception("App model config broken.") raise AppUnavailableError() diff --git a/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py b/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py new file mode 100644 index 0000000000..8325994608 --- /dev/null +++ b/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py @@ -0,0 +1,135 @@ +import base64 +import concurrent.futures +import logging +import queue +import re +import threading + +from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueTextChunkEvent +from core.model_manager import ModelManager +from core.model_runtime.entities.model_entities import ModelType + + +class AudioTrunk: + def __init__(self, status: str, audio): + self.audio = audio + self.status = status + + +def _invoiceTTS(text_content: str, model_instance, tenant_id: str, voice: str): + if not text_content or text_content.isspace(): + return + return model_instance.invoke_tts( + content_text=text_content.strip(), + user="responding_tts", + tenant_id=tenant_id, + voice=voice + ) + + +def _process_future(future_queue, audio_queue): + while True: + try: + future = future_queue.get() + if future is None: + break + for audio in future.result(): + audio_base64 = base64.b64encode(bytes(audio)) + audio_queue.put(AudioTrunk("responding", audio=audio_base64)) + except Exception as e: + logging.getLogger(__name__).warning(e) + break + audio_queue.put(AudioTrunk("finish", b'')) + + +class AppGeneratorTTSPublisher: + + def __init__(self, tenant_id: str, voice: str): + self.logger = logging.getLogger(__name__) + self.tenant_id = tenant_id + self.msg_text = '' + self._audio_queue = queue.Queue() + self._msg_queue = queue.Queue() + self.match = re.compile(r'[。.!?]') + self.model_manager = ModelManager() + self.model_instance = self.model_manager.get_default_model_instance( + tenant_id=self.tenant_id, + model_type=ModelType.TTS + ) + self.voices = self.model_instance.get_tts_voices() + values = [voice.get('value') for voice in self.voices] + self.voice = voice + if not voice or voice not in values: + self.voice = self.voices[0].get('value') + self.MAX_SENTENCE = 2 + self._last_audio_event = None + self._runtime_thread = threading.Thread(target=self._runtime).start() + self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=3) + + def publish(self, message): + try: + self._msg_queue.put(message) + except Exception as e: + self.logger.warning(e) + + def _runtime(self): + future_queue = queue.Queue() + threading.Thread(target=_process_future, args=(future_queue, self._audio_queue)).start() + while True: + try: + message = self._msg_queue.get() + if message is None: + if self.msg_text and len(self.msg_text.strip()) > 0: + futures_result = self.executor.submit(_invoiceTTS, self.msg_text, + self.model_instance, self.tenant_id, self.voice) + future_queue.put(futures_result) + break + elif isinstance(message.event, QueueAgentMessageEvent | QueueLLMChunkEvent): + self.msg_text += message.event.chunk.delta.message.content + elif isinstance(message.event, QueueTextChunkEvent): + self.msg_text += message.event.text + self.last_message = message + sentence_arr, text_tmp = self._extract_sentence(self.msg_text) + if len(sentence_arr) >= min(self.MAX_SENTENCE, 7): + self.MAX_SENTENCE += 1 + text_content = ''.join(sentence_arr) + futures_result = self.executor.submit(_invoiceTTS, text_content, + self.model_instance, + self.tenant_id, + self.voice) + future_queue.put(futures_result) + if text_tmp: + self.msg_text = text_tmp + else: + self.msg_text = '' + + except Exception as e: + self.logger.warning(e) + break + future_queue.put(None) + + def checkAndGetAudio(self) -> AudioTrunk | None: + try: + if self._last_audio_event and self._last_audio_event.status == "finish": + if self.executor: + self.executor.shutdown(wait=False) + return self.last_message + audio = self._audio_queue.get_nowait() + if audio and audio.status == "finish": + self.executor.shutdown(wait=False) + self._runtime_thread = None + if audio: + self._last_audio_event = audio + return audio + except queue.Empty: + return None + + def _extract_sentence(self, org_text): + tx = self.match.finditer(org_text) + start = 0 + result = [] + for i in tx: + end = i.regs[0][1] + result.append(org_text[start:end]) + start = end + return result, org_text[start:] diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 5ca0fe2191..4b089f033f 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -4,6 +4,8 @@ import time from collections.abc import Generator from typing import Any, Optional, Union, cast +from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME +from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import ( AdvancedChatAppGenerateEntity, @@ -33,6 +35,8 @@ from core.app.entities.task_entities import ( ChatbotAppStreamResponse, ChatflowStreamGenerateRoute, ErrorStreamResponse, + MessageAudioEndStreamResponse, + MessageAudioStreamResponse, MessageEndStreamResponse, StreamResponse, ) @@ -71,13 +75,13 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc _iteration_nested_relations: dict[str, list[str]] def __init__( - self, application_generate_entity: AdvancedChatAppGenerateEntity, - workflow: Workflow, - queue_manager: AppQueueManager, - conversation: Conversation, - message: Message, - user: Union[Account, EndUser], - stream: bool + self, application_generate_entity: AdvancedChatAppGenerateEntity, + workflow: Workflow, + queue_manager: AppQueueManager, + conversation: Conversation, + message: Message, + user: Union[Account, EndUser], + stream: bool ) -> None: """ Initialize AdvancedChatAppGenerateTaskPipeline. @@ -129,7 +133,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc self._application_generate_entity.query ) - generator = self._process_stream_response( + generator = self._wrapper_process_stream_response( trace_manager=self._application_generate_entity.trace_manager ) if self._stream: @@ -138,7 +142,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc return self._to_blocking_response(generator) def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) \ - -> ChatbotAppBlockingResponse: + -> ChatbotAppBlockingResponse: """ Process blocking response. :return: @@ -169,7 +173,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc raise Exception('Queue listening stopped unexpectedly.') def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) \ - -> Generator[ChatbotAppStreamResponse, None, None]: + -> Generator[ChatbotAppStreamResponse, None, None]: """ To stream response. :return: @@ -182,14 +186,68 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc stream_response=stream_response ) + def _listenAudioMsg(self, publisher, task_id: str): + if not publisher: + return None + audio_msg: AudioTrunk = publisher.checkAndGetAudio() + if audio_msg and audio_msg.status != "finish": + return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) + return None + + def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \ + Generator[StreamResponse, None, None]: + + publisher = None + task_id = self._application_generate_entity.task_id + tenant_id = self._application_generate_entity.app_config.tenant_id + features_dict = self._workflow.features_dict + + if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[ + 'text_to_speech'].get('autoPlay') == 'enabled': + publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice')) + for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager): + while True: + audio_response = self._listenAudioMsg(publisher, task_id=task_id) + if audio_response: + yield audio_response + else: + break + yield response + + start_listener_time = time.time() + # timeout + while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT: + try: + if not publisher: + break + audio_trunk = publisher.checkAndGetAudio() + if audio_trunk is None: + # release cpu + # sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file) + time.sleep(TTS_AUTO_PLAY_YIELD_CPU_TIME) + continue + if audio_trunk.status == "finish": + break + else: + start_listener_time = time.time() + yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id) + except Exception as e: + logger.error(e) + break + yield MessageAudioEndStreamResponse(audio='', task_id=task_id) + def _process_stream_response( - self, trace_manager: Optional[TraceQueueManager] = None + self, + publisher: AppGeneratorTTSPublisher, + trace_manager: Optional[TraceQueueManager] = None ) -> Generator[StreamResponse, None, None]: """ Process stream response. :return: """ for message in self._queue_manager.listen(): + if publisher: + publisher.publish(message=message) event = message.event if isinstance(event, QueueErrorEvent): @@ -301,7 +359,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc continue if not self._is_stream_out_support( - event=event + event=event ): continue @@ -318,7 +376,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc yield self._ping_stream_response() else: continue - + if publisher: + publisher.publish(None) if self._conversation_name_generate_thread: self._conversation_name_generate_thread.join() @@ -402,7 +461,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc return stream_generate_routes def _get_answer_start_at_node_ids(self, graph: dict, target_node_id: str) \ - -> list[str]: + -> list[str]: """ Get answer start at node id. :param graph: graph @@ -457,7 +516,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc start_node_id = target_node_id start_node_ids.append(start_node_id) elif node_type == NodeType.START.value or \ - node_iteration_id is not None and iteration_start_node_id == source_node.get('id'): + node_iteration_id is not None and iteration_start_node_id == source_node.get('id'): start_node_id = source_node_id start_node_ids.append(start_node_id) else: @@ -515,7 +574,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc # all route chunks are generated if self._task_state.current_stream_generate_state.current_route_position == len( - self._task_state.current_stream_generate_state.generate_route + self._task_state.current_stream_generate_state.generate_route ): self._task_state.current_stream_generate_state = None @@ -525,7 +584,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc :return: """ if not self._task_state.current_stream_generate_state: - return None + return route_chunks = self._task_state.current_stream_generate_state.generate_route[ self._task_state.current_stream_generate_state.current_route_position:] @@ -573,7 +632,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc # get route chunk node execution info route_chunk_node_execution_info = self._task_state.ran_node_execution_infos[route_chunk_node_id] if (route_chunk_node_execution_info.node_type == NodeType.LLM - and latest_node_execution_info.node_type == NodeType.LLM): + and latest_node_execution_info.node_type == NodeType.LLM): # only LLM support chunk stream output self._task_state.current_stream_generate_state.current_route_position += 1 continue @@ -643,7 +702,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc # all route chunks are generated if self._task_state.current_stream_generate_state.current_route_position == len( - self._task_state.current_stream_generate_state.generate_route + self._task_state.current_stream_generate_state.generate_route ): self._task_state.current_stream_generate_state = None diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index b5c49d65c2..dd2343d0b1 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -51,7 +51,6 @@ class AppQueueManager: listen_timeout = current_app.config.get("APP_MAX_EXECUTION_TIME") start_time = time.time() last_ping_time = 0 - while True: try: message = self._q.get(timeout=1) diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index f4bd396f46..2b4362150f 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -1,7 +1,10 @@ import logging +import time from collections.abc import Generator from typing import Any, Optional, Union +from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME +from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import ( InvokeFrom, @@ -25,6 +28,8 @@ from core.app.entities.queue_entities import ( ) from core.app.entities.task_entities import ( ErrorStreamResponse, + MessageAudioEndStreamResponse, + MessageAudioStreamResponse, StreamResponse, TextChunkStreamResponse, TextReplaceStreamResponse, @@ -105,7 +110,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa db.session.refresh(self._user) db.session.close() - generator = self._process_stream_response( + generator = self._wrapper_process_stream_response( trace_manager=self._application_generate_entity.trace_manager ) if self._stream: @@ -161,8 +166,58 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa stream_response=stream_response ) + def _listenAudioMsg(self, publisher, task_id: str): + if not publisher: + return None + audio_msg: AudioTrunk = publisher.checkAndGetAudio() + if audio_msg and audio_msg.status != "finish": + return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) + return None + + def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \ + Generator[StreamResponse, None, None]: + + publisher = None + task_id = self._application_generate_entity.task_id + tenant_id = self._application_generate_entity.app_config.tenant_id + features_dict = self._workflow.features_dict + + if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[ + 'text_to_speech'].get('autoPlay') == 'enabled': + publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice')) + for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager): + while True: + audio_response = self._listenAudioMsg(publisher, task_id=task_id) + if audio_response: + yield audio_response + else: + break + yield response + + start_listener_time = time.time() + while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT: + try: + if not publisher: + break + audio_trunk = publisher.checkAndGetAudio() + if audio_trunk is None: + # release cpu + # sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file) + time.sleep(TTS_AUTO_PLAY_YIELD_CPU_TIME) + continue + if audio_trunk.status == "finish": + break + else: + yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id) + except Exception as e: + logger.error(e) + break + yield MessageAudioEndStreamResponse(audio='', task_id=task_id) + + def _process_stream_response( self, + publisher: AppGeneratorTTSPublisher, trace_manager: Optional[TraceQueueManager] = None ) -> Generator[StreamResponse, None, None]: """ @@ -170,6 +225,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa :return: """ for message in self._queue_manager.listen(): + if publisher: + publisher.publish(message=message) event = message.event if isinstance(event, QueueErrorEvent): @@ -251,6 +308,10 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa else: continue + if publisher: + publisher.publish(None) + + def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None: """ Save workflow app log. diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index af93399a24..7bc5598984 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -69,6 +69,7 @@ class WorkflowTaskState(TaskState): iteration_nested_node_ids: list[str] = None + class AdvancedChatTaskState(WorkflowTaskState): """ AdvancedChatTaskState entity @@ -86,6 +87,8 @@ class StreamEvent(Enum): ERROR = "error" MESSAGE = "message" MESSAGE_END = "message_end" + TTS_MESSAGE = "tts_message" + TTS_MESSAGE_END = "tts_message_end" MESSAGE_FILE = "message_file" MESSAGE_REPLACE = "message_replace" AGENT_THOUGHT = "agent_thought" @@ -130,6 +133,22 @@ class MessageStreamResponse(StreamResponse): answer: str +class MessageAudioStreamResponse(StreamResponse): + """ + MessageStreamResponse entity + """ + event: StreamEvent = StreamEvent.TTS_MESSAGE + audio: str + + +class MessageAudioEndStreamResponse(StreamResponse): + """ + MessageStreamResponse entity + """ + event: StreamEvent = StreamEvent.TTS_MESSAGE_END + audio: str + + class MessageEndStreamResponse(StreamResponse): """ MessageEndStreamResponse entity @@ -186,6 +205,7 @@ class WorkflowStartStreamResponse(StreamResponse): """ WorkflowStartStreamResponse entity """ + class Data(BaseModel): """ Data entity @@ -205,6 +225,7 @@ class WorkflowFinishStreamResponse(StreamResponse): """ WorkflowFinishStreamResponse entity """ + class Data(BaseModel): """ Data entity @@ -232,6 +253,7 @@ class NodeStartStreamResponse(StreamResponse): """ NodeStartStreamResponse entity """ + class Data(BaseModel): """ Data entity @@ -273,6 +295,7 @@ class NodeFinishStreamResponse(StreamResponse): """ NodeFinishStreamResponse entity """ + class Data(BaseModel): """ Data entity @@ -323,10 +346,12 @@ class NodeFinishStreamResponse(StreamResponse): } } + class IterationNodeStartStreamResponse(StreamResponse): """ NodeStartStreamResponse entity """ + class Data(BaseModel): """ Data entity @@ -344,10 +369,12 @@ class IterationNodeStartStreamResponse(StreamResponse): workflow_run_id: str data: Data + class IterationNodeNextStreamResponse(StreamResponse): """ NodeStartStreamResponse entity """ + class Data(BaseModel): """ Data entity @@ -365,10 +392,12 @@ class IterationNodeNextStreamResponse(StreamResponse): workflow_run_id: str data: Data + class IterationNodeCompletedStreamResponse(StreamResponse): """ NodeCompletedStreamResponse entity """ + class Data(BaseModel): """ Data entity @@ -393,10 +422,12 @@ class IterationNodeCompletedStreamResponse(StreamResponse): workflow_run_id: str data: Data + class TextChunkStreamResponse(StreamResponse): """ TextChunkStreamResponse entity """ + class Data(BaseModel): """ Data entity @@ -411,6 +442,7 @@ class TextReplaceStreamResponse(StreamResponse): """ TextReplaceStreamResponse entity """ + class Data(BaseModel): """ Data entity @@ -473,6 +505,7 @@ class ChatbotAppBlockingResponse(AppBlockingResponse): """ ChatbotAppBlockingResponse entity """ + class Data(BaseModel): """ Data entity @@ -492,6 +525,7 @@ class CompletionAppBlockingResponse(AppBlockingResponse): """ CompletionAppBlockingResponse entity """ + class Data(BaseModel): """ Data entity @@ -510,6 +544,7 @@ class WorkflowAppBlockingResponse(AppBlockingResponse): """ WorkflowAppBlockingResponse entity """ + class Data(BaseModel): """ Data entity @@ -528,10 +563,12 @@ class WorkflowAppBlockingResponse(AppBlockingResponse): workflow_run_id: str data: Data + class WorkflowIterationState(BaseModel): """ WorkflowIterationState entity """ + class Data(BaseModel): """ Data entity diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 7d16d015bf..c9644c7d4c 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -4,6 +4,8 @@ import time from collections.abc import Generator from typing import Optional, Union, cast +from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME +from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import ( AgentChatAppGenerateEntity, @@ -32,6 +34,8 @@ from core.app.entities.task_entities import ( CompletionAppStreamResponse, EasyUITaskState, ErrorStreamResponse, + MessageAudioEndStreamResponse, + MessageAudioStreamResponse, MessageEndStreamResponse, StreamResponse, ) @@ -87,6 +91,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan """ super().__init__(application_generate_entity, queue_manager, user, stream) self._model_config = application_generate_entity.model_conf + self._app_config = application_generate_entity.app_config self._conversation = conversation self._message = message @@ -102,7 +107,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan self._conversation_name_generate_thread = None def process( - self, + self, ) -> Union[ ChatbotAppBlockingResponse, CompletionAppBlockingResponse, @@ -123,7 +128,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan self._application_generate_entity.query ) - generator = self._process_stream_response( + generator = self._wrapper_process_stream_response( trace_manager=self._application_generate_entity.trace_manager ) if self._stream: @@ -202,14 +207,64 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan stream_response=stream_response ) + def _listenAudioMsg(self, publisher, task_id: str): + if publisher is None: + return None + audio_msg: AudioTrunk = publisher.checkAndGetAudio() + if audio_msg and audio_msg.status != "finish": + # audio_str = audio_msg.audio.decode('utf-8', errors='ignore') + return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) + return None + + def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \ + Generator[StreamResponse, None, None]: + + tenant_id = self._application_generate_entity.app_config.tenant_id + task_id = self._application_generate_entity.task_id + publisher = None + text_to_speech_dict = self._app_config.app_model_config_dict.get('text_to_speech') + if text_to_speech_dict and text_to_speech_dict.get('autoPlay') == 'enabled' and text_to_speech_dict.get('enabled'): + publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get('voice', None)) + for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager): + while True: + audio_response = self._listenAudioMsg(publisher, task_id) + if audio_response: + yield audio_response + else: + break + yield response + + start_listener_time = time.time() + # timeout + while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT: + if publisher is None: + break + audio = publisher.checkAndGetAudio() + if audio is None: + # release cpu + # sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file) + time.sleep(TTS_AUTO_PLAY_YIELD_CPU_TIME) + continue + if audio.status == "finish": + break + else: + start_listener_time = time.time() + yield MessageAudioStreamResponse(audio=audio.audio, + task_id=task_id) + yield MessageAudioEndStreamResponse(audio='', task_id=task_id) + def _process_stream_response( - self, trace_manager: Optional[TraceQueueManager] = None + self, + publisher: AppGeneratorTTSPublisher, + trace_manager: Optional[TraceQueueManager] = None ) -> Generator[StreamResponse, None, None]: """ Process stream response. :return: """ for message in self._queue_manager.listen(): + if publisher: + publisher.publish(message) event = message.event if isinstance(event, QueueErrorEvent): @@ -272,12 +327,13 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan yield self._ping_stream_response() else: continue - + if publisher: + publisher.publish(None) if self._conversation_name_generate_thread: self._conversation_name_generate_thread.join() def _save_message( - self, trace_manager: Optional[TraceQueueManager] = None + self, trace_manager: Optional[TraceQueueManager] = None ) -> None: """ Save message. diff --git a/api/core/model_manager.py b/api/core/model_manager.py index e0b6960c23..d64db890f9 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -264,7 +264,7 @@ class ModelInstance: user=user ) - def invoke_tts(self, content_text: str, tenant_id: str, voice: str, streaming: bool, user: Optional[str] = None) \ + def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) \ -> str: """ Invoke large language tts model @@ -287,8 +287,7 @@ class ModelInstance: content_text=content_text, user=user, tenant_id=tenant_id, - voice=voice, - streaming=streaming + voice=voice ) def _round_robin_invoke(self, function: Callable, *args, **kwargs): diff --git a/api/core/model_runtime/model_providers/__base/tts_model.py b/api/core/model_runtime/model_providers/__base/tts_model.py index bc6a475f36..086a189246 100644 --- a/api/core/model_runtime/model_providers/__base/tts_model.py +++ b/api/core/model_runtime/model_providers/__base/tts_model.py @@ -1,4 +1,6 @@ import hashlib +import logging +import re import subprocess import uuid from abc import abstractmethod @@ -10,7 +12,7 @@ from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelTy from core.model_runtime.errors.invoke import InvokeBadRequestError from core.model_runtime.model_providers.__base.ai_model import AIModel - +logger = logging.getLogger(__name__) class TTSModel(AIModel): """ Model class for ttstext model. @@ -20,7 +22,7 @@ class TTSModel(AIModel): # pydantic configs model_config = ConfigDict(protected_namespaces=()) - def invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, streaming: bool, + def invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None): """ Invoke large language model @@ -35,14 +37,15 @@ class TTSModel(AIModel): :return: translated audio file """ try: + logger.info(f"Invoke TTS model: {model} , invoke content : {content_text}") self._is_ffmpeg_installed() - return self._invoke(model=model, credentials=credentials, user=user, streaming=streaming, + return self._invoke(model=model, credentials=credentials, user=user, content_text=content_text, voice=voice, tenant_id=tenant_id) except Exception as e: raise self._transform_invoke_error(e) @abstractmethod - def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, streaming: bool, + def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None): """ Invoke large language model @@ -123,26 +126,26 @@ class TTSModel(AIModel): return model_schema.model_properties[ModelPropertyKey.MAX_WORKERS] @staticmethod - def _split_text_into_sentences(text: str, limit: int, delimiters=None): - if delimiters is None: - delimiters = set('。!?;\n') - - buf = [] - word_count = 0 - for char in text: - buf.append(char) - if char in delimiters: - if word_count >= limit: - yield ''.join(buf) - buf = [] - word_count = 0 - else: - word_count += 1 - else: - word_count += 1 - - if buf: - yield ''.join(buf) + def _split_text_into_sentences(org_text, max_length=2000, pattern=r'[。.!?]'): + match = re.compile(pattern) + tx = match.finditer(org_text) + start = 0 + result = [] + one_sentence = '' + for i in tx: + end = i.regs[0][1] + tmp = org_text[start:end] + if len(one_sentence + tmp) > max_length: + result.append(one_sentence) + one_sentence = '' + one_sentence += tmp + start = end + last_sens = org_text[start:] + if last_sens: + one_sentence += last_sens + if one_sentence != '': + result.append(one_sentence) + return result @staticmethod def _is_ffmpeg_installed(): diff --git a/api/core/model_runtime/model_providers/azure_openai/tts/tts.py b/api/core/model_runtime/model_providers/azure_openai/tts/tts.py index dcd154cff0..50c125b873 100644 --- a/api/core/model_runtime/model_providers/azure_openai/tts/tts.py +++ b/api/core/model_runtime/model_providers/azure_openai/tts/tts.py @@ -4,7 +4,7 @@ from functools import reduce from io import BytesIO from typing import Optional -from flask import Response, stream_with_context +from flask import Response from openai import AzureOpenAI from pydub import AudioSegment @@ -14,7 +14,6 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.tts_model import TTSModel from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI from core.model_runtime.model_providers.azure_openai._constant import TTS_BASE_MODELS, AzureBaseModel -from extensions.ext_storage import storage class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel): @@ -23,7 +22,7 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel): """ def _invoke(self, model: str, tenant_id: str, credentials: dict, - content_text: str, voice: str, streaming: bool, user: Optional[str] = None) -> any: + content_text: str, voice: str, user: Optional[str] = None) -> any: """ _invoke text2speech model @@ -32,30 +31,23 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel): :param credentials: model credentials :param content_text: text content to be translated :param voice: model timbre - :param streaming: output is streaming :param user: unique user id :return: text translated to audio file """ - audio_type = self._get_model_audio_type(model, credentials) if not voice or voice not in [d['value'] for d in self.get_tts_model_voices(model=model, credentials=credentials)]: voice = self._get_model_default_voice(model, credentials) - if streaming: - return Response(stream_with_context(self._tts_invoke_streaming(model=model, - credentials=credentials, - content_text=content_text, - tenant_id=tenant_id, - voice=voice)), - status=200, mimetype=f'audio/{audio_type}') - else: - return self._tts_invoke(model=model, credentials=credentials, content_text=content_text, voice=voice) - def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None: + return self._tts_invoke_streaming(model=model, + credentials=credentials, + content_text=content_text, + voice=voice) + + def validate_credentials(self, model: str, credentials: dict) -> None: """ validate credentials text2speech model :param model: model name :param credentials: model credentials - :param user: unique user id :return: text translated to audio file """ try: @@ -82,7 +74,7 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel): word_limit = self._get_model_word_limit(model, credentials) max_workers = self._get_model_workers_limit(model, credentials) try: - sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit)) + sentences = list(self._split_text_into_sentences(org_text=content_text, max_length=word_limit)) audio_bytes_list = [] # Create a thread pool and map the function to the list of sentences @@ -107,34 +99,37 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel): except Exception as ex: raise InvokeBadRequestError(str(ex)) - # Todo: To improve the streaming function - def _tts_invoke_streaming(self, model: str, tenant_id: str, credentials: dict, content_text: str, + def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> any: """ _tts_invoke_streaming text2speech model - :param model: model name - :param tenant_id: user tenant id :param credentials: model credentials :param content_text: text content to be translated :param voice: model timbre :return: text translated to audio file """ - # transform credentials to kwargs for model instance - credentials_kwargs = self._to_credential_kwargs(credentials) - if not voice or voice not in self.get_tts_model_voices(model=model, credentials=credentials): - voice = self._get_model_default_voice(model, credentials) - word_limit = self._get_model_word_limit(model, credentials) - audio_type = self._get_model_audio_type(model, credentials) - tts_file_id = self._get_file_name(content_text) - file_path = f'generate_files/audio/{tenant_id}/{tts_file_id}.{audio_type}' try: + # doc: https://platform.openai.com/docs/guides/text-to-speech + credentials_kwargs = self._to_credential_kwargs(credentials) client = AzureOpenAI(**credentials_kwargs) - sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit)) - for sentence in sentences: - response = client.audio.speech.create(model=model, voice=voice, input=sentence.strip()) - # response.stream_to_file(file_path) - storage.save(file_path, response.read()) + # max font is 4096,there is 3500 limit for each request + max_length = 3500 + if len(content_text) > max_length: + sentences = self._split_text_into_sentences(content_text, max_length=max_length) + executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(3, len(sentences))) + futures = [executor.submit(client.audio.speech.with_streaming_response.create, model=model, + response_format="mp3", + input=sentences[i], voice=voice) for i in range(len(sentences))] + for index, future in enumerate(futures): + yield from future.result().__enter__().iter_bytes(1024) + + else: + response = client.audio.speech.with_streaming_response.create(model=model, voice=voice, + response_format="mp3", + input=content_text.strip()) + + yield from response.__enter__().iter_bytes(1024) except Exception as ex: raise InvokeBadRequestError(str(ex)) @@ -162,7 +157,7 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel): @staticmethod - def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel: + def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel | None: for ai_model_entity in TTS_BASE_MODELS: if ai_model_entity.base_model_name == base_model_name: ai_model_entity_copy = copy.deepcopy(ai_model_entity) @@ -170,5 +165,4 @@ class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel): ai_model_entity_copy.entity.label.en_US = model ai_model_entity_copy.entity.label.zh_Hans = model return ai_model_entity_copy - return None diff --git a/api/core/model_runtime/model_providers/openai/tts/tts-1-hd.yaml b/api/core/model_runtime/model_providers/openai/tts/tts-1-hd.yaml index 72f15134ea..449c131f9d 100644 --- a/api/core/model_runtime/model_providers/openai/tts/tts-1-hd.yaml +++ b/api/core/model_runtime/model_providers/openai/tts/tts-1-hd.yaml @@ -21,7 +21,7 @@ model_properties: - mode: 'shimmer' name: 'Shimmer' language: [ 'zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID' ] - word_limit: 120 + word_limit: 3500 audio_type: 'mp3' max_workers: 5 pricing: diff --git a/api/core/model_runtime/model_providers/openai/tts/tts-1.yaml b/api/core/model_runtime/model_providers/openai/tts/tts-1.yaml index 8d222fed64..83969fb2f7 100644 --- a/api/core/model_runtime/model_providers/openai/tts/tts-1.yaml +++ b/api/core/model_runtime/model_providers/openai/tts/tts-1.yaml @@ -21,7 +21,7 @@ model_properties: - mode: 'shimmer' name: 'Shimmer' language: ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] - word_limit: 120 + word_limit: 3500 audio_type: 'mp3' max_workers: 5 pricing: diff --git a/api/core/model_runtime/model_providers/openai/tts/tts.py b/api/core/model_runtime/model_providers/openai/tts/tts.py index f83c57078a..608ed897e0 100644 --- a/api/core/model_runtime/model_providers/openai/tts/tts.py +++ b/api/core/model_runtime/model_providers/openai/tts/tts.py @@ -3,7 +3,7 @@ from functools import reduce from io import BytesIO from typing import Optional -from flask import Response, stream_with_context +from flask import Response from openai import OpenAI from pydub import AudioSegment @@ -11,7 +11,6 @@ from core.model_runtime.errors.invoke import InvokeBadRequestError from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.tts_model import TTSModel from core.model_runtime.model_providers.openai._common import _CommonOpenAI -from extensions.ext_storage import storage class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel): @@ -20,7 +19,7 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel): """ def _invoke(self, model: str, tenant_id: str, credentials: dict, - content_text: str, voice: str, streaming: bool, user: Optional[str] = None) -> any: + content_text: str, voice: str, user: Optional[str] = None) -> any: """ _invoke text2speech model @@ -29,22 +28,17 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel): :param credentials: model credentials :param content_text: text content to be translated :param voice: model timbre - :param streaming: output is streaming :param user: unique user id :return: text translated to audio file """ - audio_type = self._get_model_audio_type(model, credentials) + if not voice or voice not in [d['value'] for d in self.get_tts_model_voices(model=model, credentials=credentials)]: voice = self._get_model_default_voice(model, credentials) - if streaming: - return Response(stream_with_context(self._tts_invoke_streaming(model=model, - credentials=credentials, - content_text=content_text, - tenant_id=tenant_id, - voice=voice)), - status=200, mimetype=f'audio/{audio_type}') - else: - return self._tts_invoke(model=model, credentials=credentials, content_text=content_text, voice=voice) + # if streaming: + return self._tts_invoke_streaming(model=model, + credentials=credentials, + content_text=content_text, + voice=voice) def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None: """ @@ -79,7 +73,7 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel): word_limit = self._get_model_word_limit(model, credentials) max_workers = self._get_model_workers_limit(model, credentials) try: - sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit)) + sentences = list(self._split_text_into_sentences(org_text=content_text, max_length=word_limit)) audio_bytes_list = [] # Create a thread pool and map the function to the list of sentences @@ -104,34 +98,40 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel): except Exception as ex: raise InvokeBadRequestError(str(ex)) - # Todo: To improve the streaming function - def _tts_invoke_streaming(self, model: str, tenant_id: str, credentials: dict, content_text: str, + + def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> any: """ _tts_invoke_streaming text2speech model :param model: model name - :param tenant_id: user tenant id :param credentials: model credentials :param content_text: text content to be translated :param voice: model timbre :return: text translated to audio file """ - # transform credentials to kwargs for model instance - credentials_kwargs = self._to_credential_kwargs(credentials) - if not voice or voice not in self.get_tts_model_voices(model=model, credentials=credentials): - voice = self._get_model_default_voice(model, credentials) - word_limit = self._get_model_word_limit(model, credentials) - audio_type = self._get_model_audio_type(model, credentials) - tts_file_id = self._get_file_name(content_text) - file_path = f'generate_files/audio/{tenant_id}/{tts_file_id}.{audio_type}' try: + # doc: https://platform.openai.com/docs/guides/text-to-speech + credentials_kwargs = self._to_credential_kwargs(credentials) client = OpenAI(**credentials_kwargs) - sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit)) - for sentence in sentences: - response = client.audio.speech.create(model=model, voice=voice, input=sentence.strip()) - # response.stream_to_file(file_path) - storage.save(file_path, response.read()) + if not voice or voice not in self.get_tts_model_voices(model=model, credentials=credentials): + voice = self._get_model_default_voice(model, credentials) + word_limit = self._get_model_word_limit(model, credentials) + if len(content_text) > word_limit: + sentences = self._split_text_into_sentences(content_text, max_length=word_limit) + executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(3, len(sentences))) + futures = [executor.submit(client.audio.speech.with_streaming_response.create, model=model, + response_format="mp3", + input=sentences[i], voice=voice) for i in range(len(sentences))] + for index, future in enumerate(futures): + yield from future.result().__enter__().iter_bytes(1024) + + else: + response = client.audio.speech.with_streaming_response.create(model=model, voice=voice, + response_format="mp3", + input=content_text.strip()) + + yield from response.__enter__().iter_bytes(1024) except Exception as ex: raise InvokeBadRequestError(str(ex)) diff --git a/api/core/model_runtime/model_providers/tongyi/tts/tts-1.yaml b/api/core/model_runtime/model_providers/tongyi/tts/tts-1.yaml index e533d5812d..4eaa0ff361 100644 --- a/api/core/model_runtime/model_providers/tongyi/tts/tts-1.yaml +++ b/api/core/model_runtime/model_providers/tongyi/tts/tts-1.yaml @@ -129,7 +129,7 @@ model_properties: - mode: "sambert-waan-v1" name: "Waan(泰语女声)" language: [ "th-TH" ] - word_limit: 120 + word_limit: 7000 audio_type: 'mp3' max_workers: 5 pricing: diff --git a/api/core/model_runtime/model_providers/tongyi/tts/tts.py b/api/core/model_runtime/model_providers/tongyi/tts/tts.py index 7ef053479b..655ed2d1d0 100644 --- a/api/core/model_runtime/model_providers/tongyi/tts/tts.py +++ b/api/core/model_runtime/model_providers/tongyi/tts/tts.py @@ -1,17 +1,21 @@ import concurrent.futures +import threading from functools import reduce from io import BytesIO +from queue import Queue from typing import Optional import dashscope -from flask import Response, stream_with_context +from dashscope import SpeechSynthesizer +from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse +from dashscope.audio.tts import ResultCallback, SpeechSynthesisResult +from flask import Response from pydub import AudioSegment from core.model_runtime.errors.invoke import InvokeBadRequestError from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.tts_model import TTSModel from core.model_runtime.model_providers.tongyi._common import _CommonTongyi -from extensions.ext_storage import storage class TongyiText2SpeechModel(_CommonTongyi, TTSModel): @@ -19,7 +23,7 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel): Model class for Tongyi Speech to text model. """ - def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, streaming: bool, + def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None) -> any: """ _invoke text2speech model @@ -29,22 +33,17 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel): :param credentials: model credentials :param voice: model timbre :param content_text: text content to be translated - :param streaming: output is streaming :param user: unique user id :return: text translated to audio file """ - audio_type = self._get_model_audio_type(model, credentials) - if not voice or voice not in [d['value'] for d in self.get_tts_model_voices(model=model, credentials=credentials)]: + if not voice or voice not in [d['value'] for d in + self.get_tts_model_voices(model=model, credentials=credentials)]: voice = self._get_model_default_voice(model, credentials) - if streaming: - return Response(stream_with_context(self._tts_invoke_streaming(model=model, - credentials=credentials, - content_text=content_text, - voice=voice, - tenant_id=tenant_id)), - status=200, mimetype=f'audio/{audio_type}') - else: - return self._tts_invoke(model=model, credentials=credentials, content_text=content_text, voice=voice) + + return self._tts_invoke_streaming(model=model, + credentials=credentials, + content_text=content_text, + voice=voice) def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None: """ @@ -79,7 +78,7 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel): word_limit = self._get_model_word_limit(model, credentials) max_workers = self._get_model_workers_limit(model, credentials) try: - sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit)) + sentences = list(self._split_text_into_sentences(org_text=content_text, max_length=word_limit)) audio_bytes_list = [] # Create a thread pool and map the function to the list of sentences @@ -105,14 +104,12 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel): except Exception as ex: raise InvokeBadRequestError(str(ex)) - # Todo: To improve the streaming function - def _tts_invoke_streaming(self, model: str, tenant_id: str, credentials: dict, content_text: str, + def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> any: """ _tts_invoke_streaming text2speech model :param model: model name - :param tenant_id: user tenant id :param credentials: model credentials :param voice: model timbre :param content_text: text content to be translated @@ -120,18 +117,32 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel): """ word_limit = self._get_model_word_limit(model, credentials) audio_type = self._get_model_audio_type(model, credentials) - tts_file_id = self._get_file_name(content_text) - file_path = f'generate_files/audio/{tenant_id}/{tts_file_id}.{audio_type}' try: - sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit)) - for sentence in sentences: - response = dashscope.audio.tts.SpeechSynthesizer.call(model=voice, sample_rate=48000, - api_key=credentials.get('dashscope_api_key'), - text=sentence.strip(), - format=audio_type, word_timestamp_enabled=True, - phoneme_timestamp_enabled=True) - if isinstance(response.get_audio_data(), bytes): - storage.save(file_path, response.get_audio_data()) + audio_queue: Queue = Queue() + callback = Callback(queue=audio_queue) + + def invoke_remote(content, v, api_key, cb, at, wl): + if len(content) < word_limit: + sentences = [content] + else: + sentences = list(self._split_text_into_sentences(org_text=content, max_length=wl)) + for sentence in sentences: + SpeechSynthesizer.call(model=v, sample_rate=16000, + api_key=api_key, + text=sentence.strip(), + callback=cb, + format=at, word_timestamp_enabled=True, + phoneme_timestamp_enabled=True) + + threading.Thread(target=invoke_remote, args=( + content_text, voice, credentials.get('dashscope_api_key'), callback, audio_type, word_limit)).start() + + while True: + audio = audio_queue.get() + if audio is None: + break + yield audio + except Exception as ex: raise InvokeBadRequestError(str(ex)) @@ -152,3 +163,29 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel): format=audio_type) if isinstance(response.get_audio_data(), bytes): return response.get_audio_data() + + +class Callback(ResultCallback): + + def __init__(self, queue: Queue): + self._queue = queue + + def on_open(self): + pass + + def on_complete(self): + self._queue.put(None) + self._queue.task_done() + + def on_error(self, response: SpeechSynthesisResponse): + self._queue.put(None) + self._queue.task_done() + + def on_close(self): + self._queue.put(None) + self._queue.task_done() + + def on_event(self, result: SpeechSynthesisResult): + ad = result.get_audio_frame() + if ad: + self._queue.put(ad) diff --git a/api/pyproject.toml b/api/pyproject.toml index 90e825ea6c..a5d226f2ce 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -49,7 +49,7 @@ ignore = [ "B006", # mutable-argument-default "B007", # unused-loop-control-variable "B026", # star-arg-unpacking-after-keyword-arg - "B901", # return-in-generator +# "B901", # return-in-generator "B904", # raise-without-from-inside-except "B905", # zip-without-explicit-strict ] diff --git a/api/services/app_service.py b/api/services/app_service.py index 7f5b356772..11af5ef4fb 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -123,6 +123,8 @@ class AppService: app.icon = args['icon'] app.icon_background = args['icon_background'] app.tenant_id = tenant_id + app.api_rph = args.get('api_rph', 0) + app.api_rpm = args.get('api_rpm', 0) db.session.add(app) db.session.flush() diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 965df918d8..58c950816f 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -1,11 +1,12 @@ import io +import logging from typing import Optional from werkzeug.datastructures import FileStorage from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType -from models.model import App, AppMode, AppModelConfig +from models.model import App, AppMode, AppModelConfig, Message from services.errors.audio import ( AudioTooLargeServiceError, NoAudioUploadedServiceError, @@ -18,6 +19,8 @@ FILE_SIZE = 30 FILE_SIZE_LIMIT = FILE_SIZE * 1024 * 1024 ALLOWED_EXTENSIONS = ['mp3', 'mp4', 'mpeg', 'mpga', 'm4a', 'wav', 'webm', 'amr'] +logger = logging.getLogger(__name__) + class AudioService: @classmethod @@ -64,51 +67,74 @@ class AudioService: return {"text": model_instance.invoke_speech2text(file=buffer, user=end_user)} @classmethod - def transcript_tts(cls, app_model: App, text: str, streaming: bool, - voice: Optional[str] = None, end_user: Optional[str] = None): - if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: - workflow = app_model.workflow - if workflow is None: - raise ValueError("TTS is not enabled") + def transcript_tts(cls, app_model: App, text: Optional[str] = None, + voice: Optional[str] = None, end_user: Optional[str] = None, message_id: Optional[str] = None): + from collections.abc import Generator - features_dict = workflow.features_dict - if 'text_to_speech' not in features_dict or not features_dict['text_to_speech'].get('enabled'): - raise ValueError("TTS is not enabled") + from flask import Response, stream_with_context - voice = features_dict['text_to_speech'].get('voice') if voice is None else voice - else: - text_to_speech_dict = app_model.app_model_config.text_to_speech_dict + from app import app + from extensions.ext_database import db - if not text_to_speech_dict.get('enabled'): - raise ValueError("TTS is not enabled") + def invoke_tts(text_content: str, app_model, voice: Optional[str] = None): + with app.app_context(): + if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: + workflow = app_model.workflow + if workflow is None: + raise ValueError("TTS is not enabled") - voice = text_to_speech_dict.get('voice') if voice is None else voice + features_dict = workflow.features_dict + if 'text_to_speech' not in features_dict or not features_dict['text_to_speech'].get('enabled'): + raise ValueError("TTS is not enabled") - model_manager = ModelManager() - model_instance = model_manager.get_default_model_instance( - tenant_id=app_model.tenant_id, - model_type=ModelType.TTS - ) - if model_instance is None: - raise ProviderNotSupportTextToSpeechServiceError() - - try: - if not voice: - voices = model_instance.get_tts_voices() - if voices: - voice = voices[0].get('value') + voice = features_dict['text_to_speech'].get('voice') if voice is None else voice else: - raise ValueError("Sorry, no voice available.") + text_to_speech_dict = app_model.app_model_config.text_to_speech_dict - return model_instance.invoke_tts( - content_text=text.strip(), - user=end_user, - streaming=streaming, - tenant_id=app_model.tenant_id, - voice=voice - ) - except Exception as e: - raise e + if not text_to_speech_dict.get('enabled'): + raise ValueError("TTS is not enabled") + + voice = text_to_speech_dict.get('voice') if voice is None else voice + + model_manager = ModelManager() + model_instance = model_manager.get_default_model_instance( + tenant_id=app_model.tenant_id, + model_type=ModelType.TTS + ) + try: + if not voice: + voices = model_instance.get_tts_voices() + if voices: + voice = voices[0].get('value') + else: + raise ValueError("Sorry, no voice available.") + + return model_instance.invoke_tts( + content_text=text_content.strip(), + user=end_user, + tenant_id=app_model.tenant_id, + voice=voice + ) + except Exception as e: + raise e + + if message_id: + message = db.session.query(Message).filter( + Message.id == message_id + ).first() + if message.answer == '' and message.status == 'normal': + return None + + else: + response = invoke_tts(message.answer, app_model=app_model, voice=voice) + if isinstance(response, Generator): + return Response(stream_with_context(response), content_type='audio/mpeg') + return response + else: + response = invoke_tts(text, app_model, voice) + if isinstance(response, Generator): + return Response(stream_with_context(response), content_type='audio/mpeg') + return response @classmethod def transcript_tts_voices(cls, tenant_id: str, language: str): diff --git a/web/app/components/app/configuration/config-voice/param-config-content.tsx b/web/app/components/app/configuration/config-voice/param-config-content.tsx index 10302d6221..d96d073262 100644 --- a/web/app/components/app/configuration/config-voice/param-config-content.tsx +++ b/web/app/components/app/configuration/config-voice/param-config-content.tsx @@ -11,11 +11,13 @@ import { usePathname } from 'next/navigation' import { useTranslation } from 'react-i18next' import { Listbox, Transition } from '@headlessui/react' import { CheckIcon, ChevronDownIcon } from '@heroicons/react/20/solid' +import RadioGroup from '@/app/components/app/configuration/config-vision/radio-group' import type { Item } from '@/app/components/base/select' import ConfigContext from '@/context/debug-configuration' import { fetchAppVoices } from '@/service/apps' import Tooltip from '@/app/components/base/tooltip' import { languages } from '@/i18n/language' +import { TtsAutoPlay } from '@/types/app' const VoiceParamConfig: FC = () => { const { t } = useTranslation() const pathname = usePathname() @@ -27,12 +29,16 @@ const VoiceParamConfig: FC = () => { setTextToSpeechConfig, } = useContext(ConfigContext) - const languageItem = languages.find(item => item.value === textToSpeechConfig.language) + let languageItem = languages.find(item => item.value === textToSpeechConfig.language) const localLanguagePlaceholder = languageItem?.name || t('common.placeholder.select') - + if (languages && !languageItem) + languageItem = languages[0] const language = languageItem?.value const voiceItems = useSWR({ appId, language }, fetchAppVoices).data - const voiceItem = voiceItems?.find(item => item.value === textToSpeechConfig.voice) + let voiceItem = voiceItems?.find(item => item.value === textToSpeechConfig.voice) + if (voiceItems && !voiceItem) + voiceItem = voiceItems[0] + const localVoicePlaceholder = voiceItem?.name || t('common.placeholder.select') return ( @@ -42,8 +48,9 @@ const VoiceParamConfig: FC = () => {
-
{t('appDebug.voice.voiceSettings.language')}
- +
{t('appDebug.voice.voiceSettings.language')}
+ {t('appDebug.voice.voiceSettings.resolutionTooltip').split('\n').map(item => (
{item}
))} @@ -61,7 +68,8 @@ const VoiceParamConfig: FC = () => { }} >
- + {languageItem?.name ? t(`common.voice.language.${languageItem?.value.replace('-', '')}`) : localLanguagePlaceholder} @@ -79,7 +87,8 @@ const VoiceParamConfig: FC = () => { leaveTo="opacity-0" > - + {languages.map((item: Item) => ( { 'absolute inset-y-0 right-0 flex items-center pr-4 text-gray-700', )} > -
-
-
{t('appDebug.voice.voiceSettings.voice')}
+
{t('appDebug.voice.voiceSettings.voice')}
{ }} >
- - {voiceItem?.name ?? localVoicePlaceholder} + + {voiceItem?.name ?? localVoicePlaceholder} { leaveTo="opacity-0" > - + {voiceItems?.map((item: Item) => ( { 'absolute inset-y-0 right-0 flex items-center pr-4 text-gray-700', )} > - )} @@ -174,6 +186,30 @@ const VoiceParamConfig: FC = () => {
+
+
{t('appDebug.voice.voiceSettings.autoPlay')}
+ { + setTextToSpeechConfig({ + ...textToSpeechConfig, + autoPlay: value, + }) + }} + /> +
diff --git a/web/app/components/app/configuration/features/chat-group/text-to-speech/index.tsx b/web/app/components/app/configuration/features/chat-group/text-to-speech/index.tsx index 4bb66cb635..4c5db22513 100644 --- a/web/app/components/app/configuration/features/chat-group/text-to-speech/index.tsx +++ b/web/app/components/app/configuration/features/chat-group/text-to-speech/index.tsx @@ -40,7 +40,6 @@ const TextToSpeech: FC = () => { { languageInfo?.example && ( diff --git a/web/app/components/app/text-generate/item/index.tsx b/web/app/components/app/text-generate/item/index.tsx index 79880630e3..f803b06656 100644 --- a/web/app/components/app/text-generate/item/index.tsx +++ b/web/app/components/app/text-generate/item/index.tsx @@ -428,8 +428,7 @@ const GenerationItem: FC = ({ <>
diff --git a/web/app/components/base/audio-btn/audio.player.manager.ts b/web/app/components/base/audio-btn/audio.player.manager.ts new file mode 100644 index 0000000000..03e9e21f93 --- /dev/null +++ b/web/app/components/base/audio-btn/audio.player.manager.ts @@ -0,0 +1,53 @@ +import AudioPlayer from '@/app/components/base/audio-btn/audio' +declare global { + // eslint-disable-next-line @typescript-eslint/consistent-type-definitions + interface AudioPlayerManager { + instance: AudioPlayerManager + } + +} + +export class AudioPlayerManager { + private static instance: AudioPlayerManager + private audioPlayers: AudioPlayer | null = null + private msgId: string | undefined + + private constructor() { + } + + public static getInstance(): AudioPlayerManager { + if (!AudioPlayerManager.instance) { + AudioPlayerManager.instance = new AudioPlayerManager() + this.instance = AudioPlayerManager.instance + } + + return AudioPlayerManager.instance + } + + public getAudioPlayer(url: string, isPublic: boolean, id: string | undefined, msgContent: string | null | undefined, voice: string | undefined, callback: ((event: string) => {}) | null): AudioPlayer { + if (this.msgId && this.msgId === id && this.audioPlayers) { + this.audioPlayers.setCallback(callback) + return this.audioPlayers + } + else { + if (this.audioPlayers) { + try { + this.audioPlayers.pauseAudio() + this.audioPlayers.cacheBuffers = [] + this.audioPlayers.sourceBuffer?.abort() + } + catch (e) { + } + } + + this.msgId = id + this.audioPlayers = new AudioPlayer(url, isPublic, id, msgContent, callback) + return this.audioPlayers + } + } + + public resetMsgId(msgId: string) { + this.msgId = msgId + this.audioPlayers?.resetMsgId(msgId) + } +} diff --git a/web/app/components/base/audio-btn/audio.ts b/web/app/components/base/audio-btn/audio.ts new file mode 100644 index 0000000000..239dfe0342 --- /dev/null +++ b/web/app/components/base/audio-btn/audio.ts @@ -0,0 +1,263 @@ +import Toast from '@/app/components/base/toast' +import { textToAudioStream } from '@/service/share' + +declare global { + // eslint-disable-next-line @typescript-eslint/consistent-type-definitions + interface Window { + ManagedMediaSource: any + } +} + +export default class AudioPlayer { + mediaSource: MediaSource | null + audio: HTMLAudioElement + audioContext: AudioContext + sourceBuffer?: SourceBuffer + cacheBuffers: ArrayBuffer[] = [] + pauseTimer: number | null = null + msgId: string | undefined + msgContent: string | null | undefined = null + voice: string | undefined = undefined + isLoadData = false + url: string + isPublic: boolean + callback: ((event: string) => {}) | null + + constructor(streamUrl: string, isPublic: boolean, msgId: string | undefined, msgContent: string | null | undefined, callback: ((event: string) => {}) | null) { + this.audioContext = new AudioContext() + this.msgId = msgId + this.msgContent = msgContent + this.url = streamUrl + this.isPublic = isPublic + this.callback = callback + + // Compatible with iphone ios17 ManagedMediaSource + const MediaSource = window.MediaSource || window.ManagedMediaSource + if (!MediaSource) { + Toast.notify({ + message: 'Your browser does not support audio streaming, if you are using an iPhone, please update to iOS 17.1 or later.', + type: 'error', + }) + } + this.mediaSource = MediaSource ? new MediaSource() : null + this.audio = new Audio() + this.setCallback(callback) + this.audio.src = this.mediaSource ? URL.createObjectURL(this.mediaSource) : '' + this.audio.autoplay = true + + const source = this.audioContext.createMediaElementSource(this.audio) + source.connect(this.audioContext.destination) + this.listenMediaSource('audio/mpeg') + } + + public resetMsgId(msgId: string) { + this.msgId = msgId + } + + private listenMediaSource(contentType: string) { + this.mediaSource?.addEventListener('sourceopen', () => { + if (this.sourceBuffer) + return + + this.sourceBuffer = this.mediaSource?.addSourceBuffer(contentType) + // this.sourceBuffer?.addEventListener('update', () => { + // if (this.cacheBuffers.length && !this.sourceBuffer?.updating) { + // const cacheBuffer = this.cacheBuffers.shift()! + // this.sourceBuffer?.appendBuffer(cacheBuffer) + // } + // // this.pauseAudio() + // }) + // + // this.sourceBuffer?.addEventListener('updateend', () => { + // if (this.cacheBuffers.length && !this.sourceBuffer?.updating) { + // const cacheBuffer = this.cacheBuffers.shift()! + // this.sourceBuffer?.appendBuffer(cacheBuffer) + // } + // // this.pauseAudio() + // }) + }) + } + + public setCallback(callback: ((event: string) => {}) | null) { + this.callback = callback + if (callback) { + this.audio.addEventListener('ended', () => { + callback('ended') + }, false) + this.audio.addEventListener('paused', () => { + callback('paused') + }, true) + this.audio.addEventListener('loaded', () => { + callback('loaded') + }, true) + this.audio.addEventListener('play', () => { + callback('play') + }, true) + this.audio.addEventListener('timeupdate', () => { + callback('timeupdate') + }, true) + this.audio.addEventListener('loadeddate', () => { + callback('loadeddate') + }, true) + this.audio.addEventListener('canplay', () => { + callback('canplay') + }, true) + this.audio.addEventListener('error', () => { + callback('error') + }, true) + } + } + + private async loadAudio() { + try { + const audioResponse: any = await textToAudioStream(this.url, this.isPublic, { content_type: 'audio/mpeg' }, { + message_id: this.msgId, + streaming: true, + voice: this.voice, + text: this.msgContent, + }) + + if (audioResponse.status !== 200) { + this.isLoadData = false + if (this.callback) + this.callback('error') + } + + const reader = audioResponse.body.getReader() + while (true) { + const { value, done } = await reader.read() + + if (done) { + this.receiveAudioData(value) + break + } + + this.receiveAudioData(value) + } + } + catch (error) { + this.isLoadData = false + this.callback && this.callback('error') + } + } + + // play audio + public playAudio() { + if (this.isLoadData) { + if (this.audioContext.state === 'suspended') { + this.audioContext.resume().then((_) => { + this.audio.play() + this.callback && this.callback('play') + }) + } + else if (this.audio.ended) { + this.audio.play() + this.callback && this.callback('play') + } + if (this.callback) + this.callback('play') + } + else { + this.isLoadData = true + this.loadAudio() + } + } + + private theEndOfStream() { + const endTimer = setInterval(() => { + if (!this.sourceBuffer?.updating) { + this.mediaSource?.endOfStream() + clearInterval(endTimer) + } + console.log('finishStream endOfStream endTimer') + }, 10) + } + + private finishStream() { + const timer = setInterval(() => { + if (!this.cacheBuffers.length) { + this.theEndOfStream() + clearInterval(timer) + } + + if (this.cacheBuffers.length && !this.sourceBuffer?.updating) { + const arrayBuffer = this.cacheBuffers.shift()! + this.sourceBuffer?.appendBuffer(arrayBuffer) + } + console.log('finishStream timer') + }, 10) + } + + public async playAudioWithAudio(audio: string, play = true) { + if (!audio || !audio.length) { + this.finishStream() + return + } + + const audioContent = Buffer.from(audio, 'base64') + this.receiveAudioData(new Uint8Array(audioContent)) + if (play) { + this.isLoadData = true + if (this.audio.paused) { + this.audioContext.resume().then((_) => { + this.audio.play() + this.callback && this.callback('play') + }) + } + else if (this.audio.ended) { + this.audio.play() + this.callback && this.callback('play') + } + else if (this.audio.played) { /* empty */ } + + else { + this.audio.play() + this.callback && this.callback('play') + } + } + } + + public pauseAudio() { + this.callback && this.callback('paused') + this.audio.pause() + this.audioContext.suspend() + } + + private cancer() { + + } + + private receiveAudioData(unit8Array: Uint8Array) { + if (!unit8Array) { + this.finishStream() + return + } + const audioData = this.byteArrayToArrayBuffer(unit8Array) + if (!audioData.byteLength) { + if (this.mediaSource?.readyState === 'open') + this.finishStream() + return + } + + if (this.sourceBuffer?.updating) { + this.cacheBuffers.push(audioData) + } + else { + if (this.cacheBuffers.length && !this.sourceBuffer?.updating) { + this.cacheBuffers.push(audioData) + const cacheBuffer = this.cacheBuffers.shift()! + this.sourceBuffer?.appendBuffer(cacheBuffer) + } + else { + this.sourceBuffer?.appendBuffer(audioData) + } + } + } + + private byteArrayToArrayBuffer(byteArray: Uint8Array): ArrayBuffer { + const arrayBuffer = new ArrayBuffer(byteArray.length) + const uint8Array = new Uint8Array(arrayBuffer) + uint8Array.set(byteArray) + return arrayBuffer + } +} diff --git a/web/app/components/base/audio-btn/index.tsx b/web/app/components/base/audio-btn/index.tsx index 0dd8a35edd..48081c170c 100644 --- a/web/app/components/base/audio-btn/index.tsx +++ b/web/app/components/base/audio-btn/index.tsx @@ -1,124 +1,78 @@ 'use client' -import { useEffect, useRef, useState } from 'react' +import { useRef, useState } from 'react' import { t } from 'i18next' import { useParams, usePathname } from 'next/navigation' import s from './style.module.css' import Tooltip from '@/app/components/base/tooltip' import { randomString } from '@/utils' -import { textToAudio } from '@/service/share' import Loading from '@/app/components/base/loading' +import { AudioPlayerManager } from '@/app/components/base/audio-btn/audio.player.manager' type AudioBtnProps = { - value: string + id?: string voice?: string + value?: string className?: string isAudition?: boolean - noCache: boolean + noCache?: boolean } type AudioState = 'initial' | 'loading' | 'playing' | 'paused' | 'ended' const AudioBtn = ({ - value, + id, voice, + value, className, isAudition, - noCache, }: AudioBtnProps) => { - const audioRef = useRef(null) const [audioState, setAudioState] = useState('initial') const selector = useRef(`play-tooltip-${randomString(4)}`) const params = useParams() const pathname = usePathname() - const removeCodeBlocks = (inputText: any) => { - const codeBlockRegex = /```[\s\S]*?```/g - if (inputText) - return inputText.replace(codeBlockRegex, '') - return '' - } - - const loadAudio = async () => { - const formData = new FormData() - formData.append('text', removeCodeBlocks(value)) - formData.append('voice', removeCodeBlocks(voice)) - - if (value !== '') { - setAudioState('loading') - - let url = '' - let isPublic = false - - if (params.token) { - url = '/text-to-audio' - isPublic = true - } - else if (params.appId) { - if (pathname.search('explore/installed') > -1) - url = `/installed-apps/${params.appId}/text-to-audio` - else - url = `/apps/${params.appId}/text-to-audio` - } - - try { - const audioResponse = await textToAudio(url, isPublic, formData) - const blob_bytes = Buffer.from(audioResponse.data, 'latin1') - const blob = new Blob([blob_bytes], { type: 'audio/wav' }) - const audioUrl = URL.createObjectURL(blob) - audioRef.current!.src = audioUrl - } - catch (error) { - setAudioState('initial') - console.error('Error playing audio:', error) - } - } - } - - const handleToggle = async () => { - if (audioState === 'initial' || noCache) { - await loadAudio() - } - else if (audioRef.current) { - if (audioState === 'playing') { - audioRef.current.pause() - setAudioState('paused') - } - else { - audioRef.current.play() + const audio_finished_call = (event: string): any => { + switch (event) { + case 'ended': + setAudioState('ended') + break + case 'paused': + setAudioState('ended') + break + case 'loaded': + setAudioState('loading') + break + case 'play': setAudioState('playing') - } + break + case 'error': + setAudioState('ended') + break } } + let url = '' + let isPublic = false - useEffect(() => { - const currentAudio = audioRef.current - - const handleLoading = () => { + if (params.token) { + url = '/text-to-audio' + isPublic = true + } + else if (params.appId) { + if (pathname.search('explore/installed') > -1) + url = `/installed-apps/${params.appId}/text-to-audio` + else + url = `/apps/${params.appId}/text-to-audio` + } + const handleToggle = async () => { + if (audioState === 'playing' || audioState === 'loading') { + setAudioState('paused') + AudioPlayerManager.getInstance().getAudioPlayer(url, isPublic, id, value, voice, audio_finished_call).pauseAudio() + } + else { setAudioState('loading') + AudioPlayerManager.getInstance().getAudioPlayer(url, isPublic, id, value, voice, audio_finished_call).playAudio() } - - const handlePlay = () => { - currentAudio?.play() - setAudioState('playing') - } - - const handleEnded = () => { - setAudioState('ended') - } - - currentAudio?.addEventListener('progress', handleLoading) - currentAudio?.addEventListener('canplaythrough', handlePlay) - currentAudio?.addEventListener('ended', handleEnded) - - return () => { - currentAudio?.removeEventListener('progress', handleLoading) - currentAudio?.removeEventListener('canplaythrough', handlePlay) - currentAudio?.removeEventListener('ended', handleEnded) - URL.revokeObjectURL(currentAudio?.src || '') - currentAudio?.pause() - currentAudio?.setAttribute('src', '') - } - }, []) + } const tooltipContent = { initial: t('appApi.play'), @@ -151,7 +105,6 @@ const AudioBtn = ({ )} -