From dff3f41ef6ea38b5482ca9f91a113d35ac569881 Mon Sep 17 00:00:00 2001 From: chenxu9741 Date: Sun, 4 Aug 2024 14:28:56 +0800 Subject: [PATCH] Workflow TTS playback node filtering issue. (#6877) --- .../app_generator_tts_publisher.py | 9 ++++++++- .../advanced_chat/generate_task_pipeline.py | 7 ++++++- api/core/workflow/nodes/base_node.py | 3 +++ api/core/workflow/workflow_engine_manager.py | 20 +++++++++++++++++++ 4 files changed, 37 insertions(+), 2 deletions(-) diff --git a/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py b/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py index 8325994608..0caff4a2e3 100644 --- a/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py +++ b/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py @@ -5,7 +5,12 @@ import queue import re import threading -from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueTextChunkEvent +from core.app.entities.queue_entities import ( + QueueAgentMessageEvent, + QueueLLMChunkEvent, + QueueNodeSucceededEvent, + QueueTextChunkEvent, +) from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType @@ -88,6 +93,8 @@ class AppGeneratorTTSPublisher: self.msg_text += message.event.chunk.delta.message.content elif isinstance(message.event, QueueTextChunkEvent): self.msg_text += message.event.text + elif isinstance(message.event, QueueNodeSucceededEvent): + self.msg_text += message.event.outputs.get('output', '') self.last_message = message sentence_arr, text_tmp = self._extract_sentence(self.msg_text) if len(sentence_arr) >= min(self.MAX_SENTENCE, 7): diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index be72d89c1e..a042d30e00 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -244,7 +244,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc :return: """ for message in self._queue_manager.listen(): - if publisher: + if hasattr(message.event, 'metadata') and message.event.metadata.get('is_answer_previous_node', False) and publisher: + publisher.publish(message=message) + elif (hasattr(message.event, 'execution_metadata') + and message.event.execution_metadata + and message.event.execution_metadata.get('is_answer_previous_node', False) + and publisher): publisher.publish(message=message) event = message.event diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index f42cee4ccd..d8c812e7ef 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -49,6 +49,8 @@ class BaseNode(ABC): callbacks: Sequence[WorkflowCallback] + is_answer_previous_node: bool = False + def __init__(self, tenant_id: str, app_id: str, workflow_id: str, @@ -110,6 +112,7 @@ class BaseNode(ABC): text=text, metadata={ "node_type": self.node_type, + "is_answer_previous_node": self.is_answer_previous_node, "value_selector": value_selector } ) diff --git a/api/core/workflow/workflow_engine_manager.py b/api/core/workflow/workflow_engine_manager.py index 32f0dbba06..bd2b3eafa7 100644 --- a/api/core/workflow/workflow_engine_manager.py +++ b/api/core/workflow/workflow_engine_manager.py @@ -177,6 +177,19 @@ class WorkflowEngineManager: graph = workflow.graph_dict try: + answer_prov_node_ids = [] + for node in graph.get('nodes', []): + if node.get('id', '') == 'answer': + try: + answer_prov_node_ids.append(node.get('data', {}) + .get('answer', '') + .replace('#', '') + .replace('.text', '') + .replace('{{', '') + .replace('}}', '').split('.')[0]) + except Exception as e: + logger.error(e) + predecessor_node: BaseNode | None = None current_iteration_node: BaseIterationNode | None = None has_entry_node = False @@ -301,6 +314,9 @@ class WorkflowEngineManager: else: next_node = self._get_node(workflow_run_state=workflow_run_state, graph=graph, node_id=next_node_id, callbacks=callbacks) + if next_node and next_node.node_id in answer_prov_node_ids: + next_node.is_answer_previous_node = True + # run workflow, run multiple target nodes in the future self._run_workflow_node( workflow_run_state=workflow_run_state, @@ -854,6 +870,10 @@ class WorkflowEngineManager: raise ValueError(f"Node {node.node_data.title} run failed: {node_run_result.error}") + if node.is_answer_previous_node and not isinstance(node, LLMNode): + if not node_run_result.metadata: + node_run_result.metadata = {} + node_run_result.metadata["is_answer_previous_node"]=True workflow_nodes_and_result.result = node_run_result # node run success