feat: Parallel Execution of Nodes in Workflows (#8192)

Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
Co-authored-by: Yi <yxiaoisme@gmail.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
takatost 2024-09-10 15:23:16 +08:00 committed by GitHub
parent 5da0182800
commit dabfd74622
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
156 changed files with 11158 additions and 5605 deletions

View File

@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field(
description="Dify version",
default="0.7.3",
default="0.8.0",
)
COMMIT_SHA: str = Field(

View File

@ -4,12 +4,10 @@ import os
import threading
import uuid
from collections.abc import Generator
from typing import Literal, Union, overload
from typing import Any, Literal, Optional, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session
import contexts
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
@ -20,20 +18,15 @@ from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGe
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
InvokeFrom,
)
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
from core.file.message_file_parser import MessageFileParser
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from extensions.ext_database import db
from models.account import Account
from models.model import App, Conversation, EndUser, Message
from models.workflow import ConversationVariable, Workflow
from models.workflow import Workflow
logger = logging.getLogger(__name__)
@ -60,13 +53,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
) -> dict: ...
def generate(
self, app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
stream: bool = True,
):
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
stream: bool = True,
) -> dict[str, Any] | Generator[str, Any, None]:
"""
Generate App response.
@ -154,7 +148,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
node_id: str,
user: Account,
args: dict,
stream: bool = True):
stream: bool = True) \
-> dict[str, Any] | Generator[str, Any, None]:
"""
Generate App response.
@ -171,16 +166,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
if args.get('inputs') is None:
raise ValueError('inputs is required')
extras = {
"auto_generate_conversation_name": False
}
# get conversation
conversation = None
conversation_id = args.get('conversation_id')
if conversation_id:
conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user)
# convert to app config
app_config = AdvancedChatAppConfigManager.get_app_config(
app_model=app_model,
@ -191,14 +176,16 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
application_generate_entity = AdvancedChatAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
conversation_id=conversation.id if conversation else None,
conversation_id=None,
inputs={},
query='',
files=[],
user_id=user.id,
stream=stream,
invoke_from=InvokeFrom.DEBUGGER,
extras=extras,
extras={
"auto_generate_conversation_name": False
},
single_iteration_run=AdvancedChatAppGenerateEntity.SingleIterationRunEntity(
node_id=node_id,
inputs=args['inputs']
@ -211,17 +198,28 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
user=user,
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
conversation=conversation,
conversation=None,
stream=stream
)
def _generate(self, *,
workflow: Workflow,
user: Union[Account, EndUser],
invoke_from: InvokeFrom,
application_generate_entity: AdvancedChatAppGenerateEntity,
conversation: Conversation | None = None,
stream: bool = True):
workflow: Workflow,
user: Union[Account, EndUser],
invoke_from: InvokeFrom,
application_generate_entity: AdvancedChatAppGenerateEntity,
conversation: Optional[Conversation] = None,
stream: bool = True) \
-> dict[str, Any] | Generator[str, Any, None]:
"""
Generate App response.
:param workflow: Workflow
:param user: account or end user
:param invoke_from: invoke from source
:param application_generate_entity: application generate entity
:param conversation: conversation
:param stream: is stream
"""
is_first_conversation = False
if not conversation:
is_first_conversation = True
@ -236,7 +234,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# update conversation features
conversation.override_model_configs = workflow.features
db.session.commit()
# db.session.refresh(conversation)
db.session.refresh(conversation)
# init queue manager
queue_manager = MessageBasedAppQueueManager(
@ -248,67 +246,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
message_id=message.id
)
# Init conversation variables
stmt = select(ConversationVariable).where(
ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id
)
with Session(db.engine) as session:
conversation_variables = session.scalars(stmt).all()
if not conversation_variables:
# Create conversation variables if they don't exist.
conversation_variables = [
ConversationVariable.from_variable(
app_id=conversation.app_id, conversation_id=conversation.id, variable=variable
)
for variable in workflow.conversation_variables
]
session.add_all(conversation_variables)
# Convert database entities to variables.
conversation_variables = [item.to_variable() for item in conversation_variables]
session.commit()
# Increment dialogue count.
conversation.dialogue_count += 1
conversation_id = conversation.id
conversation_dialogue_count = conversation.dialogue_count
db.session.commit()
db.session.refresh(conversation)
inputs = application_generate_entity.inputs
query = application_generate_entity.query
files = application_generate_entity.files
user_id = None
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = application_generate_entity.user_id
# Create a variable pool.
system_inputs = {
SystemVariableKey.QUERY: query,
SystemVariableKey.FILES: files,
SystemVariableKey.CONVERSATION_ID: conversation_id,
SystemVariableKey.USER_ID: user_id,
SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count,
}
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=conversation_variables,
)
contexts.workflow_variable_pool.set(variable_pool)
# new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
'flask_app': current_app._get_current_object(),
'flask_app': current_app._get_current_object(), # type: ignore
'application_generate_entity': application_generate_entity,
'queue_manager': queue_manager,
'conversation_id': conversation.id,
'message_id': message.id,
'context': contextvars.copy_context(),
})
@ -334,6 +277,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
def _generate_worker(self, flask_app: Flask,
application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation_id: str,
message_id: str,
context: contextvars.Context) -> None:
"""
@ -349,28 +293,19 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
var.set(val)
with flask_app.app_context():
try:
runner = AdvancedChatAppRunner()
if application_generate_entity.single_iteration_run:
single_iteration_run = application_generate_entity.single_iteration_run
runner.single_iteration_run(
app_id=application_generate_entity.app_config.app_id,
workflow_id=application_generate_entity.app_config.workflow_id,
queue_manager=queue_manager,
inputs=single_iteration_run.inputs,
node_id=single_iteration_run.node_id,
user_id=application_generate_entity.user_id
)
else:
# get message
message = self._get_message(message_id)
# get conversation and message
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)
# chatbot app
runner = AdvancedChatAppRunner()
runner.run(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
message=message
)
# chatbot app
runner = AdvancedChatAppRunner(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message
)
runner.run()
except GenerateTaskStoppedException:
pass
except InvokeAuthorizationError:

View File

@ -1,49 +1,67 @@
import logging
import os
import time
from collections.abc import Mapping
from typing import Any, Optional, cast
from typing import Any, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
InvokeFrom,
)
from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent
from core.app.entities.queue_entities import (
QueueAnnotationReplyEvent,
QueueStopEvent,
QueueTextChunkEvent,
)
from core.moderation.base import ModerationException
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.nodes.base_node import UserFrom
from core.workflow.workflow_engine_manager import WorkflowEngineManager
from core.workflow.entities.node_entities import UserFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models import App, Message, Workflow
from models.model import App, Conversation, EndUser, Message
from models.workflow import ConversationVariable, WorkflowType
logger = logging.getLogger(__name__)
class AdvancedChatAppRunner(AppRunner):
class AdvancedChatAppRunner(WorkflowBasedAppRunner):
"""
AdvancedChat Application Runner
"""
def run(
self,
application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager,
message: Message,
def __init__(
self,
application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message
) -> None:
"""
Run application
:param application_generate_entity: application generate entity
:param queue_manager: application queue manager
:param conversation: conversation
:param message: message
"""
super().__init__(queue_manager)
self.application_generate_entity = application_generate_entity
self.conversation = conversation
self.message = message
def run(self) -> None:
"""
Run application
:return:
"""
app_config = application_generate_entity.app_config
app_config = self.application_generate_entity.app_config
app_config = cast(AdvancedChatAppConfig, app_config)
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
@ -54,101 +72,133 @@ class AdvancedChatAppRunner(AppRunner):
if not workflow:
raise ValueError('Workflow not initialized')
inputs = application_generate_entity.inputs
query = application_generate_entity.query
user_id = None
if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = self.application_generate_entity.user_id
# moderation
if self.handle_input_moderation(
queue_manager=queue_manager,
app_record=app_record,
app_generate_entity=application_generate_entity,
inputs=inputs,
query=query,
message_id=message.id,
):
return
workflow_callbacks: list[WorkflowCallback] = []
if bool(os.environ.get("DEBUG", 'False').lower() == 'true'):
workflow_callbacks.append(WorkflowLoggingCallback())
# annotation reply
if self.handle_annotation_reply(
app_record=app_record,
message=message,
query=query,
queue_manager=queue_manager,
app_generate_entity=application_generate_entity,
):
return
if self.application_generate_entity.single_iteration_run:
# if only single iteration run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=self.application_generate_entity.single_iteration_run.inputs
)
else:
inputs = self.application_generate_entity.inputs
query = self.application_generate_entity.query
files = self.application_generate_entity.files
# moderation
if self.handle_input_moderation(
app_record=app_record,
app_generate_entity=self.application_generate_entity,
inputs=inputs,
query=query,
message_id=self.message.id
):
return
# annotation reply
if self.handle_annotation_reply(
app_record=app_record,
message=self.message,
query=query,
app_generate_entity=self.application_generate_entity
):
return
# Init conversation variables
stmt = select(ConversationVariable).where(
ConversationVariable.app_id == self.conversation.app_id, ConversationVariable.conversation_id == self.conversation.id
)
with Session(db.engine) as session:
conversation_variables = session.scalars(stmt).all()
if not conversation_variables:
# Create conversation variables if they don't exist.
conversation_variables = [
ConversationVariable.from_variable(
app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable
)
for variable in workflow.conversation_variables
]
session.add_all(conversation_variables)
# Convert database entities to variables.
conversation_variables = [item.to_variable() for item in conversation_variables]
session.commit()
# Increment dialogue count.
self.conversation.dialogue_count += 1
conversation_dialogue_count = self.conversation.dialogue_count
db.session.commit()
# Create a variable pool.
system_inputs = {
SystemVariableKey.QUERY: query,
SystemVariableKey.FILES: files,
SystemVariableKey.CONVERSATION_ID: self.conversation.id,
SystemVariableKey.USER_ID: user_id,
SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count,
}
# init variable pool
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=conversation_variables,
)
# init graph
graph = self._init_graph(graph_config=workflow.graph_dict)
db.session.close()
workflow_callbacks: list[WorkflowCallback] = [
WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)
]
if bool(os.environ.get('DEBUG', 'False').lower() == 'true'):
workflow_callbacks.append(WorkflowLoggingCallback())
# RUN WORKFLOW
workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.run_workflow(
workflow=workflow,
user_id=application_generate_entity.user_id,
user_from=UserFrom.ACCOUNT
if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
else UserFrom.END_USER,
invoke_from=application_generate_entity.invoke_from,
workflow_entry = WorkflowEntry(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
workflow_id=workflow.id,
workflow_type=WorkflowType.value_of(workflow.type),
graph=graph,
graph_config=workflow.graph_dict,
user_id=self.application_generate_entity.user_id,
user_from=(
UserFrom.ACCOUNT
if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
else UserFrom.END_USER
),
invoke_from=self.application_generate_entity.invoke_from,
call_depth=self.application_generate_entity.call_depth,
variable_pool=variable_pool,
)
generator = workflow_entry.run(
callbacks=workflow_callbacks,
call_depth=application_generate_entity.call_depth,
)
def single_iteration_run(
self, app_id: str, workflow_id: str, queue_manager: AppQueueManager, inputs: dict, node_id: str, user_id: str
) -> None:
"""
Single iteration run
"""
app_record = db.session.query(App).filter(App.id == app_id).first()
if not app_record:
raise ValueError('App not found')
workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id)
if not workflow:
raise ValueError('Workflow not initialized')
workflow_callbacks = [WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)]
workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.single_step_run_iteration_workflow_node(
workflow=workflow, node_id=node_id, user_id=user_id, user_inputs=inputs, callbacks=workflow_callbacks
)
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
"""
Get workflow
"""
# fetch workflow by workflow_id
workflow = (
db.session.query(Workflow)
.filter(
Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id
)
.first()
)
# return workflow
return workflow
for event in generator:
self._handle_event(workflow_entry, event)
def handle_input_moderation(
self,
queue_manager: AppQueueManager,
app_record: App,
app_generate_entity: AdvancedChatAppGenerateEntity,
inputs: Mapping[str, Any],
query: str,
message_id: str,
self,
app_record: App,
app_generate_entity: AdvancedChatAppGenerateEntity,
inputs: Mapping[str, Any],
query: str,
message_id: str
) -> bool:
"""
Handle input moderation
:param queue_manager: application queue manager
:param app_record: app record
:param app_generate_entity: application generate entity
:param inputs: inputs
@ -167,30 +217,23 @@ class AdvancedChatAppRunner(AppRunner):
message_id=message_id,
)
except ModerationException as e:
self._stream_output(
queue_manager=queue_manager,
self._complete_with_stream_output(
text=str(e),
stream=app_generate_entity.stream,
stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION,
stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION
)
return True
return False
def handle_annotation_reply(
self,
app_record: App,
message: Message,
query: str,
queue_manager: AppQueueManager,
app_generate_entity: AdvancedChatAppGenerateEntity,
) -> bool:
def handle_annotation_reply(self, app_record: App,
message: Message,
query: str,
app_generate_entity: AdvancedChatAppGenerateEntity) -> bool:
"""
Handle annotation reply
:param app_record: app record
:param message: message
:param query: query
:param queue_manager: application queue manager
:param app_generate_entity: application generate entity
"""
# annotation reply
@ -203,37 +246,32 @@ class AdvancedChatAppRunner(AppRunner):
)
if annotation_reply:
queue_manager.publish(
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), PublishFrom.APPLICATION_MANAGER
self._publish_event(
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id)
)
self._stream_output(
queue_manager=queue_manager,
self._complete_with_stream_output(
text=annotation_reply.content,
stream=app_generate_entity.stream,
stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY,
stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY
)
return True
return False
def _stream_output(
self, queue_manager: AppQueueManager, text: str, stream: bool, stopped_by: QueueStopEvent.StopBy
) -> None:
def _complete_with_stream_output(self,
text: str,
stopped_by: QueueStopEvent.StopBy) -> None:
"""
Direct output
:param queue_manager: application queue manager
:param text: text
:param stream: stream
:return:
"""
if stream:
index = 0
for token in text:
queue_manager.publish(QueueTextChunkEvent(text=token), PublishFrom.APPLICATION_MANAGER)
index += 1
time.sleep(0.01)
else:
queue_manager.publish(QueueTextChunkEvent(text=text), PublishFrom.APPLICATION_MANAGER)
self._publish_event(
QueueTextChunkEvent(
text=text
)
)
queue_manager.publish(QueueStopEvent(stopped_by=stopped_by), PublishFrom.APPLICATION_MANAGER)
self._publish_event(
QueueStopEvent(stopped_by=stopped_by)
)

View File

@ -2,9 +2,8 @@ import json
import logging
import time
from collections.abc import Generator
from typing import Any, Optional, Union, cast
from typing import Any, Optional, Union
import contexts
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@ -22,6 +21,9 @@ from core.app.entities.queue_entities import (
QueueNodeFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
QueuePingEvent,
QueueRetrieverResourcesEvent,
QueueStopEvent,
@ -31,34 +33,28 @@ from core.app.entities.queue_entities import (
QueueWorkflowSucceededEvent,
)
from core.app.entities.task_entities import (
AdvancedChatTaskState,
ChatbotAppBlockingResponse,
ChatbotAppStreamResponse,
ChatflowStreamGenerateRoute,
ErrorStreamResponse,
MessageAudioEndStreamResponse,
MessageAudioStreamResponse,
MessageEndStreamResponse,
StreamResponse,
WorkflowTaskState,
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
from core.file.file_obj import FileVar
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.node_entities import NodeType
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.answer.entities import TextGenerateRouteChunk, VarGenerateRouteChunk
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from events.message_event import message_was_created
from extensions.ext_database import db
from models.account import Account
from models.model import Conversation, EndUser, Message
from models.workflow import (
Workflow,
WorkflowNodeExecution,
WorkflowRunStatus,
)
@ -69,16 +65,15 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
"""
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
_task_state: AdvancedChatTaskState
_task_state: WorkflowTaskState
_application_generate_entity: AdvancedChatAppGenerateEntity
_workflow: Workflow
_user: Union[Account, EndUser]
# Deprecated
_workflow_system_variables: dict[SystemVariableKey, Any]
_iteration_nested_relations: dict[str, list[str]]
def __init__(
self, application_generate_entity: AdvancedChatAppGenerateEntity,
self,
application_generate_entity: AdvancedChatAppGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
conversation: Conversation,
@ -106,7 +101,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._workflow = workflow
self._conversation = conversation
self._message = message
# Deprecated
self._workflow_system_variables = {
SystemVariableKey.QUERY: message.query,
SystemVariableKey.FILES: application_generate_entity.files,
@ -114,12 +108,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
SystemVariableKey.USER_ID: user_id,
}
self._task_state = AdvancedChatTaskState(
usage=LLMUsage.empty_usage()
)
self._task_state = WorkflowTaskState()
self._iteration_nested_relations = self._get_iteration_nested_relations(self._workflow.graph_dict)
self._stream_generate_routes = self._get_stream_generate_routes()
self._conversation_name_generate_thread = None
def process(self):
@ -140,6 +130,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
generator = self._wrapper_process_stream_response(
trace_manager=self._application_generate_entity.trace_manager
)
if self._stream:
return self._to_stream_response(generator)
else:
@ -199,17 +190,18 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
Generator[StreamResponse, None, None]:
publisher = None
tts_publisher = None
task_id = self._application_generate_entity.task_id
tenant_id = self._application_generate_entity.app_config.tenant_id
features_dict = self._workflow.features_dict
if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[
'text_to_speech'].get('autoPlay') == 'enabled':
publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
while True:
audio_response = self._listenAudioMsg(publisher, task_id=task_id)
audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id)
if audio_response:
yield audio_response
else:
@ -220,9 +212,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
# timeout
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
try:
if not publisher:
if not tts_publisher:
break
audio_trunk = publisher.checkAndGetAudio()
audio_trunk = tts_publisher.checkAndGetAudio()
if audio_trunk is None:
# release cpu
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
@ -240,34 +232,34 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
def _process_stream_response(
self,
publisher: AppGeneratorTTSPublisher,
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
trace_manager: Optional[TraceQueueManager] = None
) -> Generator[StreamResponse, None, None]:
"""
Process stream response.
:return:
"""
for message in self._queue_manager.listen():
if (message.event
and getattr(message.event, 'metadata', None)
and message.event.metadata.get('is_answer_previous_node', False)
and publisher):
publisher.publish(message=message)
elif (hasattr(message.event, 'execution_metadata')
and message.event.execution_metadata
and message.event.execution_metadata.get('is_answer_previous_node', False)
and publisher):
publisher.publish(message=message)
event = message.event
# init fake graph runtime state
graph_runtime_state = None
workflow_run = None
if isinstance(event, QueueErrorEvent):
for queue_message in self._queue_manager.listen():
event = queue_message.event
if isinstance(event, QueuePingEvent):
yield self._ping_stream_response()
elif isinstance(event, QueueErrorEvent):
err = self._handle_error(event, self._message)
yield self._error_to_stream_response(err)
break
elif isinstance(event, QueueWorkflowStartedEvent):
workflow_run = self._handle_workflow_start()
# override graph runtime state
graph_runtime_state = event.graph_runtime_state
self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
# init workflow run
workflow_run = self._handle_workflow_run_start()
self._refetch_message()
self._message.workflow_run_id = workflow_run.id
db.session.commit()
@ -279,133 +271,242 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
workflow_run=workflow_run
)
elif isinstance(event, QueueNodeStartedEvent):
workflow_node_execution = self._handle_node_start(event)
if not workflow_run:
raise Exception('Workflow run not initialized.')
# search stream_generate_routes if node id is answer start at node
if not self._task_state.current_stream_generate_state and event.node_id in self._stream_generate_routes:
self._task_state.current_stream_generate_state = self._stream_generate_routes[event.node_id]
# reset current route position to 0
self._task_state.current_stream_generate_state.current_route_position = 0
workflow_node_execution = self._handle_node_execution_start(
workflow_run=workflow_run,
event=event
)
# generate stream outputs when node started
yield from self._generate_stream_outputs_when_node_started()
yield self._workflow_node_start_to_stream_response(
response = self._workflow_node_start_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution
)
elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent):
workflow_node_execution = self._handle_node_finished(event)
# stream outputs when node finished
generator = self._generate_stream_outputs_when_node_finished()
if generator:
yield from generator
if response:
yield response
elif isinstance(event, QueueNodeSucceededEvent):
workflow_node_execution = self._handle_workflow_node_execution_success(event)
yield self._workflow_node_finish_to_stream_response(
response = self._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution
)
if isinstance(event, QueueNodeFailedEvent):
yield from self._handle_iteration_exception(
task_id=self._application_generate_entity.task_id,
error=f'Child node failed: {event.error}'
)
elif isinstance(event, QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent):
if isinstance(event, QueueIterationNextEvent):
# clear ran node execution infos of current iteration
iteration_relations = self._iteration_nested_relations.get(event.node_id)
if iteration_relations:
for node_id in iteration_relations:
self._task_state.ran_node_execution_infos.pop(node_id, None)
if response:
yield response
elif isinstance(event, QueueNodeFailedEvent):
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
yield self._handle_iteration_to_stream_response(self._application_generate_entity.task_id, event)
self._handle_iteration_operation(event)
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
workflow_run = self._handle_workflow_finished(
event, conversation_id=self._conversation.id, trace_manager=trace_manager
response = self._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution
)
if workflow_run:
if response:
yield response
elif isinstance(event, QueueParallelBranchRunStartedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_parallel_branch_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_parallel_branch_finished_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueIterationStartEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_iteration_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueIterationNextEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_iteration_next_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueIterationCompletedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_iteration_completed_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueWorkflowSucceededEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
if not graph_runtime_state:
raise Exception('Graph runtime state not initialized.')
workflow_run = self._handle_workflow_run_success(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=json.dumps(event.outputs) if event.outputs else None,
conversation_id=self._conversation.id,
trace_manager=trace_manager,
)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
)
self._queue_manager.publish(
QueueAdvancedChatMessageEndEvent(),
PublishFrom.TASK_PIPELINE
)
elif isinstance(event, QueueWorkflowFailedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
if not graph_runtime_state:
raise Exception('Graph runtime state not initialized.')
workflow_run = self._handle_workflow_run_failed(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.FAILED,
error=event.error,
conversation_id=self._conversation.id,
trace_manager=trace_manager,
)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
)
err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))
yield self._error_to_stream_response(self._handle_error(err_event, self._message))
break
elif isinstance(event, QueueStopEvent):
if workflow_run and graph_runtime_state:
workflow_run = self._handle_workflow_run_failed(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.STOPPED,
error=event.get_stop_reason(),
conversation_id=self._conversation.id,
trace_manager=trace_manager,
)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
)
if workflow_run.status == WorkflowRunStatus.FAILED.value:
err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))
yield self._error_to_stream_response(self._handle_error(err_event, self._message))
break
if isinstance(event, QueueStopEvent):
# Save message
self._save_message()
yield self._message_end_to_stream_response()
break
else:
self._queue_manager.publish(
QueueAdvancedChatMessageEndEvent(),
PublishFrom.TASK_PIPELINE
)
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
if output_moderation_answer:
self._task_state.answer = output_moderation_answer
yield self._message_replace_to_stream_response(answer=output_moderation_answer)
# Save message
self._save_message()
self._save_message(graph_runtime_state=graph_runtime_state)
yield self._message_end_to_stream_response()
break
elif isinstance(event, QueueRetrieverResourcesEvent):
self._handle_retriever_resources(event)
self._refetch_message()
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
if self._task_state.metadata else None
db.session.commit()
db.session.refresh(self._message)
db.session.close()
elif isinstance(event, QueueAnnotationReplyEvent):
self._handle_annotation_reply(event)
self._refetch_message()
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
if self._task_state.metadata else None
db.session.commit()
db.session.refresh(self._message)
db.session.close()
elif isinstance(event, QueueTextChunkEvent):
delta_text = event.text
if delta_text is None:
continue
if not self._is_stream_out_support(
event=event
):
continue
# handle output moderation chunk
should_direct_answer = self._handle_output_moderation_chunk(delta_text)
if should_direct_answer:
continue
# only publish tts message at text chunk streaming
if tts_publisher:
tts_publisher.publish(message=queue_message)
self._task_state.answer += delta_text
yield self._message_to_stream_response(delta_text, self._message.id)
elif isinstance(event, QueueMessageReplaceEvent):
# published by moderation
yield self._message_replace_to_stream_response(answer=event.text)
elif isinstance(event, QueuePingEvent):
yield self._ping_stream_response()
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
if not graph_runtime_state:
raise Exception('Graph runtime state not initialized.')
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
if output_moderation_answer:
self._task_state.answer = output_moderation_answer
yield self._message_replace_to_stream_response(answer=output_moderation_answer)
# Save message
self._save_message(graph_runtime_state=graph_runtime_state)
yield self._message_end_to_stream_response()
else:
continue
if publisher:
publisher.publish(None)
# publish None when task finished
if tts_publisher:
tts_publisher.publish(None)
if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join()
def _save_message(self) -> None:
def _save_message(self, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
"""
Save message.
:return:
"""
self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
self._refetch_message()
self._message.answer = self._task_state.answer
self._message.provider_response_latency = time.perf_counter() - self._start_at
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
if self._task_state.metadata else None
if self._task_state.metadata and self._task_state.metadata.get('usage'):
usage = LLMUsage(**self._task_state.metadata['usage'])
if graph_runtime_state and graph_runtime_state.llm_usage:
usage = graph_runtime_state.llm_usage
self._message.message_tokens = usage.prompt_tokens
self._message.message_unit_price = usage.prompt_unit_price
self._message.message_price_unit = usage.prompt_price_unit
@ -432,7 +533,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
"""
extras = {}
if self._task_state.metadata:
extras['metadata'] = self._task_state.metadata
extras['metadata'] = self._task_state.metadata.copy()
if 'annotation_reply' in extras['metadata']:
del extras['metadata']['annotation_reply']
return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id,
@ -440,323 +544,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
**extras
)
def _get_stream_generate_routes(self) -> dict[str, ChatflowStreamGenerateRoute]:
"""
Get stream generate routes.
:return:
"""
# find all answer nodes
graph = self._workflow.graph_dict
answer_node_configs = [
node for node in graph['nodes']
if node.get('data', {}).get('type') == NodeType.ANSWER.value
]
# parse stream output node value selectors of answer nodes
stream_generate_routes = {}
for node_config in answer_node_configs:
# get generate route for stream output
answer_node_id = node_config['id']
generate_route = AnswerNode.extract_generate_route_selectors(node_config)
start_node_ids = self._get_answer_start_at_node_ids(graph, answer_node_id)
if not start_node_ids:
continue
for start_node_id in start_node_ids:
stream_generate_routes[start_node_id] = ChatflowStreamGenerateRoute(
answer_node_id=answer_node_id,
generate_route=generate_route
)
return stream_generate_routes
def _get_answer_start_at_node_ids(self, graph: dict, target_node_id: str) \
-> list[str]:
"""
Get answer start at node id.
:param graph: graph
:param target_node_id: target node ID
:return:
"""
nodes = graph.get('nodes')
edges = graph.get('edges')
# fetch all ingoing edges from source node
ingoing_edges = []
for edge in edges:
if edge.get('target') == target_node_id:
ingoing_edges.append(edge)
if not ingoing_edges:
# check if it's the first node in the iteration
target_node = next((node for node in nodes if node.get('id') == target_node_id), None)
if not target_node:
return []
node_iteration_id = target_node.get('data', {}).get('iteration_id')
# get iteration start node id
for node in nodes:
if node.get('id') == node_iteration_id:
if node.get('data', {}).get('start_node_id') == target_node_id:
return [target_node_id]
return []
start_node_ids = []
for ingoing_edge in ingoing_edges:
source_node_id = ingoing_edge.get('source')
source_node = next((node for node in nodes if node.get('id') == source_node_id), None)
if not source_node:
continue
node_type = source_node.get('data', {}).get('type')
node_iteration_id = source_node.get('data', {}).get('iteration_id')
iteration_start_node_id = None
if node_iteration_id:
iteration_node = next((node for node in nodes if node.get('id') == node_iteration_id), None)
iteration_start_node_id = iteration_node.get('data', {}).get('start_node_id')
if node_type in [
NodeType.ANSWER.value,
NodeType.IF_ELSE.value,
NodeType.QUESTION_CLASSIFIER.value,
NodeType.ITERATION.value,
NodeType.LOOP.value
]:
start_node_id = target_node_id
start_node_ids.append(start_node_id)
elif node_type == NodeType.START.value or \
node_iteration_id is not None and iteration_start_node_id == source_node.get('id'):
start_node_id = source_node_id
start_node_ids.append(start_node_id)
else:
sub_start_node_ids = self._get_answer_start_at_node_ids(graph, source_node_id)
if sub_start_node_ids:
start_node_ids.extend(sub_start_node_ids)
return start_node_ids
def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]:
"""
Get iteration nested relations.
:param graph: graph
:return:
"""
nodes = graph.get('nodes')
iteration_ids = [node.get('id') for node in nodes
if node.get('data', {}).get('type') in [
NodeType.ITERATION.value,
NodeType.LOOP.value,
]]
return {
iteration_id: [
node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id
] for iteration_id in iteration_ids
}
def _generate_stream_outputs_when_node_started(self) -> Generator:
"""
Generate stream outputs.
:return:
"""
if self._task_state.current_stream_generate_state:
route_chunks = self._task_state.current_stream_generate_state.generate_route[
self._task_state.current_stream_generate_state.current_route_position:
]
for route_chunk in route_chunks:
if route_chunk.type == 'text':
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
# handle output moderation chunk
should_direct_answer = self._handle_output_moderation_chunk(route_chunk.text)
if should_direct_answer:
continue
self._task_state.answer += route_chunk.text
yield self._message_to_stream_response(route_chunk.text, self._message.id)
else:
break
self._task_state.current_stream_generate_state.current_route_position += 1
# all route chunks are generated
if self._task_state.current_stream_generate_state.current_route_position == len(
self._task_state.current_stream_generate_state.generate_route
):
self._task_state.current_stream_generate_state = None
def _generate_stream_outputs_when_node_finished(self) -> Optional[Generator]:
"""
Generate stream outputs.
:return:
"""
if not self._task_state.current_stream_generate_state:
return
route_chunks = self._task_state.current_stream_generate_state.generate_route[
self._task_state.current_stream_generate_state.current_route_position:]
for route_chunk in route_chunks:
if route_chunk.type == 'text':
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
self._task_state.answer += route_chunk.text
yield self._message_to_stream_response(route_chunk.text, self._message.id)
else:
value = None
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
value_selector = route_chunk.value_selector
if not value_selector:
self._task_state.current_stream_generate_state.current_route_position += 1
continue
route_chunk_node_id = value_selector[0]
if route_chunk_node_id == 'sys':
# system variable
value = contexts.workflow_variable_pool.get().get(value_selector)
if value:
value = value.text
elif route_chunk_node_id in self._iteration_nested_relations:
# it's a iteration variable
if not self._iteration_state or route_chunk_node_id not in self._iteration_state.current_iterations:
continue
iteration_state = self._iteration_state.current_iterations[route_chunk_node_id]
iterator = iteration_state.inputs
if not iterator:
continue
iterator_selector = iterator.get('iterator_selector', [])
if value_selector[1] == 'index':
value = iteration_state.current_index
elif value_selector[1] == 'item':
value = iterator_selector[iteration_state.current_index] if iteration_state.current_index < len(
iterator_selector
) else None
else:
# check chunk node id is before current node id or equal to current node id
if route_chunk_node_id not in self._task_state.ran_node_execution_infos:
break
latest_node_execution_info = self._task_state.latest_node_execution_info
# get route chunk node execution info
route_chunk_node_execution_info = self._task_state.ran_node_execution_infos[route_chunk_node_id]
if (route_chunk_node_execution_info.node_type == NodeType.LLM
and latest_node_execution_info.node_type == NodeType.LLM):
# only LLM support chunk stream output
self._task_state.current_stream_generate_state.current_route_position += 1
continue
# get route chunk node execution
route_chunk_node_execution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == route_chunk_node_execution_info.workflow_node_execution_id
).first()
outputs = route_chunk_node_execution.outputs_dict
# get value from outputs
value = None
for key in value_selector[1:]:
if not value:
value = outputs.get(key) if outputs else None
else:
value = value.get(key)
if value is not None:
text = ''
if isinstance(value, str | int | float):
text = str(value)
elif isinstance(value, FileVar):
# convert file to markdown
text = value.to_markdown()
elif isinstance(value, dict):
# handle files
file_vars = self._fetch_files_from_variable_value(value)
if file_vars:
file_var = file_vars[0]
try:
file_var_obj = FileVar(**file_var)
# convert file to markdown
text = file_var_obj.to_markdown()
except Exception as e:
logger.error(f'Error creating file var: {e}')
if not text:
# other types
text = json.dumps(value, ensure_ascii=False)
elif isinstance(value, list):
# handle files
file_vars = self._fetch_files_from_variable_value(value)
for file_var in file_vars:
try:
file_var_obj = FileVar(**file_var)
except Exception as e:
logger.error(f'Error creating file var: {e}')
continue
# convert file to markdown
text = file_var_obj.to_markdown() + ' '
text = text.strip()
if not text and value:
# other types
text = json.dumps(value, ensure_ascii=False)
if text:
self._task_state.answer += text
yield self._message_to_stream_response(text, self._message.id)
self._task_state.current_stream_generate_state.current_route_position += 1
# all route chunks are generated
if self._task_state.current_stream_generate_state.current_route_position == len(
self._task_state.current_stream_generate_state.generate_route
):
self._task_state.current_stream_generate_state = None
def _is_stream_out_support(self, event: QueueTextChunkEvent) -> bool:
"""
Is stream out support
:param event: queue text chunk event
:return:
"""
if not event.metadata:
return True
if 'node_id' not in event.metadata:
return True
node_type = event.metadata.get('node_type')
stream_output_value_selector = event.metadata.get('value_selector')
if not stream_output_value_selector:
return False
if not self._task_state.current_stream_generate_state:
return False
route_chunk = self._task_state.current_stream_generate_state.generate_route[
self._task_state.current_stream_generate_state.current_route_position]
if route_chunk.type != 'var':
return False
if node_type != NodeType.LLM:
# only LLM support chunk stream output
return False
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
value_selector = route_chunk.value_selector
# check chunk node id is before current node id or equal to current node id
if value_selector != stream_output_value_selector:
return False
return True
def _handle_output_moderation_chunk(self, text: str) -> bool:
"""
Handle output moderation chunk.
@ -782,3 +569,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._output_moderation_handler.append_new_token(text)
return False
def _refetch_message(self) -> None:
"""
Refetch message.
:return:
"""
message = db.session.query(Message).filter(Message.id == self._message.id).first()
if message:
self._message = message

View File

@ -1,203 +0,0 @@
from typing import Any, Optional
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
from models.workflow import Workflow
class WorkflowEventTriggerCallback(WorkflowCallback):
def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
self._queue_manager = queue_manager
def on_workflow_run_started(self) -> None:
"""
Workflow run started
"""
self._queue_manager.publish(
QueueWorkflowStartedEvent(),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_run_succeeded(self) -> None:
"""
Workflow run succeeded
"""
self._queue_manager.publish(
QueueWorkflowSucceededEvent(),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_run_failed(self, error: str) -> None:
"""
Workflow run failed
"""
self._queue_manager.publish(
QueueWorkflowFailedEvent(
error=error
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_node_execute_started(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
node_run_index: int = 1,
predecessor_node_id: Optional[str] = None) -> None:
"""
Workflow node execute started
"""
self._queue_manager.publish(
QueueNodeStartedEvent(
node_id=node_id,
node_type=node_type,
node_data=node_data,
node_run_index=node_run_index,
predecessor_node_id=predecessor_node_id
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_node_execute_succeeded(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
inputs: Optional[dict] = None,
process_data: Optional[dict] = None,
outputs: Optional[dict] = None,
execution_metadata: Optional[dict] = None) -> None:
"""
Workflow node execute succeeded
"""
self._queue_manager.publish(
QueueNodeSucceededEvent(
node_id=node_id,
node_type=node_type,
node_data=node_data,
inputs=inputs,
process_data=process_data,
outputs=outputs,
execution_metadata=execution_metadata
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_node_execute_failed(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
error: str,
inputs: Optional[dict] = None,
outputs: Optional[dict] = None,
process_data: Optional[dict] = None) -> None:
"""
Workflow node execute failed
"""
self._queue_manager.publish(
QueueNodeFailedEvent(
node_id=node_id,
node_type=node_type,
node_data=node_data,
inputs=inputs,
outputs=outputs,
process_data=process_data,
error=error
),
PublishFrom.APPLICATION_MANAGER
)
def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None:
"""
Publish text chunk
"""
self._queue_manager.publish(
QueueTextChunkEvent(
text=text,
metadata={
"node_id": node_id,
**metadata
}
), PublishFrom.APPLICATION_MANAGER
)
def on_workflow_iteration_started(self,
node_id: str,
node_type: NodeType,
node_run_index: int = 1,
node_data: Optional[BaseNodeData] = None,
inputs: dict = None,
predecessor_node_id: Optional[str] = None,
metadata: Optional[dict] = None) -> None:
"""
Publish iteration started
"""
self._queue_manager.publish(
QueueIterationStartEvent(
node_id=node_id,
node_type=node_type,
node_run_index=node_run_index,
node_data=node_data,
inputs=inputs,
predecessor_node_id=predecessor_node_id,
metadata=metadata
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_iteration_next(self, node_id: str,
node_type: NodeType,
index: int,
node_run_index: int,
output: Optional[Any]) -> None:
"""
Publish iteration next
"""
self._queue_manager._publish(
QueueIterationNextEvent(
node_id=node_id,
node_type=node_type,
index=index,
node_run_index=node_run_index,
output=output
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_iteration_completed(self, node_id: str,
node_type: NodeType,
node_run_index: int,
outputs: dict) -> None:
"""
Publish iteration completed
"""
self._queue_manager._publish(
QueueIterationCompletedEvent(
node_id=node_id,
node_type=node_type,
node_run_index=node_run_index,
outputs=outputs
),
PublishFrom.APPLICATION_MANAGER
)
def on_event(self, event: AppQueueEvent) -> None:
"""
Publish event
"""
self._queue_manager.publish(
event,
PublishFrom.APPLICATION_MANAGER
)

View File

@ -16,7 +16,7 @@ class AppGenerateResponseConverter(ABC):
def convert(cls, response: Union[
AppBlockingResponse,
Generator[AppStreamResponse, Any, None]
], invoke_from: InvokeFrom):
], invoke_from: InvokeFrom) -> dict[str, Any] | Generator[str, Any, None]:
if invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_full_response(response)

View File

@ -1,6 +1,6 @@
import time
from collections.abc import Generator
from typing import TYPE_CHECKING, Optional, Union
from collections.abc import Generator, Mapping
from typing import TYPE_CHECKING, Any, Optional, Union
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@ -347,7 +347,7 @@ class AppRunner:
self, app_id: str,
tenant_id: str,
app_generate_entity: AppGenerateEntity,
inputs: dict,
inputs: Mapping[str, Any],
query: str,
message_id: str,
) -> tuple[bool, dict, str]:

View File

@ -4,7 +4,7 @@ import os
import threading
import uuid
from collections.abc import Generator
from typing import Literal, Union, overload
from typing import Any, Literal, Optional, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
@ -40,6 +40,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
args: dict,
invoke_from: InvokeFrom,
stream: Literal[True] = True,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None
) -> Generator[str, None, None]: ...
@overload
@ -50,16 +52,20 @@ class WorkflowAppGenerator(BaseAppGenerator):
args: dict,
invoke_from: InvokeFrom,
stream: Literal[False] = False,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None
) -> dict: ...
def generate(
self, app_model: App,
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
stream: bool = True,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None
):
"""
Generate App response.
@ -71,6 +77,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
:param invoke_from: invoke from source
:param stream: is stream
:param call_depth: call depth
:param workflow_thread_pool_id: workflow thread pool id
"""
inputs = args['inputs']
@ -118,16 +125,19 @@ class WorkflowAppGenerator(BaseAppGenerator):
application_generate_entity=application_generate_entity,
invoke_from=invoke_from,
stream=stream,
workflow_thread_pool_id=workflow_thread_pool_id
)
def _generate(
self, app_model: App,
self, *,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
application_generate_entity: WorkflowAppGenerateEntity,
invoke_from: InvokeFrom,
stream: bool = True,
) -> Union[dict, Generator[str, None, None]]:
workflow_thread_pool_id: Optional[str] = None
) -> dict[str, Any] | Generator[str, None, None]:
"""
Generate App response.
@ -137,6 +147,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
:param application_generate_entity: application generate entity
:param invoke_from: invoke from source
:param stream: is stream
:param workflow_thread_pool_id: workflow thread pool id
"""
# init queue manager
queue_manager = WorkflowAppQueueManager(
@ -148,10 +159,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
# new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
'flask_app': current_app._get_current_object(),
'flask_app': current_app._get_current_object(), # type: ignore
'application_generate_entity': application_generate_entity,
'queue_manager': queue_manager,
'context': contextvars.copy_context()
'context': contextvars.copy_context(),
'workflow_thread_pool_id': workflow_thread_pool_id
})
worker_thread.start()
@ -175,7 +187,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
node_id: str,
user: Account,
args: dict,
stream: bool = True):
stream: bool = True) -> dict[str, Any] | Generator[str, Any, None]:
"""
Generate App response.
@ -192,10 +204,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
if args.get('inputs') is None:
raise ValueError('inputs is required')
extras = {
"auto_generate_conversation_name": False
}
# convert to app config
app_config = WorkflowAppConfigManager.get_app_config(
app_model=app_model,
@ -211,7 +219,9 @@ class WorkflowAppGenerator(BaseAppGenerator):
user_id=user.id,
stream=stream,
invoke_from=InvokeFrom.DEBUGGER,
extras=extras,
extras={
"auto_generate_conversation_name": False
},
single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity(
node_id=node_id,
inputs=args['inputs']
@ -231,12 +241,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
def _generate_worker(self, flask_app: Flask,
application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager,
context: contextvars.Context) -> None:
context: contextvars.Context,
workflow_thread_pool_id: Optional[str] = None) -> None:
"""
Generate worker in a new thread.
:param flask_app: Flask app
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:param workflow_thread_pool_id: workflow thread pool id
:return:
"""
for var, val in context.items():
@ -244,22 +256,13 @@ class WorkflowAppGenerator(BaseAppGenerator):
with flask_app.app_context():
try:
# workflow app
runner = WorkflowAppRunner()
if application_generate_entity.single_iteration_run:
single_iteration_run = application_generate_entity.single_iteration_run
runner.single_iteration_run(
app_id=application_generate_entity.app_config.app_id,
workflow_id=application_generate_entity.app_config.workflow_id,
queue_manager=queue_manager,
inputs=single_iteration_run.inputs,
node_id=single_iteration_run.node_id,
user_id=application_generate_entity.user_id
)
else:
runner.run(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager
)
runner = WorkflowAppRunner(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
workflow_thread_pool_id=workflow_thread_pool_id
)
runner.run()
except GenerateTaskStoppedException:
pass
except InvokeAuthorizationError:
@ -271,14 +274,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
if os.environ.get("DEBUG") and os.environ.get("DEBUG", "false").lower() == 'true':
logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e:
logger.exception("Unknown Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
finally:
db.session.remove()
db.session.close()
def _handle_response(self, application_generate_entity: WorkflowAppGenerateEntity,
workflow: Workflow,

View File

@ -4,46 +4,61 @@ from typing import Optional, cast
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
from core.app.apps.workflow.workflow_event_trigger_callback import WorkflowEventTriggerCallback
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback
from core.app.entities.app_invoke_entities import (
InvokeFrom,
WorkflowAppGenerateEntity,
)
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.node_entities import UserFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base_node import UserFrom
from core.workflow.workflow_engine_manager import WorkflowEngineManager
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.model import App, EndUser
from models.workflow import Workflow
from models.workflow import WorkflowType
logger = logging.getLogger(__name__)
class WorkflowAppRunner:
class WorkflowAppRunner(WorkflowBasedAppRunner):
"""
Workflow Application Runner
"""
def run(self, application_generate_entity: WorkflowAppGenerateEntity, queue_manager: AppQueueManager) -> None:
def __init__(
self,
application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager,
workflow_thread_pool_id: Optional[str] = None
) -> None:
"""
:param application_generate_entity: application generate entity
:param queue_manager: application queue manager
:param workflow_thread_pool_id: workflow thread pool id
"""
self.application_generate_entity = application_generate_entity
self.queue_manager = queue_manager
self.workflow_thread_pool_id = workflow_thread_pool_id
def run(self) -> None:
"""
Run application
:param application_generate_entity: application generate entity
:param queue_manager: application queue manager
:return:
"""
app_config = application_generate_entity.app_config
app_config = self.application_generate_entity.app_config
app_config = cast(WorkflowAppConfig, app_config)
user_id = None
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = application_generate_entity.user_id
user_id = self.application_generate_entity.user_id
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
if not app_record:
@ -53,80 +68,64 @@ class WorkflowAppRunner:
if not workflow:
raise ValueError('Workflow not initialized')
inputs = application_generate_entity.inputs
files = application_generate_entity.files
db.session.close()
workflow_callbacks: list[WorkflowCallback] = [
WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)
]
workflow_callbacks: list[WorkflowCallback] = []
if bool(os.environ.get('DEBUG', 'False').lower() == 'true'):
workflow_callbacks.append(WorkflowLoggingCallback())
# Create a variable pool.
system_inputs = {
SystemVariableKey.FILES: files,
SystemVariableKey.USER_ID: user_id,
}
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=[],
)
# if only single iteration run is requested
if self.application_generate_entity.single_iteration_run:
# if only single iteration run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=self.application_generate_entity.single_iteration_run.inputs
)
else:
inputs = self.application_generate_entity.inputs
files = self.application_generate_entity.files
# Create a variable pool.
system_inputs = {
SystemVariableKey.FILES: files,
SystemVariableKey.USER_ID: user_id,
}
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=[],
)
# init graph
graph = self._init_graph(graph_config=workflow.graph_dict)
# RUN WORKFLOW
workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.run_workflow(
workflow=workflow,
user_id=application_generate_entity.user_id,
user_from=UserFrom.ACCOUNT
if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
else UserFrom.END_USER,
invoke_from=application_generate_entity.invoke_from,
callbacks=workflow_callbacks,
call_depth=application_generate_entity.call_depth,
workflow_entry = WorkflowEntry(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
workflow_id=workflow.id,
workflow_type=WorkflowType.value_of(workflow.type),
graph=graph,
graph_config=workflow.graph_dict,
user_id=self.application_generate_entity.user_id,
user_from=(
UserFrom.ACCOUNT
if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
else UserFrom.END_USER
),
invoke_from=self.application_generate_entity.invoke_from,
call_depth=self.application_generate_entity.call_depth,
variable_pool=variable_pool,
thread_pool_id=self.workflow_thread_pool_id
)
def single_iteration_run(
self, app_id: str, workflow_id: str, queue_manager: AppQueueManager, inputs: dict, node_id: str, user_id: str
) -> None:
"""
Single iteration run
"""
app_record = db.session.query(App).filter(App.id == app_id).first()
if not app_record:
raise ValueError('App not found')
if not app_record.workflow_id:
raise ValueError('Workflow not initialized')
workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id)
if not workflow:
raise ValueError('Workflow not initialized')
workflow_callbacks = [WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)]
workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.single_step_run_iteration_workflow_node(
workflow=workflow, node_id=node_id, user_id=user_id, user_inputs=inputs, callbacks=workflow_callbacks
generator = workflow_entry.run(
callbacks=workflow_callbacks
)
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
"""
Get workflow
"""
# fetch workflow by workflow_id
workflow = (
db.session.query(Workflow)
.filter(
Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id
)
.first()
)
# return workflow
return workflow
for event in generator:
self._handle_event(workflow_entry, event)

View File

@ -1,3 +1,4 @@
import json
import logging
import time
from collections.abc import Generator
@ -15,10 +16,12 @@ from core.app.entities.queue_entities import (
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueMessageReplaceEvent,
QueueNodeFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
QueuePingEvent,
QueueStopEvent,
QueueTextChunkEvent,
@ -32,19 +35,16 @@ from core.app.entities.task_entities import (
MessageAudioStreamResponse,
StreamResponse,
TextChunkStreamResponse,
TextReplaceStreamResponse,
WorkflowAppBlockingResponse,
WorkflowAppStreamResponse,
WorkflowFinishStreamResponse,
WorkflowStreamGenerateNodes,
WorkflowStartStreamResponse,
WorkflowTaskState,
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.node_entities import NodeType
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.end.end_node import EndNode
from extensions.ext_database import db
from models.account import Account
from models.model import EndUser
@ -52,8 +52,8 @@ from models.workflow import (
Workflow,
WorkflowAppLog,
WorkflowAppLogCreatedFrom,
WorkflowNodeExecution,
WorkflowRun,
WorkflowRunStatus,
)
logger = logging.getLogger(__name__)
@ -68,7 +68,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
_task_state: WorkflowTaskState
_application_generate_entity: WorkflowAppGenerateEntity
_workflow_system_variables: dict[SystemVariableKey, Any]
_iteration_nested_relations: dict[str, list[str]]
def __init__(self, application_generate_entity: WorkflowAppGenerateEntity,
workflow: Workflow,
@ -96,11 +95,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
SystemVariableKey.USER_ID: user_id
}
self._task_state = WorkflowTaskState(
iteration_nested_node_ids=[]
)
self._stream_generate_nodes = self._get_stream_generate_nodes()
self._iteration_nested_relations = self._get_iteration_nested_relations(self._workflow.graph_dict)
self._task_state = WorkflowTaskState()
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
"""
@ -129,23 +124,20 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if isinstance(stream_response, ErrorStreamResponse):
raise stream_response.err
elif isinstance(stream_response, WorkflowFinishStreamResponse):
workflow_run = db.session.query(WorkflowRun).filter(
WorkflowRun.id == self._task_state.workflow_run_id).first()
response = WorkflowAppBlockingResponse(
task_id=self._application_generate_entity.task_id,
workflow_run_id=workflow_run.id,
workflow_run_id=stream_response.data.id,
data=WorkflowAppBlockingResponse.Data(
id=workflow_run.id,
workflow_id=workflow_run.workflow_id,
status=workflow_run.status,
outputs=workflow_run.outputs_dict,
error=workflow_run.error,
elapsed_time=workflow_run.elapsed_time,
total_tokens=workflow_run.total_tokens,
total_steps=workflow_run.total_steps,
created_at=int(workflow_run.created_at.timestamp()),
finished_at=int(workflow_run.finished_at.timestamp())
id=stream_response.data.id,
workflow_id=stream_response.data.workflow_id,
status=stream_response.data.status,
outputs=stream_response.data.outputs,
error=stream_response.data.error,
elapsed_time=stream_response.data.elapsed_time,
total_tokens=stream_response.data.total_tokens,
total_steps=stream_response.data.total_steps,
created_at=int(stream_response.data.created_at),
finished_at=int(stream_response.data.finished_at)
)
)
@ -161,9 +153,13 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
To stream response.
:return:
"""
workflow_run_id = None
for stream_response in generator:
if isinstance(stream_response, WorkflowStartStreamResponse):
workflow_run_id = stream_response.workflow_run_id
yield WorkflowAppStreamResponse(
workflow_run_id=self._task_state.workflow_run_id,
workflow_run_id=workflow_run_id,
stream_response=stream_response
)
@ -178,17 +174,18 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
Generator[StreamResponse, None, None]:
publisher = None
tts_publisher = None
task_id = self._application_generate_entity.task_id
tenant_id = self._application_generate_entity.app_config.tenant_id
features_dict = self._workflow.features_dict
if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[
'text_to_speech'].get('autoPlay') == 'enabled':
publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
while True:
audio_response = self._listenAudioMsg(publisher, task_id=task_id)
audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id)
if audio_response:
yield audio_response
else:
@ -198,9 +195,9 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
start_listener_time = time.time()
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
try:
if not publisher:
if not tts_publisher:
break
audio_trunk = publisher.checkAndGetAudio()
audio_trunk = tts_publisher.checkAndGetAudio()
if audio_trunk is None:
# release cpu
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
@ -218,69 +215,159 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
def _process_stream_response(
self,
publisher: AppGeneratorTTSPublisher,
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
trace_manager: Optional[TraceQueueManager] = None
) -> Generator[StreamResponse, None, None]:
"""
Process stream response.
:return:
"""
for message in self._queue_manager.listen():
if publisher:
publisher.publish(message=message)
event = message.event
graph_runtime_state = None
workflow_run = None
if isinstance(event, QueueErrorEvent):
for queue_message in self._queue_manager.listen():
event = queue_message.event
if isinstance(event, QueuePingEvent):
yield self._ping_stream_response()
elif isinstance(event, QueueErrorEvent):
err = self._handle_error(event)
yield self._error_to_stream_response(err)
break
elif isinstance(event, QueueWorkflowStartedEvent):
workflow_run = self._handle_workflow_start()
# override graph runtime state
graph_runtime_state = event.graph_runtime_state
# init workflow run
workflow_run = self._handle_workflow_run_start()
yield self._workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
)
elif isinstance(event, QueueNodeStartedEvent):
workflow_node_execution = self._handle_node_start(event)
if not workflow_run:
raise Exception('Workflow run not initialized.')
# search stream_generate_routes if node id is answer start at node
if not self._task_state.current_stream_generate_state and event.node_id in self._stream_generate_nodes:
self._task_state.current_stream_generate_state = self._stream_generate_nodes[event.node_id]
workflow_node_execution = self._handle_node_execution_start(
workflow_run=workflow_run,
event=event
)
# generate stream outputs when node started
yield from self._generate_stream_outputs_when_node_started()
yield self._workflow_node_start_to_stream_response(
response = self._workflow_node_start_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution
)
elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent):
workflow_node_execution = self._handle_node_finished(event)
yield self._workflow_node_finish_to_stream_response(
if response:
yield response
elif isinstance(event, QueueNodeSucceededEvent):
workflow_node_execution = self._handle_workflow_node_execution_success(event)
response = self._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution
)
if isinstance(event, QueueNodeFailedEvent):
yield from self._handle_iteration_exception(
task_id=self._application_generate_entity.task_id,
error=f'Child node failed: {event.error}'
)
elif isinstance(event, QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent):
if isinstance(event, QueueIterationNextEvent):
# clear ran node execution infos of current iteration
iteration_relations = self._iteration_nested_relations.get(event.node_id)
if iteration_relations:
for node_id in iteration_relations:
self._task_state.ran_node_execution_infos.pop(node_id, None)
if response:
yield response
elif isinstance(event, QueueNodeFailedEvent):
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
yield self._handle_iteration_to_stream_response(self._application_generate_entity.task_id, event)
self._handle_iteration_operation(event)
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
workflow_run = self._handle_workflow_finished(
event, trace_manager=trace_manager
response = self._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution
)
if response:
yield response
elif isinstance(event, QueueParallelBranchRunStartedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_parallel_branch_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_parallel_branch_finished_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueIterationStartEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_iteration_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueIterationNextEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_iteration_next_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueIterationCompletedEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
yield self._workflow_iteration_completed_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event
)
elif isinstance(event, QueueWorkflowSucceededEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
if not graph_runtime_state:
raise Exception('Graph runtime state not initialized.')
workflow_run = self._handle_workflow_run_success(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=json.dumps(event.outputs) if isinstance(event, QueueWorkflowSucceededEvent) and event.outputs else None,
conversation_id=None,
trace_manager=trace_manager,
)
# save workflow app log
self._save_workflow_app_log(workflow_run)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
)
elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent):
if not workflow_run:
raise Exception('Workflow run not initialized.')
if not graph_runtime_state:
raise Exception('Graph runtime state not initialized.')
workflow_run = self._handle_workflow_run_failed(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.FAILED if isinstance(event, QueueWorkflowFailedEvent) else WorkflowRunStatus.STOPPED,
error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
conversation_id=None,
trace_manager=trace_manager,
)
# save workflow app log
@ -295,22 +382,17 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if delta_text is None:
continue
if not self._is_stream_out_support(
event=event
):
continue
# only publish tts message at text chunk streaming
if tts_publisher:
tts_publisher.publish(message=queue_message)
self._task_state.answer += delta_text
yield self._text_chunk_to_stream_response(delta_text)
elif isinstance(event, QueueMessageReplaceEvent):
yield self._text_replace_to_stream_response(event.text)
elif isinstance(event, QueuePingEvent):
yield self._ping_stream_response()
else:
continue
if publisher:
publisher.publish(None)
if tts_publisher:
tts_publisher.publish(None)
def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None:
@ -329,15 +411,15 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
# not save log for debugging
return
workflow_app_log = WorkflowAppLog(
tenant_id=workflow_run.tenant_id,
app_id=workflow_run.app_id,
workflow_id=workflow_run.workflow_id,
workflow_run_id=workflow_run.id,
created_from=created_from.value,
created_by_role=('account' if isinstance(self._user, Account) else 'end_user'),
created_by=self._user.id,
)
workflow_app_log = WorkflowAppLog()
workflow_app_log.tenant_id = workflow_run.tenant_id
workflow_app_log.app_id = workflow_run.app_id
workflow_app_log.workflow_id = workflow_run.workflow_id
workflow_app_log.workflow_run_id = workflow_run.id
workflow_app_log.created_from = created_from.value
workflow_app_log.created_by_role = 'account' if isinstance(self._user, Account) else 'end_user'
workflow_app_log.created_by = self._user.id
db.session.add(workflow_app_log)
db.session.commit()
db.session.close()
@ -354,180 +436,3 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
)
return response
def _text_replace_to_stream_response(self, text: str) -> TextReplaceStreamResponse:
"""
Text replace to stream response.
:param text: text
:return:
"""
return TextReplaceStreamResponse(
task_id=self._application_generate_entity.task_id,
text=TextReplaceStreamResponse.Data(text=text)
)
def _get_stream_generate_nodes(self) -> dict[str, WorkflowStreamGenerateNodes]:
"""
Get stream generate nodes.
:return:
"""
# find all answer nodes
graph = self._workflow.graph_dict
end_node_configs = [
node for node in graph['nodes']
if node.get('data', {}).get('type') == NodeType.END.value
]
# parse stream output node value selectors of end nodes
stream_generate_routes = {}
for node_config in end_node_configs:
# get generate route for stream output
end_node_id = node_config['id']
generate_nodes = EndNode.extract_generate_nodes(graph, node_config)
start_node_ids = self._get_end_start_at_node_ids(graph, end_node_id)
if not start_node_ids:
continue
for start_node_id in start_node_ids:
stream_generate_routes[start_node_id] = WorkflowStreamGenerateNodes(
end_node_id=end_node_id,
stream_node_ids=generate_nodes
)
return stream_generate_routes
def _get_end_start_at_node_ids(self, graph: dict, target_node_id: str) \
-> list[str]:
"""
Get end start at node id.
:param graph: graph
:param target_node_id: target node ID
:return:
"""
nodes = graph.get('nodes')
edges = graph.get('edges')
# fetch all ingoing edges from source node
ingoing_edges = []
for edge in edges:
if edge.get('target') == target_node_id:
ingoing_edges.append(edge)
if not ingoing_edges:
return []
start_node_ids = []
for ingoing_edge in ingoing_edges:
source_node_id = ingoing_edge.get('source')
source_node = next((node for node in nodes if node.get('id') == source_node_id), None)
if not source_node:
continue
node_type = source_node.get('data', {}).get('type')
node_iteration_id = source_node.get('data', {}).get('iteration_id')
iteration_start_node_id = None
if node_iteration_id:
iteration_node = next((node for node in nodes if node.get('id') == node_iteration_id), None)
iteration_start_node_id = iteration_node.get('data', {}).get('start_node_id')
if node_type in [
NodeType.IF_ELSE.value,
NodeType.QUESTION_CLASSIFIER.value
]:
start_node_id = target_node_id
start_node_ids.append(start_node_id)
elif node_type == NodeType.START.value or \
node_iteration_id is not None and iteration_start_node_id == source_node.get('id'):
start_node_id = source_node_id
start_node_ids.append(start_node_id)
else:
sub_start_node_ids = self._get_end_start_at_node_ids(graph, source_node_id)
if sub_start_node_ids:
start_node_ids.extend(sub_start_node_ids)
return start_node_ids
def _generate_stream_outputs_when_node_started(self) -> Generator:
"""
Generate stream outputs.
:return:
"""
if self._task_state.current_stream_generate_state:
stream_node_ids = self._task_state.current_stream_generate_state.stream_node_ids
for node_id, node_execution_info in self._task_state.ran_node_execution_infos.items():
if node_id not in stream_node_ids:
continue
node_execution_info = self._task_state.ran_node_execution_infos[node_id]
# get chunk node execution
route_chunk_node_execution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == node_execution_info.workflow_node_execution_id).first()
if not route_chunk_node_execution:
continue
outputs = route_chunk_node_execution.outputs_dict
if not outputs:
continue
# get value from outputs
text = outputs.get('text')
if text:
self._task_state.answer += text
yield self._text_chunk_to_stream_response(text)
db.session.close()
def _is_stream_out_support(self, event: QueueTextChunkEvent) -> bool:
"""
Is stream out support
:param event: queue text chunk event
:return:
"""
if not event.metadata:
return False
if 'node_id' not in event.metadata:
return False
node_id = event.metadata.get('node_id')
node_type = event.metadata.get('node_type')
stream_output_value_selector = event.metadata.get('value_selector')
if not stream_output_value_selector:
return False
if not self._task_state.current_stream_generate_state:
return False
if node_id not in self._task_state.current_stream_generate_state.stream_node_ids:
return False
if node_type != NodeType.LLM:
# only LLM support chunk stream output
return False
return True
def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]:
"""
Get iteration nested relations.
:param graph: graph
:return:
"""
nodes = graph.get('nodes')
iteration_ids = [node.get('id') for node in nodes
if node.get('data', {}).get('type') in [
NodeType.ITERATION.value,
NodeType.LOOP.value,
]]
return {
iteration_id: [
node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id
] for iteration_id in iteration_ids
}

View File

@ -1,200 +0,0 @@
from typing import Any, Optional
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
from models.workflow import Workflow
class WorkflowEventTriggerCallback(WorkflowCallback):
def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
self._queue_manager = queue_manager
def on_workflow_run_started(self) -> None:
"""
Workflow run started
"""
self._queue_manager.publish(
QueueWorkflowStartedEvent(),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_run_succeeded(self) -> None:
"""
Workflow run succeeded
"""
self._queue_manager.publish(
QueueWorkflowSucceededEvent(),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_run_failed(self, error: str) -> None:
"""
Workflow run failed
"""
self._queue_manager.publish(
QueueWorkflowFailedEvent(
error=error
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_node_execute_started(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
node_run_index: int = 1,
predecessor_node_id: Optional[str] = None) -> None:
"""
Workflow node execute started
"""
self._queue_manager.publish(
QueueNodeStartedEvent(
node_id=node_id,
node_type=node_type,
node_data=node_data,
node_run_index=node_run_index,
predecessor_node_id=predecessor_node_id
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_node_execute_succeeded(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
inputs: Optional[dict] = None,
process_data: Optional[dict] = None,
outputs: Optional[dict] = None,
execution_metadata: Optional[dict] = None) -> None:
"""
Workflow node execute succeeded
"""
self._queue_manager.publish(
QueueNodeSucceededEvent(
node_id=node_id,
node_type=node_type,
node_data=node_data,
inputs=inputs,
process_data=process_data,
outputs=outputs,
execution_metadata=execution_metadata
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_node_execute_failed(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
error: str,
inputs: Optional[dict] = None,
outputs: Optional[dict] = None,
process_data: Optional[dict] = None) -> None:
"""
Workflow node execute failed
"""
self._queue_manager.publish(
QueueNodeFailedEvent(
node_id=node_id,
node_type=node_type,
node_data=node_data,
inputs=inputs,
outputs=outputs,
process_data=process_data,
error=error
),
PublishFrom.APPLICATION_MANAGER
)
def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None:
"""
Publish text chunk
"""
self._queue_manager.publish(
QueueTextChunkEvent(
text=text,
metadata={
"node_id": node_id,
**metadata
}
), PublishFrom.APPLICATION_MANAGER
)
def on_workflow_iteration_started(self,
node_id: str,
node_type: NodeType,
node_run_index: int = 1,
node_data: Optional[BaseNodeData] = None,
inputs: dict = None,
predecessor_node_id: Optional[str] = None,
metadata: Optional[dict] = None) -> None:
"""
Publish iteration started
"""
self._queue_manager.publish(
QueueIterationStartEvent(
node_id=node_id,
node_type=node_type,
node_run_index=node_run_index,
node_data=node_data,
inputs=inputs,
predecessor_node_id=predecessor_node_id,
metadata=metadata
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_iteration_next(self, node_id: str,
node_type: NodeType,
index: int,
node_run_index: int,
output: Optional[Any]) -> None:
"""
Publish iteration next
"""
self._queue_manager.publish(
QueueIterationNextEvent(
node_id=node_id,
node_type=node_type,
index=index,
node_run_index=node_run_index,
output=output
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_iteration_completed(self, node_id: str,
node_type: NodeType,
node_run_index: int,
outputs: dict) -> None:
"""
Publish iteration completed
"""
self._queue_manager.publish(
QueueIterationCompletedEvent(
node_id=node_id,
node_type=node_type,
node_run_index=node_run_index,
outputs=outputs
),
PublishFrom.APPLICATION_MANAGER
)
def on_event(self, event: AppQueueEvent) -> None:
"""
Publish event
"""
pass

View File

@ -0,0 +1,379 @@
from collections.abc import Mapping
from typing import Any, Optional, cast
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
QueueRetrieverResourcesEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.workflow.entities.node_entities import NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
GraphRunFailedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
IterationRunFailedEvent,
IterationRunNextEvent,
IterationRunStartedEvent,
IterationRunSucceededEvent,
NodeRunFailedEvent,
NodeRunRetrieverResourceEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
ParallelBranchRunFailedEvent,
ParallelBranchRunStartedEvent,
ParallelBranchRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.iteration.entities import IterationNodeData
from core.workflow.nodes.node_mapping import node_classes
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.model import App
from models.workflow import Workflow
class WorkflowBasedAppRunner(AppRunner):
def __init__(self, queue_manager: AppQueueManager):
self.queue_manager = queue_manager
def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph:
"""
Init graph
"""
if 'nodes' not in graph_config or 'edges' not in graph_config:
raise ValueError('nodes or edges not found in workflow graph')
if not isinstance(graph_config.get('nodes'), list):
raise ValueError('nodes in workflow graph must be a list')
if not isinstance(graph_config.get('edges'), list):
raise ValueError('edges in workflow graph must be a list')
# init graph
graph = Graph.init(
graph_config=graph_config
)
if not graph:
raise ValueError('graph not found in workflow')
return graph
def _get_graph_and_variable_pool_of_single_iteration(
self,
workflow: Workflow,
node_id: str,
user_inputs: dict,
) -> tuple[Graph, VariablePool]:
"""
Get variable pool of single iteration
"""
# fetch workflow graph
graph_config = workflow.graph_dict
if not graph_config:
raise ValueError('workflow graph not found')
graph_config = cast(dict[str, Any], graph_config)
if 'nodes' not in graph_config or 'edges' not in graph_config:
raise ValueError('nodes or edges not found in workflow graph')
if not isinstance(graph_config.get('nodes'), list):
raise ValueError('nodes in workflow graph must be a list')
if not isinstance(graph_config.get('edges'), list):
raise ValueError('edges in workflow graph must be a list')
# filter nodes only in iteration
node_configs = [
node for node in graph_config.get('nodes', [])
if node.get('id') == node_id or node.get('data', {}).get('iteration_id', '') == node_id
]
graph_config['nodes'] = node_configs
node_ids = [node.get('id') for node in node_configs]
# filter edges only in iteration
edge_configs = [
edge for edge in graph_config.get('edges', [])
if (edge.get('source') is None or edge.get('source') in node_ids)
and (edge.get('target') is None or edge.get('target') in node_ids)
]
graph_config['edges'] = edge_configs
# init graph
graph = Graph.init(
graph_config=graph_config,
root_node_id=node_id
)
if not graph:
raise ValueError('graph not found in workflow')
# fetch node config from node id
iteration_node_config = None
for node in node_configs:
if node.get('id') == node_id:
iteration_node_config = node
break
if not iteration_node_config:
raise ValueError('iteration node id not found in workflow graph')
# Get node class
node_type = NodeType.value_of(iteration_node_config.get('data', {}).get('type'))
node_cls = node_classes.get(node_type)
node_cls = cast(type[BaseNode], node_cls)
# init variable pool
variable_pool = VariablePool(
system_variables={},
user_inputs={},
environment_variables=workflow.environment_variables,
)
try:
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=workflow.graph_dict,
config=iteration_node_config
)
except NotImplementedError:
variable_mapping = {}
WorkflowEntry.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
user_inputs=user_inputs,
variable_pool=variable_pool,
tenant_id=workflow.tenant_id,
node_type=node_type,
node_data=IterationNodeData(**iteration_node_config.get('data', {}))
)
return graph, variable_pool
def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) -> None:
"""
Handle event
:param workflow_entry: workflow entry
:param event: event
"""
if isinstance(event, GraphRunStartedEvent):
self._publish_event(
QueueWorkflowStartedEvent(
graph_runtime_state=workflow_entry.graph_engine.graph_runtime_state
)
)
elif isinstance(event, GraphRunSucceededEvent):
self._publish_event(
QueueWorkflowSucceededEvent(outputs=event.outputs)
)
elif isinstance(event, GraphRunFailedEvent):
self._publish_event(
QueueWorkflowFailedEvent(error=event.error)
)
elif isinstance(event, NodeRunStartedEvent):
self._publish_event(
QueueNodeStartedEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
node_run_index=event.route_node_state.index,
predecessor_node_id=event.predecessor_node_id,
in_iteration_id=event.in_iteration_id
)
)
elif isinstance(event, NodeRunSucceededEvent):
self._publish_event(
QueueNodeSucceededEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result else {},
outputs=event.route_node_state.node_run_result.outputs
if event.route_node_state.node_run_result else {},
execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result else {},
in_iteration_id=event.in_iteration_id
)
)
elif isinstance(event, NodeRunFailedEvent):
self._publish_event(
QueueNodeFailedEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result else {},
outputs=event.route_node_state.node_run_result.outputs
if event.route_node_state.node_run_result else {},
error=event.route_node_state.node_run_result.error
if event.route_node_state.node_run_result
and event.route_node_state.node_run_result.error
else "Unknown error",
in_iteration_id=event.in_iteration_id
)
)
elif isinstance(event, NodeRunStreamChunkEvent):
self._publish_event(
QueueTextChunkEvent(
text=event.chunk_content,
from_variable_selector=event.from_variable_selector,
in_iteration_id=event.in_iteration_id
)
)
elif isinstance(event, NodeRunRetrieverResourceEvent):
self._publish_event(
QueueRetrieverResourcesEvent(
retriever_resources=event.retriever_resources,
in_iteration_id=event.in_iteration_id
)
)
elif isinstance(event, ParallelBranchRunStartedEvent):
self._publish_event(
QueueParallelBranchRunStartedEvent(
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
in_iteration_id=event.in_iteration_id
)
)
elif isinstance(event, ParallelBranchRunSucceededEvent):
self._publish_event(
QueueParallelBranchRunSucceededEvent(
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
in_iteration_id=event.in_iteration_id
)
)
elif isinstance(event, ParallelBranchRunFailedEvent):
self._publish_event(
QueueParallelBranchRunFailedEvent(
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
in_iteration_id=event.in_iteration_id,
error=event.error
)
)
elif isinstance(event, IterationRunStartedEvent):
self._publish_event(
QueueIterationStartEvent(
node_execution_id=event.iteration_id,
node_id=event.iteration_node_id,
node_type=event.iteration_node_type,
node_data=event.iteration_node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.start_at,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs,
predecessor_node_id=event.predecessor_node_id,
metadata=event.metadata
)
)
elif isinstance(event, IterationRunNextEvent):
self._publish_event(
QueueIterationNextEvent(
node_execution_id=event.iteration_id,
node_id=event.iteration_node_id,
node_type=event.iteration_node_type,
node_data=event.iteration_node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
index=event.index,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
output=event.pre_iteration_output,
)
)
elif isinstance(event, (IterationRunSucceededEvent | IterationRunFailedEvent)):
self._publish_event(
QueueIterationCompletedEvent(
node_execution_id=event.iteration_id,
node_id=event.iteration_node_id,
node_type=event.iteration_node_type,
node_data=event.iteration_node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.start_at,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs,
outputs=event.outputs,
metadata=event.metadata,
steps=event.steps,
error=event.error if isinstance(event, IterationRunFailedEvent) else None
)
)
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
"""
Get workflow
"""
# fetch workflow by workflow_id
workflow = (
db.session.query(Workflow)
.filter(
Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id
)
.first()
)
# return workflow
return workflow
def _publish_event(self, event: AppQueueEvent) -> None:
self.queue_manager.publish(
event,
PublishFrom.APPLICATION_MANAGER
)

View File

@ -1,10 +1,24 @@
from typing import Optional
from core.app.entities.queue_entities import AppQueueEvent
from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
GraphRunFailedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
IterationRunFailedEvent,
IterationRunNextEvent,
IterationRunStartedEvent,
IterationRunSucceededEvent,
NodeRunFailedEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
ParallelBranchRunFailedEvent,
ParallelBranchRunStartedEvent,
ParallelBranchRunSucceededEvent,
)
_TEXT_COLOR_MAPPING = {
"blue": "36;1",
@ -20,127 +34,203 @@ class WorkflowLoggingCallback(WorkflowCallback):
def __init__(self) -> None:
self.current_node_id = None
def on_workflow_run_started(self) -> None:
"""
Workflow run started
"""
self.print_text("\n[on_workflow_run_started]", color='pink')
def on_event(
self,
event: GraphEngineEvent
) -> None:
if isinstance(event, GraphRunStartedEvent):
self.print_text("\n[GraphRunStartedEvent]", color='pink')
elif isinstance(event, GraphRunSucceededEvent):
self.print_text("\n[GraphRunSucceededEvent]", color='green')
elif isinstance(event, GraphRunFailedEvent):
self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color='red')
elif isinstance(event, NodeRunStartedEvent):
self.on_workflow_node_execute_started(
event=event
)
elif isinstance(event, NodeRunSucceededEvent):
self.on_workflow_node_execute_succeeded(
event=event
)
elif isinstance(event, NodeRunFailedEvent):
self.on_workflow_node_execute_failed(
event=event
)
elif isinstance(event, NodeRunStreamChunkEvent):
self.on_node_text_chunk(
event=event
)
elif isinstance(event, ParallelBranchRunStartedEvent):
self.on_workflow_parallel_started(
event=event
)
elif isinstance(event, ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent):
self.on_workflow_parallel_completed(
event=event
)
elif isinstance(event, IterationRunStartedEvent):
self.on_workflow_iteration_started(
event=event
)
elif isinstance(event, IterationRunNextEvent):
self.on_workflow_iteration_next(
event=event
)
elif isinstance(event, IterationRunSucceededEvent | IterationRunFailedEvent):
self.on_workflow_iteration_completed(
event=event
)
else:
self.print_text(f"\n[{event.__class__.__name__}]", color='blue')
def on_workflow_run_succeeded(self) -> None:
"""
Workflow run succeeded
"""
self.print_text("\n[on_workflow_run_succeeded]", color='green')
def on_workflow_run_failed(self, error: str) -> None:
"""
Workflow run failed
"""
self.print_text("\n[on_workflow_run_failed]", color='red')
def on_workflow_node_execute_started(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
node_run_index: int = 1,
predecessor_node_id: Optional[str] = None) -> None:
def on_workflow_node_execute_started(
self,
event: NodeRunStartedEvent
) -> None:
"""
Workflow node execute started
"""
self.print_text("\n[on_workflow_node_execute_started]", color='yellow')
self.print_text(f"Node ID: {node_id}", color='yellow')
self.print_text(f"Type: {node_type.value}", color='yellow')
self.print_text(f"Index: {node_run_index}", color='yellow')
if predecessor_node_id:
self.print_text(f"Predecessor Node ID: {predecessor_node_id}", color='yellow')
self.print_text("\n[NodeRunStartedEvent]", color='yellow')
self.print_text(f"Node ID: {event.node_id}", color='yellow')
self.print_text(f"Node Title: {event.node_data.title}", color='yellow')
self.print_text(f"Type: {event.node_type.value}", color='yellow')
def on_workflow_node_execute_succeeded(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
inputs: Optional[dict] = None,
process_data: Optional[dict] = None,
outputs: Optional[dict] = None,
execution_metadata: Optional[dict] = None) -> None:
def on_workflow_node_execute_succeeded(
self,
event: NodeRunSucceededEvent
) -> None:
"""
Workflow node execute succeeded
"""
self.print_text("\n[on_workflow_node_execute_succeeded]", color='green')
self.print_text(f"Node ID: {node_id}", color='green')
self.print_text(f"Type: {node_type.value}", color='green')
self.print_text(f"Inputs: {jsonable_encoder(inputs) if inputs else ''}", color='green')
self.print_text(f"Process Data: {jsonable_encoder(process_data) if process_data else ''}", color='green')
self.print_text(f"Outputs: {jsonable_encoder(outputs) if outputs else ''}", color='green')
self.print_text(f"Metadata: {jsonable_encoder(execution_metadata) if execution_metadata else ''}",
color='green')
route_node_state = event.route_node_state
def on_workflow_node_execute_failed(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
error: str,
inputs: Optional[dict] = None,
outputs: Optional[dict] = None,
process_data: Optional[dict] = None) -> None:
self.print_text("\n[NodeRunSucceededEvent]", color='green')
self.print_text(f"Node ID: {event.node_id}", color='green')
self.print_text(f"Node Title: {event.node_data.title}", color='green')
self.print_text(f"Type: {event.node_type.value}", color='green')
if route_node_state.node_run_result:
node_run_result = route_node_state.node_run_result
self.print_text(f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}",
color='green')
self.print_text(
f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
color='green')
self.print_text(f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}",
color='green')
self.print_text(
f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}",
color='green')
def on_workflow_node_execute_failed(
self,
event: NodeRunFailedEvent
) -> None:
"""
Workflow node execute failed
"""
self.print_text("\n[on_workflow_node_execute_failed]", color='red')
self.print_text(f"Node ID: {node_id}", color='red')
self.print_text(f"Type: {node_type.value}", color='red')
self.print_text(f"Error: {error}", color='red')
self.print_text(f"Inputs: {jsonable_encoder(inputs) if inputs else ''}", color='red')
self.print_text(f"Process Data: {jsonable_encoder(process_data) if process_data else ''}", color='red')
self.print_text(f"Outputs: {jsonable_encoder(outputs) if outputs else ''}", color='red')
route_node_state = event.route_node_state
def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None:
self.print_text("\n[NodeRunFailedEvent]", color='red')
self.print_text(f"Node ID: {event.node_id}", color='red')
self.print_text(f"Node Title: {event.node_data.title}", color='red')
self.print_text(f"Type: {event.node_type.value}", color='red')
if route_node_state.node_run_result:
node_run_result = route_node_state.node_run_result
self.print_text(f"Error: {node_run_result.error}", color='red')
self.print_text(f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}",
color='red')
self.print_text(
f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
color='red')
self.print_text(f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}",
color='red')
def on_node_text_chunk(
self,
event: NodeRunStreamChunkEvent
) -> None:
"""
Publish text chunk
"""
if not self.current_node_id or self.current_node_id != node_id:
self.current_node_id = node_id
self.print_text('\n[on_node_text_chunk]')
self.print_text(f"Node ID: {node_id}")
self.print_text(f"Metadata: {jsonable_encoder(metadata) if metadata else ''}")
route_node_state = event.route_node_state
if not self.current_node_id or self.current_node_id != route_node_state.node_id:
self.current_node_id = route_node_state.node_id
self.print_text('\n[NodeRunStreamChunkEvent]')
self.print_text(f"Node ID: {route_node_state.node_id}")
self.print_text(text, color="pink", end="")
node_run_result = route_node_state.node_run_result
if node_run_result:
self.print_text(
f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}")
def on_workflow_iteration_started(self,
node_id: str,
node_type: NodeType,
node_run_index: int = 1,
node_data: Optional[BaseNodeData] = None,
inputs: dict = None,
predecessor_node_id: Optional[str] = None,
metadata: Optional[dict] = None) -> None:
self.print_text(event.chunk_content, color="pink", end="")
def on_workflow_parallel_started(
self,
event: ParallelBranchRunStartedEvent
) -> None:
"""
Publish parallel started
"""
self.print_text("\n[ParallelBranchRunStartedEvent]", color='blue')
self.print_text(f"Parallel ID: {event.parallel_id}", color='blue')
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color='blue')
if event.in_iteration_id:
self.print_text(f"Iteration ID: {event.in_iteration_id}", color='blue')
def on_workflow_parallel_completed(
self,
event: ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent
) -> None:
"""
Publish parallel completed
"""
if isinstance(event, ParallelBranchRunSucceededEvent):
color = 'blue'
elif isinstance(event, ParallelBranchRunFailedEvent):
color = 'red'
self.print_text("\n[ParallelBranchRunSucceededEvent]" if isinstance(event, ParallelBranchRunSucceededEvent) else "\n[ParallelBranchRunFailedEvent]", color=color)
self.print_text(f"Parallel ID: {event.parallel_id}", color=color)
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color=color)
if event.in_iteration_id:
self.print_text(f"Iteration ID: {event.in_iteration_id}", color=color)
if isinstance(event, ParallelBranchRunFailedEvent):
self.print_text(f"Error: {event.error}", color=color)
def on_workflow_iteration_started(
self,
event: IterationRunStartedEvent
) -> None:
"""
Publish iteration started
"""
self.print_text("\n[on_workflow_iteration_started]", color='blue')
self.print_text(f"Node ID: {node_id}", color='blue')
self.print_text("\n[IterationRunStartedEvent]", color='blue')
self.print_text(f"Iteration Node ID: {event.iteration_id}", color='blue')
def on_workflow_iteration_next(self, node_id: str,
node_type: NodeType,
index: int,
node_run_index: int,
output: Optional[dict]) -> None:
def on_workflow_iteration_next(
self,
event: IterationRunNextEvent
) -> None:
"""
Publish iteration next
"""
self.print_text("\n[on_workflow_iteration_next]", color='blue')
self.print_text("\n[IterationRunNextEvent]", color='blue')
self.print_text(f"Iteration Node ID: {event.iteration_id}", color='blue')
self.print_text(f"Iteration Index: {event.index}", color='blue')
def on_workflow_iteration_completed(self, node_id: str,
node_type: NodeType,
node_run_index: int,
outputs: dict) -> None:
def on_workflow_iteration_completed(
self,
event: IterationRunSucceededEvent | IterationRunFailedEvent
) -> None:
"""
Publish iteration completed
"""
self.print_text("\n[on_workflow_iteration_completed]", color='blue')
def on_event(self, event: AppQueueEvent) -> None:
"""
Publish event
"""
self.print_text("\n[on_workflow_event]", color='blue')
self.print_text(f"Event: {jsonable_encoder(event)}", color='blue')
self.print_text("\n[IterationRunSucceededEvent]" if isinstance(event, IterationRunSucceededEvent) else "\n[IterationRunFailedEvent]", color='blue')
self.print_text(f"Node ID: {event.iteration_id}", color='blue')
def print_text(
self, text: str, color: Optional[str] = None, end: str = "\n"

View File

@ -1,3 +1,4 @@
from datetime import datetime
from enum import Enum
from typing import Any, Optional
@ -5,7 +6,8 @@ from pydantic import BaseModel, field_validator
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
class QueueEvent(str, Enum):
@ -31,6 +33,9 @@ class QueueEvent(str, Enum):
ANNOTATION_REPLY = "annotation_reply"
AGENT_THOUGHT = "agent_thought"
MESSAGE_FILE = "message_file"
PARALLEL_BRANCH_RUN_STARTED = "parallel_branch_run_started"
PARALLEL_BRANCH_RUN_SUCCEEDED = "parallel_branch_run_succeeded"
PARALLEL_BRANCH_RUN_FAILED = "parallel_branch_run_failed"
ERROR = "error"
PING = "ping"
STOP = "stop"
@ -38,7 +43,7 @@ class QueueEvent(str, Enum):
class AppQueueEvent(BaseModel):
"""
QueueEvent entity
QueueEvent abstract entity
"""
event: QueueEvent
@ -46,6 +51,7 @@ class AppQueueEvent(BaseModel):
class QueueLLMChunkEvent(AppQueueEvent):
"""
QueueLLMChunkEvent entity
Only for basic mode apps
"""
event: QueueEvent = QueueEvent.LLM_CHUNK
chunk: LLMResultChunk
@ -55,14 +61,24 @@ class QueueIterationStartEvent(AppQueueEvent):
QueueIterationStartEvent entity
"""
event: QueueEvent = QueueEvent.ITERATION_START
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
start_at: datetime
node_run_index: int
inputs: dict = None
inputs: Optional[dict[str, Any]] = None
predecessor_node_id: Optional[str] = None
metadata: Optional[dict] = None
metadata: Optional[dict[str, Any]] = None
class QueueIterationNextEvent(AppQueueEvent):
"""
@ -71,8 +87,18 @@ class QueueIterationNextEvent(AppQueueEvent):
event: QueueEvent = QueueEvent.ITERATION_NEXT
index: int
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
node_run_index: int
output: Optional[Any] = None # output for the current iteration
@ -93,13 +119,30 @@ class QueueIterationCompletedEvent(AppQueueEvent):
"""
QueueIterationCompletedEvent entity
"""
event:QueueEvent = QueueEvent.ITERATION_COMPLETED
event: QueueEvent = QueueEvent.ITERATION_COMPLETED
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
start_at: datetime
node_run_index: int
outputs: dict
inputs: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
metadata: Optional[dict[str, Any]] = None
steps: int = 0
error: Optional[str] = None
class QueueTextChunkEvent(AppQueueEvent):
"""
@ -107,7 +150,10 @@ class QueueTextChunkEvent(AppQueueEvent):
"""
event: QueueEvent = QueueEvent.TEXT_CHUNK
text: str
metadata: Optional[dict] = None
from_variable_selector: Optional[list[str]] = None
"""from variable selector"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
class QueueAgentMessageEvent(AppQueueEvent):
@ -132,6 +178,8 @@ class QueueRetrieverResourcesEvent(AppQueueEvent):
"""
event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES
retriever_resources: list[dict]
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
class QueueAnnotationReplyEvent(AppQueueEvent):
@ -162,6 +210,7 @@ class QueueWorkflowStartedEvent(AppQueueEvent):
QueueWorkflowStartedEvent entity
"""
event: QueueEvent = QueueEvent.WORKFLOW_STARTED
graph_runtime_state: GraphRuntimeState
class QueueWorkflowSucceededEvent(AppQueueEvent):
@ -169,6 +218,7 @@ class QueueWorkflowSucceededEvent(AppQueueEvent):
QueueWorkflowSucceededEvent entity
"""
event: QueueEvent = QueueEvent.WORKFLOW_SUCCEEDED
outputs: Optional[dict[str, Any]] = None
class QueueWorkflowFailedEvent(AppQueueEvent):
@ -185,11 +235,23 @@ class QueueNodeStartedEvent(AppQueueEvent):
"""
event: QueueEvent = QueueEvent.NODE_STARTED
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
node_run_index: int = 1
predecessor_node_id: Optional[str] = None
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
start_at: datetime
class QueueNodeSucceededEvent(AppQueueEvent):
@ -198,14 +260,26 @@ class QueueNodeSucceededEvent(AppQueueEvent):
"""
event: QueueEvent = QueueEvent.NODE_SUCCEEDED
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
start_at: datetime
inputs: Optional[dict] = None
process_data: Optional[dict] = None
outputs: Optional[dict] = None
execution_metadata: Optional[dict] = None
inputs: Optional[dict[str, Any]] = None
process_data: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
error: Optional[str] = None
@ -216,13 +290,25 @@ class QueueNodeFailedEvent(AppQueueEvent):
"""
event: QueueEvent = QueueEvent.NODE_FAILED
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
start_at: datetime
inputs: Optional[dict] = None
outputs: Optional[dict] = None
process_data: Optional[dict] = None
inputs: Optional[dict[str, Any]] = None
process_data: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
error: str
@ -274,10 +360,23 @@ class QueueStopEvent(AppQueueEvent):
event: QueueEvent = QueueEvent.STOP
stopped_by: StopBy
def get_stop_reason(self) -> str:
"""
To stop reason
"""
reason_mapping = {
QueueStopEvent.StopBy.USER_MANUAL: 'Stopped by user.',
QueueStopEvent.StopBy.ANNOTATION_REPLY: 'Stopped by annotation reply.',
QueueStopEvent.StopBy.OUTPUT_MODERATION: 'Stopped by output moderation.',
QueueStopEvent.StopBy.INPUT_MODERATION: 'Stopped by input moderation.'
}
return reason_mapping.get(self.stopped_by, 'Stopped by unknown reason.')
class QueueMessage(BaseModel):
"""
QueueMessage entity
QueueMessage abstract entity
"""
task_id: str
app_mode: str
@ -297,3 +396,52 @@ class WorkflowQueueMessage(QueueMessage):
WorkflowQueueMessage entity
"""
pass
class QueueParallelBranchRunStartedEvent(AppQueueEvent):
"""
QueueParallelBranchRunStartedEvent entity
"""
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_STARTED
parallel_id: str
parallel_start_node_id: str
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
class QueueParallelBranchRunSucceededEvent(AppQueueEvent):
"""
QueueParallelBranchRunSucceededEvent entity
"""
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_SUCCEEDED
parallel_id: str
parallel_start_node_id: str
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
class QueueParallelBranchRunFailedEvent(AppQueueEvent):
"""
QueueParallelBranchRunFailedEvent entity
"""
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_FAILED
parallel_id: str
parallel_start_node_id: str
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
error: str

View File

@ -3,40 +3,11 @@ from typing import Any, Optional
from pydantic import BaseModel, ConfigDict
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
from core.workflow.nodes.answer.entities import GenerateRouteChunk
from models.workflow import WorkflowNodeExecutionStatus
class WorkflowStreamGenerateNodes(BaseModel):
"""
WorkflowStreamGenerateNodes entity
"""
end_node_id: str
stream_node_ids: list[str]
class ChatflowStreamGenerateRoute(BaseModel):
"""
ChatflowStreamGenerateRoute entity
"""
answer_node_id: str
generate_route: list[GenerateRouteChunk]
current_route_position: int = 0
class NodeExecutionInfo(BaseModel):
"""
NodeExecutionInfo entity
"""
workflow_node_execution_id: str
node_type: NodeType
start_at: float
class TaskState(BaseModel):
"""
TaskState entity
@ -57,27 +28,6 @@ class WorkflowTaskState(TaskState):
"""
answer: str = ""
workflow_run_id: Optional[str] = None
start_at: Optional[float] = None
total_tokens: int = 0
total_steps: int = 0
ran_node_execution_infos: dict[str, NodeExecutionInfo] = {}
latest_node_execution_info: Optional[NodeExecutionInfo] = None
current_stream_generate_state: Optional[WorkflowStreamGenerateNodes] = None
iteration_nested_node_ids: list[str] = None
class AdvancedChatTaskState(WorkflowTaskState):
"""
AdvancedChatTaskState entity
"""
usage: LLMUsage
current_stream_generate_state: Optional[ChatflowStreamGenerateRoute] = None
class StreamEvent(Enum):
"""
@ -97,6 +47,8 @@ class StreamEvent(Enum):
WORKFLOW_FINISHED = "workflow_finished"
NODE_STARTED = "node_started"
NODE_FINISHED = "node_finished"
PARALLEL_BRANCH_STARTED = "parallel_branch_started"
PARALLEL_BRANCH_FINISHED = "parallel_branch_finished"
ITERATION_STARTED = "iteration_started"
ITERATION_NEXT = "iteration_next"
ITERATION_COMPLETED = "iteration_completed"
@ -267,6 +219,11 @@ class NodeStartStreamResponse(StreamResponse):
inputs: Optional[dict] = None
created_at: int
extras: dict = {}
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
parent_parallel_id: Optional[str] = None
parent_parallel_start_node_id: Optional[str] = None
iteration_id: Optional[str] = None
event: StreamEvent = StreamEvent.NODE_STARTED
workflow_run_id: str
@ -286,7 +243,12 @@ class NodeStartStreamResponse(StreamResponse):
"predecessor_node_id": self.data.predecessor_node_id,
"inputs": None,
"created_at": self.data.created_at,
"extras": {}
"extras": {},
"parallel_id": self.data.parallel_id,
"parallel_start_node_id": self.data.parallel_start_node_id,
"parent_parallel_id": self.data.parent_parallel_id,
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
"iteration_id": self.data.iteration_id,
}
}
@ -316,6 +278,11 @@ class NodeFinishStreamResponse(StreamResponse):
created_at: int
finished_at: int
files: Optional[list[dict]] = []
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
parent_parallel_id: Optional[str] = None
parent_parallel_start_node_id: Optional[str] = None
iteration_id: Optional[str] = None
event: StreamEvent = StreamEvent.NODE_FINISHED
workflow_run_id: str
@ -342,9 +309,58 @@ class NodeFinishStreamResponse(StreamResponse):
"execution_metadata": None,
"created_at": self.data.created_at,
"finished_at": self.data.finished_at,
"files": []
"files": [],
"parallel_id": self.data.parallel_id,
"parallel_start_node_id": self.data.parallel_start_node_id,
"parent_parallel_id": self.data.parent_parallel_id,
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
"iteration_id": self.data.iteration_id,
}
}
class ParallelBranchStartStreamResponse(StreamResponse):
"""
ParallelBranchStartStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
parallel_id: str
parallel_branch_id: str
parent_parallel_id: Optional[str] = None
parent_parallel_start_node_id: Optional[str] = None
iteration_id: Optional[str] = None
created_at: int
event: StreamEvent = StreamEvent.PARALLEL_BRANCH_STARTED
workflow_run_id: str
data: Data
class ParallelBranchFinishedStreamResponse(StreamResponse):
"""
ParallelBranchFinishedStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
parallel_id: str
parallel_branch_id: str
parent_parallel_id: Optional[str] = None
parent_parallel_start_node_id: Optional[str] = None
iteration_id: Optional[str] = None
status: str
error: Optional[str] = None
created_at: int
event: StreamEvent = StreamEvent.PARALLEL_BRANCH_FINISHED
workflow_run_id: str
data: Data
class IterationNodeStartStreamResponse(StreamResponse):
@ -364,6 +380,8 @@ class IterationNodeStartStreamResponse(StreamResponse):
extras: dict = {}
metadata: dict = {}
inputs: dict = {}
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
event: StreamEvent = StreamEvent.ITERATION_STARTED
workflow_run_id: str
@ -387,6 +405,8 @@ class IterationNodeNextStreamResponse(StreamResponse):
created_at: int
pre_iteration_output: Optional[Any] = None
extras: dict = {}
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
event: StreamEvent = StreamEvent.ITERATION_NEXT
workflow_run_id: str
@ -408,8 +428,8 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
title: str
outputs: Optional[dict] = None
created_at: int
extras: dict = None
inputs: dict = None
extras: Optional[dict] = None
inputs: Optional[dict] = None
status: WorkflowNodeExecutionStatus
error: Optional[str] = None
elapsed_time: float
@ -417,6 +437,8 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
execution_metadata: Optional[dict] = None
finished_at: int
steps: int
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
event: StreamEvent = StreamEvent.ITERATION_COMPLETED
workflow_run_id: str
@ -488,7 +510,7 @@ class WorkflowAppStreamResponse(AppStreamResponse):
"""
WorkflowAppStreamResponse entity
"""
workflow_run_id: str
workflow_run_id: Optional[str] = None
class AppBlockingResponse(BaseModel):
@ -562,25 +584,3 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
workflow_run_id: str
data: Data
class WorkflowIterationState(BaseModel):
"""
WorkflowIterationState entity
"""
class Data(BaseModel):
"""
Data entity
"""
parent_iteration_id: Optional[str] = None
iteration_id: str
current_index: int
iteration_steps_boundary: list[int] = None
node_execution_id: str
started_at: float
inputs: dict = None
total_tokens: int = 0
node_data: BaseNodeData
current_iterations: dict[str, Data] = None

View File

@ -68,16 +68,18 @@ class BasedGenerateTaskPipeline:
err = Exception(e.description if getattr(e, 'description', None) is not None else str(e))
if message:
message = db.session.query(Message).filter(Message.id == message.id).first()
err_desc = self._error_to_desc(err)
message.status = 'error'
message.error = err_desc
refetch_message = db.session.query(Message).filter(Message.id == message.id).first()
db.session.commit()
if refetch_message:
err_desc = self._error_to_desc(err)
refetch_message.status = 'error'
refetch_message.error = err_desc
db.session.commit()
return err
def _error_to_desc(cls, e: Exception) -> str:
def _error_to_desc(self, e: Exception) -> str:
"""
Error to desc.
:param e: exception

View File

@ -8,7 +8,6 @@ from core.app.entities.app_invoke_entities import (
AgentChatAppGenerateEntity,
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
InvokeFrom,
)
from core.app.entities.queue_entities import (
QueueAnnotationReplyEvent,
@ -16,11 +15,11 @@ from core.app.entities.queue_entities import (
QueueRetrieverResourcesEvent,
)
from core.app.entities.task_entities import (
AdvancedChatTaskState,
EasyUITaskState,
MessageFileStreamResponse,
MessageReplaceStreamResponse,
MessageStreamResponse,
WorkflowTaskState,
)
from core.llm_generator.llm_generator import LLMGenerator
from core.tools.tool_file_manager import ToolFileManager
@ -36,7 +35,7 @@ class MessageCycleManage:
AgentChatAppGenerateEntity,
AdvancedChatAppGenerateEntity
]
_task_state: Union[EasyUITaskState, AdvancedChatTaskState]
_task_state: Union[EasyUITaskState, WorkflowTaskState]
def _generate_conversation_name(self, conversation: Conversation, query: str) -> Optional[Thread]:
"""
@ -45,6 +44,9 @@ class MessageCycleManage:
:param query: query
:return: thread
"""
if isinstance(self._application_generate_entity, CompletionAppGenerateEntity):
return None
is_first_message = self._application_generate_entity.conversation_id is None
extras = self._application_generate_entity.extras
auto_generate_conversation_name = extras.get('auto_generate_conversation_name', True)
@ -52,7 +54,7 @@ class MessageCycleManage:
if auto_generate_conversation_name and is_first_message:
# start generate thread
thread = Thread(target=self._generate_conversation_name_worker, kwargs={
'flask_app': current_app._get_current_object(),
'flask_app': current_app._get_current_object(), # type: ignore
'conversation_id': conversation.id,
'query': query
})
@ -75,6 +77,9 @@ class MessageCycleManage:
.first()
)
if not conversation:
return
if conversation.mode != AppMode.COMPLETION.value:
app_model = conversation.app
if not app_model:
@ -121,34 +126,13 @@ class MessageCycleManage:
if self._application_generate_entity.app_config.additional_features.show_retrieve_source:
self._task_state.metadata['retriever_resources'] = event.retriever_resources
def _get_response_metadata(self) -> dict:
"""
Get response metadata by invoke from.
:return:
"""
metadata = {}
# show_retrieve_source
if 'retriever_resources' in self._task_state.metadata:
metadata['retriever_resources'] = self._task_state.metadata['retriever_resources']
# show annotation reply
if 'annotation_reply' in self._task_state.metadata:
metadata['annotation_reply'] = self._task_state.metadata['annotation_reply']
# show usage
if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
metadata['usage'] = self._task_state.metadata['usage']
return metadata
def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]:
"""
Message file to stream response.
:param event: event
:return:
"""
message_file: MessageFile = (
message_file = (
db.session.query(MessageFile)
.filter(MessageFile.id == event.message_file_id)
.first()

View File

@ -1,33 +1,41 @@
import json
import time
from datetime import datetime, timezone
from typing import Optional, Union, cast
from typing import Any, Optional, Union, cast
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import (
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueStopEvent,
QueueWorkflowFailedEvent,
QueueWorkflowSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
)
from core.app.entities.task_entities import (
NodeExecutionInfo,
IterationNodeCompletedStreamResponse,
IterationNodeNextStreamResponse,
IterationNodeStartStreamResponse,
NodeFinishStreamResponse,
NodeStartStreamResponse,
ParallelBranchFinishedStreamResponse,
ParallelBranchStartStreamResponse,
WorkflowFinishStreamResponse,
WorkflowStartStreamResponse,
WorkflowTaskState,
)
from core.app.task_pipeline.workflow_iteration_cycle_manage import WorkflowIterationCycleManage
from core.file.file_obj import FileVar
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.tools.tool_manager import ToolManager
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType
from core.workflow.entities.node_entities import NodeType
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.workflow_engine_manager import WorkflowEngineManager
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.account import Account
from models.model import EndUser
@ -41,54 +49,56 @@ from models.workflow import (
WorkflowRunStatus,
WorkflowRunTriggeredFrom,
)
from services.workflow_service import WorkflowService
class WorkflowCycleManage(WorkflowIterationCycleManage):
def _init_workflow_run(self, workflow: Workflow,
triggered_from: WorkflowRunTriggeredFrom,
user: Union[Account, EndUser],
user_inputs: dict,
system_inputs: Optional[dict] = None) -> WorkflowRun:
"""
Init workflow run
:param workflow: Workflow instance
:param triggered_from: triggered from
:param user: account or end user
:param user_inputs: user variables inputs
:param system_inputs: system inputs, like: query, files
:return:
"""
max_sequence = db.session.query(db.func.max(WorkflowRun.sequence_number)) \
.filter(WorkflowRun.tenant_id == workflow.tenant_id) \
.filter(WorkflowRun.app_id == workflow.app_id) \
.scalar() or 0
class WorkflowCycleManage:
_application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity]
_workflow: Workflow
_user: Union[Account, EndUser]
_task_state: WorkflowTaskState
_workflow_system_variables: dict[SystemVariableKey, Any]
def _handle_workflow_run_start(self) -> WorkflowRun:
max_sequence = (
db.session.query(db.func.max(WorkflowRun.sequence_number))
.filter(WorkflowRun.tenant_id == self._workflow.tenant_id)
.filter(WorkflowRun.app_id == self._workflow.app_id)
.scalar()
or 0
)
new_sequence_number = max_sequence + 1
inputs = {**user_inputs}
for key, value in (system_inputs or {}).items():
inputs = {**self._application_generate_entity.inputs}
for key, value in (self._workflow_system_variables or {}).items():
if key.value == 'conversation':
continue
inputs[f'sys.{key.value}'] = value
inputs = WorkflowEngineManager.handle_special_values(inputs)
inputs = WorkflowEntry.handle_special_values(inputs)
triggered_from= (
WorkflowRunTriggeredFrom.DEBUGGING
if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER
else WorkflowRunTriggeredFrom.APP_RUN
)
# init workflow run
workflow_run = WorkflowRun(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
sequence_number=new_sequence_number,
workflow_id=workflow.id,
type=workflow.type,
triggered_from=triggered_from.value,
version=workflow.version,
graph=workflow.graph,
inputs=json.dumps(inputs),
status=WorkflowRunStatus.RUNNING.value,
created_by_role=(CreatedByRole.ACCOUNT.value
if isinstance(user, Account) else CreatedByRole.END_USER.value),
created_by=user.id
workflow_run = WorkflowRun()
workflow_run.tenant_id = self._workflow.tenant_id
workflow_run.app_id = self._workflow.app_id
workflow_run.sequence_number = new_sequence_number
workflow_run.workflow_id = self._workflow.id
workflow_run.type = self._workflow.type
workflow_run.triggered_from = triggered_from.value
workflow_run.version = self._workflow.version
workflow_run.graph = self._workflow.graph
workflow_run.inputs = json.dumps(inputs)
workflow_run.status = WorkflowRunStatus.RUNNING.value
workflow_run.created_by_role = (
CreatedByRole.ACCOUNT.value if isinstance(self._user, Account) else CreatedByRole.END_USER.value
)
workflow_run.created_by = self._user.id
db.session.add(workflow_run)
db.session.commit()
@ -97,33 +107,37 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
return workflow_run
def _workflow_run_success(
self, workflow_run: WorkflowRun,
def _handle_workflow_run_success(
self,
workflow_run: WorkflowRun,
start_at: float,
total_tokens: int,
total_steps: int,
outputs: Optional[str] = None,
conversation_id: Optional[str] = None,
trace_manager: Optional[TraceQueueManager] = None
trace_manager: Optional[TraceQueueManager] = None,
) -> WorkflowRun:
"""
Workflow run success
:param workflow_run: workflow run
:param start_at: start time
:param total_tokens: total tokens
:param total_steps: total steps
:param outputs: outputs
:param conversation_id: conversation id
:return:
"""
workflow_run = self._refetch_workflow_run(workflow_run.id)
workflow_run.status = WorkflowRunStatus.SUCCEEDED.value
workflow_run.outputs = outputs
workflow_run.elapsed_time = WorkflowService.get_elapsed_time(workflow_run_id=workflow_run.id)
workflow_run.elapsed_time = time.perf_counter() - start_at
workflow_run.total_tokens = total_tokens
workflow_run.total_steps = total_steps
workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
db.session.commit()
db.session.refresh(workflow_run)
db.session.close()
if trace_manager:
trace_manager.add_trace_task(
@ -135,34 +149,58 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
)
)
db.session.close()
return workflow_run
def _workflow_run_failed(
self, workflow_run: WorkflowRun,
def _handle_workflow_run_failed(
self,
workflow_run: WorkflowRun,
start_at: float,
total_tokens: int,
total_steps: int,
status: WorkflowRunStatus,
error: str,
conversation_id: Optional[str] = None,
trace_manager: Optional[TraceQueueManager] = None
trace_manager: Optional[TraceQueueManager] = None,
) -> WorkflowRun:
"""
Workflow run failed
:param workflow_run: workflow run
:param start_at: start time
:param total_tokens: total tokens
:param total_steps: total steps
:param status: status
:param error: error message
:return:
"""
workflow_run = self._refetch_workflow_run(workflow_run.id)
workflow_run.status = status.value
workflow_run.error = error
workflow_run.elapsed_time = WorkflowService.get_elapsed_time(workflow_run_id=workflow_run.id)
workflow_run.elapsed_time = time.perf_counter() - start_at
workflow_run.total_tokens = total_tokens
workflow_run.total_steps = total_steps
workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
db.session.commit()
running_workflow_node_executions = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
WorkflowNodeExecution.app_id == workflow_run.app_id,
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
WorkflowNodeExecution.workflow_run_id == workflow_run.id,
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value
).all()
for workflow_node_execution in running_workflow_node_executions:
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.error = error
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - workflow_node_execution.created_at).total_seconds()
db.session.commit()
db.session.refresh(workflow_run)
db.session.close()
@ -178,39 +216,24 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
return workflow_run
def _init_node_execution_from_workflow_run(self, workflow_run: WorkflowRun,
node_id: str,
node_type: NodeType,
node_title: str,
node_run_index: int = 1,
predecessor_node_id: Optional[str] = None) -> WorkflowNodeExecution:
"""
Init workflow node execution from workflow run
:param workflow_run: workflow run
:param node_id: node id
:param node_type: node type
:param node_title: node title
:param node_run_index: run index
:param predecessor_node_id: predecessor node id if exists
:return:
"""
def _handle_node_execution_start(self, workflow_run: WorkflowRun, event: QueueNodeStartedEvent) -> WorkflowNodeExecution:
# init workflow node execution
workflow_node_execution = WorkflowNodeExecution(
tenant_id=workflow_run.tenant_id,
app_id=workflow_run.app_id,
workflow_id=workflow_run.workflow_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
workflow_run_id=workflow_run.id,
predecessor_node_id=predecessor_node_id,
index=node_run_index,
node_id=node_id,
node_type=node_type.value,
title=node_title,
status=WorkflowNodeExecutionStatus.RUNNING.value,
created_by_role=workflow_run.created_by_role,
created_by=workflow_run.created_by,
created_at=datetime.now(timezone.utc).replace(tzinfo=None)
)
workflow_node_execution = WorkflowNodeExecution()
workflow_node_execution.tenant_id = workflow_run.tenant_id
workflow_node_execution.app_id = workflow_run.app_id
workflow_node_execution.workflow_id = workflow_run.workflow_id
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
workflow_node_execution.workflow_run_id = workflow_run.id
workflow_node_execution.predecessor_node_id = event.predecessor_node_id
workflow_node_execution.index = event.node_run_index
workflow_node_execution.node_execution_id = event.node_execution_id
workflow_node_execution.node_id = event.node_id
workflow_node_execution.node_type = event.node_type.value
workflow_node_execution.title = event.node_data.title
workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value
workflow_node_execution.created_by_role = workflow_run.created_by_role
workflow_node_execution.created_by = workflow_run.created_by
workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None)
db.session.add(workflow_node_execution)
db.session.commit()
@ -219,33 +242,26 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
return workflow_node_execution
def _workflow_node_execution_success(self, workflow_node_execution: WorkflowNodeExecution,
start_at: float,
inputs: Optional[dict] = None,
process_data: Optional[dict] = None,
outputs: Optional[dict] = None,
execution_metadata: Optional[dict] = None) -> WorkflowNodeExecution:
def _handle_workflow_node_execution_success(self, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
"""
Workflow node execution success
:param workflow_node_execution: workflow node execution
:param start_at: start time
:param inputs: inputs
:param process_data: process data
:param outputs: outputs
:param execution_metadata: execution metadata
:param event: queue node succeeded event
:return:
"""
inputs = WorkflowEngineManager.handle_special_values(inputs)
outputs = WorkflowEngineManager.handle_special_values(outputs)
workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id)
inputs = WorkflowEntry.handle_special_values(event.inputs)
outputs = WorkflowEntry.handle_special_values(event.outputs)
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
workflow_node_execution.elapsed_time = time.perf_counter() - start_at
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(execution_metadata)) \
if execution_metadata else None
workflow_node_execution.execution_metadata = (
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
)
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - event.start_at).total_seconds()
db.session.commit()
db.session.refresh(workflow_node_execution)
@ -253,33 +269,24 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
return workflow_node_execution
def _workflow_node_execution_failed(self, workflow_node_execution: WorkflowNodeExecution,
start_at: float,
error: str,
inputs: Optional[dict] = None,
process_data: Optional[dict] = None,
outputs: Optional[dict] = None,
execution_metadata: Optional[dict] = None
) -> WorkflowNodeExecution:
def _handle_workflow_node_execution_failed(self, event: QueueNodeFailedEvent) -> WorkflowNodeExecution:
"""
Workflow node execution failed
:param workflow_node_execution: workflow node execution
:param start_at: start time
:param error: error message
:param event: queue node failed event
:return:
"""
inputs = WorkflowEngineManager.handle_special_values(inputs)
outputs = WorkflowEngineManager.handle_special_values(outputs)
workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id)
inputs = WorkflowEntry.handle_special_values(event.inputs)
outputs = WorkflowEntry.handle_special_values(event.outputs)
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.error = error
workflow_node_execution.elapsed_time = time.perf_counter() - start_at
workflow_node_execution.error = event.error
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(execution_metadata)) \
if execution_metadata else None
workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - event.start_at).total_seconds()
db.session.commit()
db.session.refresh(workflow_node_execution)
@ -287,8 +294,13 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
return workflow_node_execution
def _workflow_start_to_stream_response(self, task_id: str,
workflow_run: WorkflowRun) -> WorkflowStartStreamResponse:
#################################################
# to stream responses #
#################################################
def _workflow_start_to_stream_response(
self, task_id: str, workflow_run: WorkflowRun
) -> WorkflowStartStreamResponse:
"""
Workflow start to stream response.
:param task_id: task id
@ -302,13 +314,14 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
id=workflow_run.id,
workflow_id=workflow_run.workflow_id,
sequence_number=workflow_run.sequence_number,
inputs=workflow_run.inputs_dict,
created_at=int(workflow_run.created_at.timestamp())
)
inputs=workflow_run.inputs_dict or {},
created_at=int(workflow_run.created_at.timestamp()),
),
)
def _workflow_finish_to_stream_response(self, task_id: str,
workflow_run: WorkflowRun) -> WorkflowFinishStreamResponse:
def _workflow_finish_to_stream_response(
self, task_id: str, workflow_run: WorkflowRun
) -> WorkflowFinishStreamResponse:
"""
Workflow finish to stream response.
:param task_id: task id
@ -320,16 +333,16 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
created_by_account = workflow_run.created_by_account
if created_by_account:
created_by = {
"id": created_by_account.id,
"name": created_by_account.name,
"email": created_by_account.email,
'id': created_by_account.id,
'name': created_by_account.name,
'email': created_by_account.email,
}
else:
created_by_end_user = workflow_run.created_by_end_user
if created_by_end_user:
created_by = {
"id": created_by_end_user.id,
"user": created_by_end_user.session_id,
'id': created_by_end_user.id,
'user': created_by_end_user.session_id,
}
return WorkflowFinishStreamResponse(
@ -348,14 +361,13 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
created_by=created_by,
created_at=int(workflow_run.created_at.timestamp()),
finished_at=int(workflow_run.finished_at.timestamp()),
files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict)
)
files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict or {}),
),
)
def _workflow_node_start_to_stream_response(self, event: QueueNodeStartedEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution) \
-> NodeStartStreamResponse:
def _workflow_node_start_to_stream_response(
self, event: QueueNodeStartedEvent, task_id: str, workflow_node_execution: WorkflowNodeExecution
) -> Optional[NodeStartStreamResponse]:
"""
Workflow node start to stream response.
:param event: queue node started event
@ -363,6 +375,9 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
:param workflow_node_execution: workflow node execution
:return:
"""
if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]:
return None
response = NodeStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_node_execution.workflow_run_id,
@ -374,8 +389,13 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
index=workflow_node_execution.index,
predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.inputs_dict,
created_at=int(workflow_node_execution.created_at.timestamp())
)
created_at=int(workflow_node_execution.created_at.timestamp()),
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
),
)
# extras logic
@ -384,19 +404,27 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
response.data.extras['icon'] = ToolManager.get_tool_icon(
tenant_id=self._application_generate_entity.app_config.tenant_id,
provider_type=node_data.provider_type,
provider_id=node_data.provider_id
provider_id=node_data.provider_id,
)
return response
def _workflow_node_finish_to_stream_response(self, task_id: str, workflow_node_execution: WorkflowNodeExecution) \
-> NodeFinishStreamResponse:
def _workflow_node_finish_to_stream_response(
self,
event: QueueNodeSucceededEvent | QueueNodeFailedEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution
) -> Optional[NodeFinishStreamResponse]:
"""
Workflow node finish to stream response.
:param event: queue node succeeded or failed event
:param task_id: task id
:param workflow_node_execution: workflow node execution
:return:
"""
if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]:
return None
return NodeFinishStreamResponse(
task_id=task_id,
workflow_run_id=workflow_node_execution.workflow_run_id,
@ -416,181 +444,155 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
execution_metadata=workflow_node_execution.execution_metadata_dict,
created_at=int(workflow_node_execution.created_at.timestamp()),
finished_at=int(workflow_node_execution.finished_at.timestamp()),
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict)
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}),
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
),
)
def _workflow_parallel_branch_start_to_stream_response(
self,
task_id: str,
workflow_run: WorkflowRun,
event: QueueParallelBranchRunStartedEvent
) -> ParallelBranchStartStreamResponse:
"""
Workflow parallel branch start to stream response
:param task_id: task id
:param workflow_run: workflow run
:param event: parallel branch run started event
:return:
"""
return ParallelBranchStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
data=ParallelBranchStartStreamResponse.Data(
parallel_id=event.parallel_id,
parallel_branch_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
created_at=int(time.time()),
)
)
def _workflow_parallel_branch_finished_to_stream_response(
self,
task_id: str,
workflow_run: WorkflowRun,
event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent
) -> ParallelBranchFinishedStreamResponse:
"""
Workflow parallel branch finished to stream response
:param task_id: task id
:param workflow_run: workflow run
:param event: parallel branch run succeeded or failed event
:return:
"""
return ParallelBranchFinishedStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
data=ParallelBranchFinishedStreamResponse.Data(
parallel_id=event.parallel_id,
parallel_branch_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
status='succeeded' if isinstance(event, QueueParallelBranchRunSucceededEvent) else 'failed',
error=event.error if isinstance(event, QueueParallelBranchRunFailedEvent) else None,
created_at=int(time.time()),
)
)
def _handle_workflow_start(self) -> WorkflowRun:
self._task_state.start_at = time.perf_counter()
workflow_run = self._init_workflow_run(
workflow=self._workflow,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING
if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER
else WorkflowRunTriggeredFrom.APP_RUN,
user=self._user,
user_inputs=self._application_generate_entity.inputs,
system_inputs=self._workflow_system_variables
def _workflow_iteration_start_to_stream_response(
self,
task_id: str,
workflow_run: WorkflowRun,
event: QueueIterationStartEvent
) -> IterationNodeStartStreamResponse:
"""
Workflow iteration start to stream response
:param task_id: task id
:param workflow_run: workflow run
:param event: iteration start event
:return:
"""
return IterationNodeStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
data=IterationNodeStartStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
created_at=int(time.time()),
extras={},
inputs=event.inputs or {},
metadata=event.metadata or {},
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
)
)
self._task_state.workflow_run_id = workflow_run.id
db.session.close()
return workflow_run
def _handle_node_start(self, event: QueueNodeStartedEvent) -> WorkflowNodeExecution:
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first()
workflow_node_execution = self._init_node_execution_from_workflow_run(
workflow_run=workflow_run,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_data.title,
node_run_index=event.node_run_index,
predecessor_node_id=event.predecessor_node_id
def _workflow_iteration_next_to_stream_response(self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent) -> IterationNodeNextStreamResponse:
"""
Workflow iteration next to stream response
:param task_id: task id
:param workflow_run: workflow run
:param event: iteration next event
:return:
"""
return IterationNodeNextStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
data=IterationNodeNextStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
index=event.index,
pre_iteration_output=event.output,
created_at=int(time.time()),
extras={},
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
)
)
latest_node_execution_info = NodeExecutionInfo(
workflow_node_execution_id=workflow_node_execution.id,
node_type=event.node_type,
start_at=time.perf_counter()
)
self._task_state.ran_node_execution_infos[event.node_id] = latest_node_execution_info
self._task_state.latest_node_execution_info = latest_node_execution_info
self._task_state.total_steps += 1
db.session.close()
return workflow_node_execution
def _handle_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> WorkflowNodeExecution:
current_node_execution = self._task_state.ran_node_execution_infos[event.node_id]
workflow_node_execution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == current_node_execution.workflow_node_execution_id).first()
execution_metadata = event.execution_metadata if isinstance(event, QueueNodeSucceededEvent) else None
if self._iteration_state and self._iteration_state.current_iterations:
if not execution_metadata:
execution_metadata = {}
current_iteration_data = None
for iteration_node_id in self._iteration_state.current_iterations:
data = self._iteration_state.current_iterations[iteration_node_id]
if data.parent_iteration_id == None:
current_iteration_data = data
break
if current_iteration_data:
execution_metadata[NodeRunMetadataKey.ITERATION_ID] = current_iteration_data.iteration_id
execution_metadata[NodeRunMetadataKey.ITERATION_INDEX] = current_iteration_data.current_index
if isinstance(event, QueueNodeSucceededEvent):
workflow_node_execution = self._workflow_node_execution_success(
workflow_node_execution=workflow_node_execution,
start_at=current_node_execution.start_at,
inputs=event.inputs,
process_data=event.process_data,
def _workflow_iteration_completed_to_stream_response(self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent) -> IterationNodeCompletedStreamResponse:
"""
Workflow iteration completed to stream response
:param task_id: task id
:param workflow_run: workflow run
:param event: iteration completed event
:return:
"""
return IterationNodeCompletedStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
data=IterationNodeCompletedStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
outputs=event.outputs,
execution_metadata=execution_metadata
created_at=int(time.time()),
extras={},
inputs=event.inputs or {},
status=WorkflowNodeExecutionStatus.SUCCEEDED,
error=None,
elapsed_time=(datetime.now(timezone.utc).replace(tzinfo=None) - event.start_at).total_seconds(),
total_tokens=event.metadata.get('total_tokens', 0) if event.metadata else 0,
execution_metadata=event.metadata,
finished_at=int(time.time()),
steps=event.steps,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
)
if execution_metadata and execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
self._task_state.total_tokens += (
int(execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)))
if self._iteration_state:
for iteration_node_id in self._iteration_state.current_iterations:
data = self._iteration_state.current_iterations[iteration_node_id]
if execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
data.total_tokens += int(execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS))
if workflow_node_execution.node_type == NodeType.LLM.value:
outputs = workflow_node_execution.outputs_dict
usage_dict = outputs.get('usage', {})
self._task_state.metadata['usage'] = usage_dict
else:
workflow_node_execution = self._workflow_node_execution_failed(
workflow_node_execution=workflow_node_execution,
start_at=current_node_execution.start_at,
error=event.error,
inputs=event.inputs,
process_data=event.process_data,
outputs=event.outputs,
execution_metadata=execution_metadata
)
db.session.close()
return workflow_node_execution
def _handle_workflow_finished(
self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent,
conversation_id: Optional[str] = None,
trace_manager: Optional[TraceQueueManager] = None
) -> Optional[WorkflowRun]:
workflow_run = db.session.query(WorkflowRun).filter(
WorkflowRun.id == self._task_state.workflow_run_id).first()
if not workflow_run:
return None
if conversation_id is None:
conversation_id = self._application_generate_entity.inputs.get('sys.conversation_id')
if isinstance(event, QueueStopEvent):
workflow_run = self._workflow_run_failed(
workflow_run=workflow_run,
total_tokens=self._task_state.total_tokens,
total_steps=self._task_state.total_steps,
status=WorkflowRunStatus.STOPPED,
error='Workflow stopped.',
conversation_id=conversation_id,
trace_manager=trace_manager
)
latest_node_execution_info = self._task_state.latest_node_execution_info
if latest_node_execution_info:
workflow_node_execution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == latest_node_execution_info.workflow_node_execution_id).first()
if (workflow_node_execution
and workflow_node_execution.status == WorkflowNodeExecutionStatus.RUNNING.value):
self._workflow_node_execution_failed(
workflow_node_execution=workflow_node_execution,
start_at=latest_node_execution_info.start_at,
error='Workflow stopped.'
)
elif isinstance(event, QueueWorkflowFailedEvent):
workflow_run = self._workflow_run_failed(
workflow_run=workflow_run,
total_tokens=self._task_state.total_tokens,
total_steps=self._task_state.total_steps,
status=WorkflowRunStatus.FAILED,
error=event.error,
conversation_id=conversation_id,
trace_manager=trace_manager
)
else:
if self._task_state.latest_node_execution_info:
workflow_node_execution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == self._task_state.latest_node_execution_info.workflow_node_execution_id).first()
outputs = workflow_node_execution.outputs
else:
outputs = None
workflow_run = self._workflow_run_success(
workflow_run=workflow_run,
total_tokens=self._task_state.total_tokens,
total_steps=self._task_state.total_steps,
outputs=outputs,
conversation_id=conversation_id,
trace_manager=trace_manager
)
self._task_state.workflow_run_id = workflow_run.id
db.session.close()
return workflow_run
)
def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> list[dict]:
"""
@ -647,3 +649,40 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
return value.to_dict()
return None
def _refetch_workflow_run(self, workflow_run_id: str) -> WorkflowRun:
"""
Refetch workflow run
:param workflow_run_id: workflow run id
:return:
"""
workflow_run = db.session.query(WorkflowRun).filter(
WorkflowRun.id == workflow_run_id).first()
if not workflow_run:
raise Exception(f'Workflow run not found: {workflow_run_id}')
return workflow_run
def _refetch_workflow_node_execution(self, node_execution_id: str) -> WorkflowNodeExecution:
"""
Refetch workflow node execution
:param node_execution_id: workflow node execution id
:return:
"""
workflow_node_execution = (
db.session.query(WorkflowNodeExecution)
.filter(
WorkflowNodeExecution.tenant_id == self._application_generate_entity.app_config.tenant_id,
WorkflowNodeExecution.app_id == self._application_generate_entity.app_config.app_id,
WorkflowNodeExecution.workflow_id == self._workflow.id,
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
WorkflowNodeExecution.node_execution_id == node_execution_id,
)
.first()
)
if not workflow_node_execution:
raise Exception(f'Workflow node execution not found: {node_execution_id}')
return workflow_node_execution

View File

@ -1,16 +0,0 @@
from typing import Any, Union
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
from core.app.entities.task_entities import AdvancedChatTaskState, WorkflowTaskState
from core.workflow.enums import SystemVariableKey
from models.account import Account
from models.model import EndUser
from models.workflow import Workflow
class WorkflowCycleStateManager:
_application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity]
_workflow: Workflow
_user: Union[Account, EndUser]
_task_state: Union[AdvancedChatTaskState, WorkflowTaskState]
_workflow_system_variables: dict[SystemVariableKey, Any]

View File

@ -1,290 +0,0 @@
import json
import time
from collections.abc import Generator
from datetime import datetime, timezone
from typing import Optional, Union
from core.app.entities.queue_entities import (
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
)
from core.app.entities.task_entities import (
IterationNodeCompletedStreamResponse,
IterationNodeNextStreamResponse,
IterationNodeStartStreamResponse,
NodeExecutionInfo,
WorkflowIterationState,
)
from core.app.task_pipeline.workflow_cycle_state_manager import WorkflowCycleStateManager
from core.workflow.entities.node_entities import NodeType
from core.workflow.workflow_engine_manager import WorkflowEngineManager
from extensions.ext_database import db
from models.workflow import (
WorkflowNodeExecution,
WorkflowNodeExecutionStatus,
WorkflowNodeExecutionTriggeredFrom,
WorkflowRun,
)
class WorkflowIterationCycleManage(WorkflowCycleStateManager):
_iteration_state: WorkflowIterationState = None
def _init_iteration_state(self) -> WorkflowIterationState:
if not self._iteration_state:
self._iteration_state = WorkflowIterationState(
current_iterations={}
)
def _handle_iteration_to_stream_response(self, task_id: str, event: QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent) \
-> Union[IterationNodeStartStreamResponse, IterationNodeNextStreamResponse, IterationNodeCompletedStreamResponse]:
"""
Handle iteration to stream response
:param task_id: task id
:param event: iteration event
:return:
"""
if isinstance(event, QueueIterationStartEvent):
return IterationNodeStartStreamResponse(
task_id=task_id,
workflow_run_id=self._task_state.workflow_run_id,
data=IterationNodeStartStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
created_at=int(time.time()),
extras={},
inputs=event.inputs,
metadata=event.metadata
)
)
elif isinstance(event, QueueIterationNextEvent):
current_iteration = self._iteration_state.current_iterations[event.node_id]
return IterationNodeNextStreamResponse(
task_id=task_id,
workflow_run_id=self._task_state.workflow_run_id,
data=IterationNodeNextStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=current_iteration.node_data.title,
index=event.index,
pre_iteration_output=event.output,
created_at=int(time.time()),
extras={}
)
)
elif isinstance(event, QueueIterationCompletedEvent):
current_iteration = self._iteration_state.current_iterations[event.node_id]
return IterationNodeCompletedStreamResponse(
task_id=task_id,
workflow_run_id=self._task_state.workflow_run_id,
data=IterationNodeCompletedStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=current_iteration.node_data.title,
outputs=event.outputs,
created_at=int(time.time()),
extras={},
inputs=current_iteration.inputs,
status=WorkflowNodeExecutionStatus.SUCCEEDED,
error=None,
elapsed_time=time.perf_counter() - current_iteration.started_at,
total_tokens=current_iteration.total_tokens,
execution_metadata={
'total_tokens': current_iteration.total_tokens,
},
finished_at=int(time.time()),
steps=current_iteration.current_index
)
)
def _init_iteration_execution_from_workflow_run(self,
workflow_run: WorkflowRun,
node_id: str,
node_type: NodeType,
node_title: str,
node_run_index: int = 1,
inputs: Optional[dict] = None,
predecessor_node_id: Optional[str] = None
) -> WorkflowNodeExecution:
workflow_node_execution = WorkflowNodeExecution(
tenant_id=workflow_run.tenant_id,
app_id=workflow_run.app_id,
workflow_id=workflow_run.workflow_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
workflow_run_id=workflow_run.id,
predecessor_node_id=predecessor_node_id,
index=node_run_index,
node_id=node_id,
node_type=node_type.value,
inputs=json.dumps(inputs) if inputs else None,
title=node_title,
status=WorkflowNodeExecutionStatus.RUNNING.value,
created_by_role=workflow_run.created_by_role,
created_by=workflow_run.created_by,
execution_metadata=json.dumps({
'started_run_index': node_run_index + 1,
'current_index': 0,
'steps_boundary': [],
}),
created_at=datetime.now(timezone.utc).replace(tzinfo=None)
)
db.session.add(workflow_node_execution)
db.session.commit()
db.session.refresh(workflow_node_execution)
db.session.close()
return workflow_node_execution
def _handle_iteration_operation(self, event: QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent) -> WorkflowNodeExecution:
if isinstance(event, QueueIterationStartEvent):
return self._handle_iteration_started(event)
elif isinstance(event, QueueIterationNextEvent):
return self._handle_iteration_next(event)
elif isinstance(event, QueueIterationCompletedEvent):
return self._handle_iteration_completed(event)
def _handle_iteration_started(self, event: QueueIterationStartEvent) -> WorkflowNodeExecution:
self._init_iteration_state()
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first()
workflow_node_execution = self._init_iteration_execution_from_workflow_run(
workflow_run=workflow_run,
node_id=event.node_id,
node_type=NodeType.ITERATION,
node_title=event.node_data.title,
node_run_index=event.node_run_index,
inputs=event.inputs,
predecessor_node_id=event.predecessor_node_id
)
latest_node_execution_info = NodeExecutionInfo(
workflow_node_execution_id=workflow_node_execution.id,
node_type=NodeType.ITERATION,
start_at=time.perf_counter()
)
self._task_state.ran_node_execution_infos[event.node_id] = latest_node_execution_info
self._task_state.latest_node_execution_info = latest_node_execution_info
self._iteration_state.current_iterations[event.node_id] = WorkflowIterationState.Data(
parent_iteration_id=None,
iteration_id=event.node_id,
current_index=0,
iteration_steps_boundary=[],
node_execution_id=workflow_node_execution.id,
started_at=time.perf_counter(),
inputs=event.inputs,
total_tokens=0,
node_data=event.node_data
)
db.session.close()
return workflow_node_execution
def _handle_iteration_next(self, event: QueueIterationNextEvent) -> WorkflowNodeExecution:
if event.node_id not in self._iteration_state.current_iterations:
return
current_iteration = self._iteration_state.current_iterations[event.node_id]
current_iteration.current_index = event.index
current_iteration.iteration_steps_boundary.append(event.node_run_index)
workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == current_iteration.node_execution_id
).first()
original_node_execution_metadata = workflow_node_execution.execution_metadata_dict
if original_node_execution_metadata:
original_node_execution_metadata['current_index'] = event.index
original_node_execution_metadata['steps_boundary'] = current_iteration.iteration_steps_boundary
original_node_execution_metadata['total_tokens'] = current_iteration.total_tokens
workflow_node_execution.execution_metadata = json.dumps(original_node_execution_metadata)
db.session.commit()
db.session.close()
def _handle_iteration_completed(self, event: QueueIterationCompletedEvent):
if event.node_id not in self._iteration_state.current_iterations:
return
current_iteration = self._iteration_state.current_iterations[event.node_id]
workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == current_iteration.node_execution_id
).first()
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
workflow_node_execution.outputs = json.dumps(WorkflowEngineManager.handle_special_values(event.outputs)) if event.outputs else None
workflow_node_execution.elapsed_time = time.perf_counter() - current_iteration.started_at
original_node_execution_metadata = workflow_node_execution.execution_metadata_dict
if original_node_execution_metadata:
original_node_execution_metadata['steps_boundary'] = current_iteration.iteration_steps_boundary
original_node_execution_metadata['total_tokens'] = current_iteration.total_tokens
workflow_node_execution.execution_metadata = json.dumps(original_node_execution_metadata)
db.session.commit()
# remove current iteration
self._iteration_state.current_iterations.pop(event.node_id, None)
# set latest node execution info
latest_node_execution_info = NodeExecutionInfo(
workflow_node_execution_id=workflow_node_execution.id,
node_type=NodeType.ITERATION,
start_at=time.perf_counter()
)
self._task_state.latest_node_execution_info = latest_node_execution_info
db.session.close()
def _handle_iteration_exception(self, task_id: str, error: str) -> Generator[IterationNodeCompletedStreamResponse, None, None]:
"""
Handle iteration exception
"""
if not self._iteration_state or not self._iteration_state.current_iterations:
return
for node_id, current_iteration in self._iteration_state.current_iterations.items():
workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == current_iteration.node_execution_id
).first()
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.error = error
workflow_node_execution.elapsed_time = time.perf_counter() - current_iteration.started_at
db.session.commit()
db.session.close()
yield IterationNodeCompletedStreamResponse(
task_id=task_id,
workflow_run_id=self._task_state.workflow_run_id,
data=IterationNodeCompletedStreamResponse.Data(
id=node_id,
node_id=node_id,
node_type=NodeType.ITERATION.value,
title=current_iteration.node_data.title,
outputs={},
created_at=int(time.time()),
extras={},
inputs=current_iteration.inputs,
status=WorkflowNodeExecutionStatus.FAILED,
error=error,
elapsed_time=time.perf_counter() - current_iteration.started_at,
total_tokens=current_iteration.total_tokens,
execution_metadata={
'total_tokens': current_iteration.total_tokens,
},
finished_at=int(time.time()),
steps=current_iteration.current_index
)
)

View File

@ -63,6 +63,39 @@ class LLMUsage(ModelUsage):
latency=0.0
)
def plus(self, other: 'LLMUsage') -> 'LLMUsage':
"""
Add two LLMUsage instances together.
:param other: Another LLMUsage instance to add
:return: A new LLMUsage instance with summed values
"""
if self.total_tokens == 0:
return other
else:
return LLMUsage(
prompt_tokens=self.prompt_tokens + other.prompt_tokens,
prompt_unit_price=other.prompt_unit_price,
prompt_price_unit=other.prompt_price_unit,
prompt_price=self.prompt_price + other.prompt_price,
completion_tokens=self.completion_tokens + other.completion_tokens,
completion_unit_price=other.completion_unit_price,
completion_price_unit=other.completion_price_unit,
completion_price=self.completion_price + other.completion_price,
total_tokens=self.total_tokens + other.total_tokens,
total_price=self.total_price + other.total_price,
currency=other.currency,
latency=self.latency + other.latency
)
def __add__(self, other: 'LLMUsage') -> 'LLMUsage':
"""
Overload the + operator to add two LLMUsage instances.
:param other: Another LLMUsage instance to add
:return: A new LLMUsage instance with summed values
"""
return self.plus(other)
class LLMResult(BaseModel):
"""

View File

@ -34,13 +34,13 @@ class OutputModeration(BaseModel):
final_output: Optional[str] = None
model_config = ConfigDict(arbitrary_types_allowed=True)
def should_direct_output(self):
def should_direct_output(self) -> bool:
return self.final_output is not None
def get_final_output(self):
return self.final_output
def get_final_output(self) -> str:
return self.final_output or ""
def append_new_token(self, token: str):
def append_new_token(self, token: str) -> None:
self.buffer += token
if not self.thread:

View File

@ -1,7 +1,7 @@
import json
import logging
from copy import deepcopy
from typing import Any, Union
from typing import Any, Optional, Union
from core.file.file_obj import FileTransferMethod, FileVar
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType
@ -18,6 +18,7 @@ class WorkflowTool(Tool):
version: str
workflow_entities: dict[str, Any]
workflow_call_depth: int
thread_pool_id: Optional[str] = None
label: str
@ -57,6 +58,7 @@ class WorkflowTool(Tool):
invoke_from=self.runtime.invoke_from,
stream=False,
call_depth=self.workflow_call_depth + 1,
workflow_thread_pool_id=self.thread_pool_id
)
data = result.get('data', {})

View File

@ -128,6 +128,7 @@ class ToolEngine:
user_id: str,
workflow_tool_callback: DifyWorkflowCallbackHandler,
workflow_call_depth: int,
thread_pool_id: Optional[str] = None
) -> list[ToolInvokeMessage]:
"""
Workflow invokes the tool with the given arguments.
@ -141,6 +142,7 @@ class ToolEngine:
if isinstance(tool, WorkflowTool):
tool.workflow_call_depth = workflow_call_depth + 1
tool.thread_pool_id = thread_pool_id
if tool.runtime and tool.runtime.runtime_parameters:
tool_parameters = {**tool.runtime.runtime_parameters, **tool_parameters}

View File

@ -25,7 +25,6 @@ from core.tools.tool.tool import Tool
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
from core.workflow.nodes.tool.entities import ToolEntity
from extensions.ext_database import db
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
from services.tools.tools_transform_service import ToolTransformService
@ -249,7 +248,7 @@ class ToolManager:
return tool_entity
@classmethod
def get_workflow_tool_runtime(cls, tenant_id: str, app_id: str, node_id: str, workflow_tool: ToolEntity, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER) -> Tool:
def get_workflow_tool_runtime(cls, tenant_id: str, app_id: str, node_id: str, workflow_tool: "ToolEntity", invoke_from: InvokeFrom = InvokeFrom.DEBUGGER) -> Tool:
"""
get the workflow tool runtime
"""

View File

@ -7,6 +7,7 @@ from core.tools.tool_file_manager import ToolFileManager
logger = logging.getLogger(__name__)
class ToolFileMessageTransformer:
@classmethod
def transform_tool_invoke_messages(cls, messages: list[ToolInvokeMessage],

View File

@ -1,116 +1,15 @@
from abc import ABC, abstractmethod
from typing import Any, Optional
from core.app.entities.queue_entities import AppQueueEvent
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
from core.workflow.graph_engine.entities.event import GraphEngineEvent
class WorkflowCallback(ABC):
@abstractmethod
def on_workflow_run_started(self) -> None:
def on_event(
self,
event: GraphEngineEvent
) -> None:
"""
Workflow run started
"""
raise NotImplementedError
@abstractmethod
def on_workflow_run_succeeded(self) -> None:
"""
Workflow run succeeded
"""
raise NotImplementedError
@abstractmethod
def on_workflow_run_failed(self, error: str) -> None:
"""
Workflow run failed
"""
raise NotImplementedError
@abstractmethod
def on_workflow_node_execute_started(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
node_run_index: int = 1,
predecessor_node_id: Optional[str] = None) -> None:
"""
Workflow node execute started
"""
raise NotImplementedError
@abstractmethod
def on_workflow_node_execute_succeeded(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
inputs: Optional[dict] = None,
process_data: Optional[dict] = None,
outputs: Optional[dict] = None,
execution_metadata: Optional[dict] = None) -> None:
"""
Workflow node execute succeeded
"""
raise NotImplementedError
@abstractmethod
def on_workflow_node_execute_failed(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
error: str,
inputs: Optional[dict] = None,
outputs: Optional[dict] = None,
process_data: Optional[dict] = None) -> None:
"""
Workflow node execute failed
"""
raise NotImplementedError
@abstractmethod
def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None:
"""
Publish text chunk
"""
raise NotImplementedError
@abstractmethod
def on_workflow_iteration_started(self,
node_id: str,
node_type: NodeType,
node_run_index: int = 1,
node_data: Optional[BaseNodeData] = None,
inputs: Optional[dict] = None,
predecessor_node_id: Optional[str] = None,
metadata: Optional[dict] = None) -> None:
"""
Publish iteration started
"""
raise NotImplementedError
@abstractmethod
def on_workflow_iteration_next(self, node_id: str,
node_type: NodeType,
index: int,
node_run_index: int,
output: Optional[Any],
) -> None:
"""
Publish iteration next
"""
raise NotImplementedError
@abstractmethod
def on_workflow_iteration_completed(self, node_id: str,
node_type: NodeType,
node_run_index: int,
outputs: dict) -> None:
"""
Publish iteration completed
"""
raise NotImplementedError
@abstractmethod
def on_event(self, event: AppQueueEvent) -> None:
"""
Publish event
Published event
"""
raise NotImplementedError

View File

@ -9,7 +9,7 @@ class BaseNodeData(ABC, BaseModel):
desc: Optional[str] = None
class BaseIterationNodeData(BaseNodeData):
start_node_id: str
start_node_id: Optional[str] = None
class BaseIterationState(BaseModel):
iteration_node_id: str

View File

@ -1,9 +1,9 @@
from collections.abc import Mapping
from enum import Enum
from typing import Any, Optional
from pydantic import BaseModel
from core.model_runtime.entities.llm_entities import LLMUsage
from models import WorkflowNodeExecutionStatus
@ -28,6 +28,7 @@ class NodeType(Enum):
VARIABLE_ASSIGNER = 'variable-assigner'
LOOP = 'loop'
ITERATION = 'iteration'
ITERATION_START = 'iteration-start' # fake start node for iteration
PARAMETER_EXTRACTOR = 'parameter-extractor'
CONVERSATION_VARIABLE_ASSIGNER = 'assigner'
@ -56,6 +57,10 @@ class NodeRunMetadataKey(Enum):
TOOL_INFO = 'tool_info'
ITERATION_ID = 'iteration_id'
ITERATION_INDEX = 'iteration_index'
PARALLEL_ID = 'parallel_id'
PARALLEL_START_NODE_ID = 'parallel_start_node_id'
PARENT_PARALLEL_ID = 'parent_parallel_id'
PARENT_PARALLEL_START_NODE_ID = 'parent_parallel_start_node_id'
class NodeRunResult(BaseModel):
@ -65,11 +70,32 @@ class NodeRunResult(BaseModel):
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING
inputs: Optional[Mapping[str, Any]] = None # node inputs
process_data: Optional[dict] = None # process data
outputs: Optional[Mapping[str, Any]] = None # node outputs
inputs: Optional[dict[str, Any]] = None # node inputs
process_data: Optional[dict[str, Any]] = None # process data
outputs: Optional[dict[str, Any]] = None # node outputs
metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata
llm_usage: Optional[LLMUsage] = None # llm usage
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches
error: Optional[str] = None # error message if status is failed
class UserFrom(Enum):
"""
User from
"""
ACCOUNT = "account"
END_USER = "end-user"
@classmethod
def value_of(cls, value: str) -> "UserFrom":
"""
Value of
:param value: value
:return:
"""
for item in cls:
if item.value == value:
return item
raise ValueError(f"Invalid value: {value}")

View File

@ -2,6 +2,7 @@ from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import Any, Union
from pydantic import BaseModel, Field, model_validator
from typing_extensions import deprecated
from core.app.segments import Segment, Variable, factory
@ -16,43 +17,52 @@ ENVIRONMENT_VARIABLE_NODE_ID = "env"
CONVERSATION_VARIABLE_NODE_ID = "conversation"
class VariablePool:
def __init__(
self,
system_variables: Mapping[SystemVariableKey, Any],
user_inputs: Mapping[str, Any],
environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable] | None = None,
) -> None:
# system variables
# for example:
# {
# 'query': 'abc',
# 'files': []
# }
class VariablePool(BaseModel):
# Variable dictionary is a dictionary for looking up variables by their selector.
# The first element of the selector is the node id, it's the first-level key in the dictionary.
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
# elements of the selector except the first one.
variable_dictionary: dict[str, dict[int, Segment]] = Field(
description='Variables mapping',
default=defaultdict(dict)
)
# Variable dictionary is a dictionary for looking up variables by their selector.
# The first element of the selector is the node id, it's the first-level key in the dictionary.
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
# elements of the selector except the first one.
self._variable_dictionary: dict[str, dict[int, Segment]] = defaultdict(dict)
# TODO: This user inputs is not used for pool.
user_inputs: Mapping[str, Any] = Field(
description='User inputs',
)
# TODO: This user inputs is not used for pool.
self.user_inputs = user_inputs
system_variables: Mapping[SystemVariableKey, Any] = Field(
description='System variables',
)
environment_variables: Sequence[Variable] = Field(
description="Environment variables.",
default_factory=list
)
conversation_variables: Sequence[Variable] | None = None
@model_validator(mode="after")
def val_model_after(self):
"""
Append system variables
:return:
"""
# Add system variables to the variable pool
self.system_variables = system_variables
for key, value in system_variables.items():
for key, value in self.system_variables.items():
self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value)
# Add environment variables to the variable pool
for var in environment_variables:
for var in self.environment_variables or []:
self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
# Add conversation variables to the variable pool
for var in conversation_variables or []:
for var in self.conversation_variables or []:
self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)
return self
def add(self, selector: Sequence[str], value: Any, /) -> None:
"""
Adds a variable to the variable pool.
@ -79,7 +89,7 @@ class VariablePool:
v = factory.build_segment(value)
hash_key = hash(tuple(selector[1:]))
self._variable_dictionary[selector[0]][hash_key] = v
self.variable_dictionary[selector[0]][hash_key] = v
def get(self, selector: Sequence[str], /) -> Segment | None:
"""
@ -97,7 +107,7 @@ class VariablePool:
if len(selector) < 2:
raise ValueError("Invalid selector")
hash_key = hash(tuple(selector[1:]))
value = self._variable_dictionary[selector[0]].get(hash_key)
value = self.variable_dictionary[selector[0]].get(hash_key)
return value
@ -118,7 +128,7 @@ class VariablePool:
if len(selector) < 2:
raise ValueError("Invalid selector")
hash_key = hash(tuple(selector[1:]))
value = self._variable_dictionary[selector[0]].get(hash_key)
value = self.variable_dictionary[selector[0]].get(hash_key)
return value.to_object() if value else None
def remove(self, selector: Sequence[str], /):
@ -134,7 +144,19 @@ class VariablePool:
if not selector:
return
if len(selector) == 1:
self._variable_dictionary[selector[0]] = {}
self.variable_dictionary[selector[0]] = {}
return
hash_key = hash(tuple(selector[1:]))
self._variable_dictionary[selector[0]].pop(hash_key, None)
self.variable_dictionary[selector[0]].pop(hash_key, None)
def remove_node(self, node_id: str, /):
"""
Remove all variables associated with a given node id.
Args:
node_id (str): The node id to remove.
Returns:
None
"""
self.variable_dictionary.pop(node_id, None)

View File

@ -66,8 +66,7 @@ class WorkflowRunState:
self.variable_pool = variable_pool
self.total_tokens = 0
self.workflow_nodes_and_results = []
self.current_iteration_state = None
self.workflow_node_steps = 1
self.workflow_node_runs = []
self.workflow_node_runs = []
self.current_iteration_state = None

View File

@ -1,10 +1,8 @@
from core.workflow.entities.node_entities import NodeType
from core.workflow.nodes.base_node import BaseNode
class WorkflowNodeRunFailedError(Exception):
def __init__(self, node_id: str, node_type: NodeType, node_title: str, error: str):
self.node_id = node_id
self.node_type = node_type
self.node_title = node_title
def __init__(self, node_instance: BaseNode, error: str):
self.node_instance = node_instance
self.error = error
super().__init__(f"Node {node_title} run failed: {error}")
super().__init__(f"Node {node_instance.node_data.title} run failed: {error}")

View File

@ -0,0 +1,31 @@
from abc import ABC, abstractmethod
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.entities.run_condition import RunCondition
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
class RunConditionHandler(ABC):
def __init__(self,
init_params: GraphInitParams,
graph: Graph,
condition: RunCondition):
self.init_params = init_params
self.graph = graph
self.condition = condition
@abstractmethod
def check(self,
graph_runtime_state: GraphRuntimeState,
previous_route_node_state: RouteNodeState
) -> bool:
"""
Check if the condition can be executed
:param graph_runtime_state: graph runtime state
:param previous_route_node_state: previous route node state
:return: bool
"""
raise NotImplementedError

View File

@ -0,0 +1,28 @@
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
class BranchIdentifyRunConditionHandler(RunConditionHandler):
def check(self,
graph_runtime_state: GraphRuntimeState,
previous_route_node_state: RouteNodeState) -> bool:
"""
Check if the condition can be executed
:param graph_runtime_state: graph runtime state
:param previous_route_node_state: previous route node state
:return: bool
"""
if not self.condition.branch_identify:
raise Exception("Branch identify is required")
run_result = previous_route_node_state.node_run_result
if not run_result:
return False
if not run_result.edge_source_handle:
return False
return self.condition.branch_identify == run_result.edge_source_handle

View File

@ -0,0 +1,32 @@
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.utils.condition.processor import ConditionProcessor
class ConditionRunConditionHandlerHandler(RunConditionHandler):
def check(self,
graph_runtime_state: GraphRuntimeState,
previous_route_node_state: RouteNodeState
) -> bool:
"""
Check if the condition can be executed
:param graph_runtime_state: graph runtime state
:param previous_route_node_state: previous route node state
:return: bool
"""
if not self.condition.conditions:
return True
# process condition
condition_processor = ConditionProcessor()
input_conditions, group_result = condition_processor.process_conditions(
variable_pool=graph_runtime_state.variable_pool,
conditions=self.condition.conditions
)
# Apply the logical operator for the current case
compare_result = all(group_result)
return compare_result

View File

@ -0,0 +1,35 @@
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
from core.workflow.graph_engine.condition_handlers.branch_identify_handler import BranchIdentifyRunConditionHandler
from core.workflow.graph_engine.condition_handlers.condition_handler import ConditionRunConditionHandlerHandler
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.run_condition import RunCondition
class ConditionManager:
@staticmethod
def get_condition_handler(
init_params: GraphInitParams,
graph: Graph,
run_condition: RunCondition
) -> RunConditionHandler:
"""
Get condition handler
:param init_params: init params
:param graph: graph
:param run_condition: run condition
:return: condition handler
"""
if run_condition.type == "branch_identify":
return BranchIdentifyRunConditionHandler(
init_params=init_params,
graph=graph,
condition=run_condition
)
else:
return ConditionRunConditionHandlerHandler(
init_params=init_params,
graph=graph,
condition=run_condition
)

View File

@ -0,0 +1,163 @@
from datetime import datetime
from typing import Any, Optional
from pydantic import BaseModel, Field
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
class GraphEngineEvent(BaseModel):
pass
###########################################
# Graph Events
###########################################
class BaseGraphEvent(GraphEngineEvent):
pass
class GraphRunStartedEvent(BaseGraphEvent):
pass
class GraphRunSucceededEvent(BaseGraphEvent):
outputs: Optional[dict[str, Any]] = None
"""outputs"""
class GraphRunFailedEvent(BaseGraphEvent):
error: str = Field(..., description="failed reason")
###########################################
# Node Events
###########################################
class BaseNodeEvent(GraphEngineEvent):
id: str = Field(..., description="node execution id")
node_id: str = Field(..., description="node id")
node_type: NodeType = Field(..., description="node type")
node_data: BaseNodeData = Field(..., description="node data")
route_node_state: RouteNodeState = Field(..., description="route node state")
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
class NodeRunStartedEvent(BaseNodeEvent):
predecessor_node_id: Optional[str] = None
"""predecessor node id"""
class NodeRunStreamChunkEvent(BaseNodeEvent):
chunk_content: str = Field(..., description="chunk content")
from_variable_selector: Optional[list[str]] = None
"""from variable selector"""
class NodeRunRetrieverResourceEvent(BaseNodeEvent):
retriever_resources: list[dict] = Field(..., description="retriever resources")
context: str = Field(..., description="context")
class NodeRunSucceededEvent(BaseNodeEvent):
pass
class NodeRunFailedEvent(BaseNodeEvent):
error: str = Field(..., description="error")
###########################################
# Parallel Branch Events
###########################################
class BaseParallelBranchEvent(GraphEngineEvent):
parallel_id: str = Field(..., description="parallel id")
"""parallel id"""
parallel_start_node_id: str = Field(..., description="parallel start node id")
"""parallel start node id"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
class ParallelBranchRunStartedEvent(BaseParallelBranchEvent):
pass
class ParallelBranchRunSucceededEvent(BaseParallelBranchEvent):
pass
class ParallelBranchRunFailedEvent(BaseParallelBranchEvent):
error: str = Field(..., description="failed reason")
###########################################
# Iteration Events
###########################################
class BaseIterationEvent(GraphEngineEvent):
iteration_id: str = Field(..., description="iteration node execution id")
iteration_node_id: str = Field(..., description="iteration node id")
iteration_node_type: NodeType = Field(..., description="node type, iteration or loop")
iteration_node_data: BaseNodeData = Field(..., description="node data")
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
class IterationRunStartedEvent(BaseIterationEvent):
start_at: datetime = Field(..., description="start at")
inputs: Optional[dict[str, Any]] = None
metadata: Optional[dict[str, Any]] = None
predecessor_node_id: Optional[str] = None
class IterationRunNextEvent(BaseIterationEvent):
index: int = Field(..., description="index")
pre_iteration_output: Optional[Any] = Field(None, description="pre iteration output")
class IterationRunSucceededEvent(BaseIterationEvent):
start_at: datetime = Field(..., description="start at")
inputs: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
metadata: Optional[dict[str, Any]] = None
steps: int = 0
class IterationRunFailedEvent(BaseIterationEvent):
start_at: datetime = Field(..., description="start at")
inputs: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
metadata: Optional[dict[str, Any]] = None
steps: int = 0
error: str = Field(..., description="failed reason")
InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent

View File

@ -0,0 +1,692 @@
import uuid
from collections.abc import Mapping
from typing import Any, Optional, cast
from pydantic import BaseModel, Field
from core.workflow.entities.node_entities import NodeType
from core.workflow.graph_engine.entities.run_condition import RunCondition
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
from core.workflow.nodes.answer.entities import AnswerStreamGenerateRoute
from core.workflow.nodes.end.end_stream_generate_router import EndStreamGeneratorRouter
from core.workflow.nodes.end.entities import EndStreamParam
class GraphEdge(BaseModel):
source_node_id: str = Field(..., description="source node id")
target_node_id: str = Field(..., description="target node id")
run_condition: Optional[RunCondition] = None
"""run condition"""
class GraphParallel(BaseModel):
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="random uuid parallel id")
start_from_node_id: str = Field(..., description="start from node id")
parent_parallel_id: Optional[str] = None
"""parent parallel id"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id"""
end_to_node_id: Optional[str] = None
"""end to node id"""
class Graph(BaseModel):
root_node_id: str = Field(..., description="root node id of the graph")
node_ids: list[str] = Field(default_factory=list, description="graph node ids")
node_id_config_mapping: dict[str, dict] = Field(
default_factory=list,
description="node configs mapping (node id: node config)"
)
edge_mapping: dict[str, list[GraphEdge]] = Field(
default_factory=dict,
description="graph edge mapping (source node id: edges)"
)
reverse_edge_mapping: dict[str, list[GraphEdge]] = Field(
default_factory=dict,
description="reverse graph edge mapping (target node id: edges)"
)
parallel_mapping: dict[str, GraphParallel] = Field(
default_factory=dict,
description="graph parallel mapping (parallel id: parallel)"
)
node_parallel_mapping: dict[str, str] = Field(
default_factory=dict,
description="graph node parallel mapping (node id: parallel id)"
)
answer_stream_generate_routes: AnswerStreamGenerateRoute = Field(
...,
description="answer stream generate routes"
)
end_stream_param: EndStreamParam = Field(
...,
description="end stream param"
)
@classmethod
def init(cls,
graph_config: Mapping[str, Any],
root_node_id: Optional[str] = None) -> "Graph":
"""
Init graph
:param graph_config: graph config
:param root_node_id: root node id
:return: graph
"""
# edge configs
edge_configs = graph_config.get('edges')
if edge_configs is None:
edge_configs = []
edge_configs = cast(list, edge_configs)
# reorganize edges mapping
edge_mapping: dict[str, list[GraphEdge]] = {}
reverse_edge_mapping: dict[str, list[GraphEdge]] = {}
target_edge_ids = set()
for edge_config in edge_configs:
source_node_id = edge_config.get('source')
if not source_node_id:
continue
if source_node_id not in edge_mapping:
edge_mapping[source_node_id] = []
target_node_id = edge_config.get('target')
if not target_node_id:
continue
if target_node_id not in reverse_edge_mapping:
reverse_edge_mapping[target_node_id] = []
# is target node id in source node id edge mapping
if any(graph_edge.target_node_id == target_node_id for graph_edge in edge_mapping[source_node_id]):
continue
target_edge_ids.add(target_node_id)
# parse run condition
run_condition = None
if edge_config.get('sourceHandle') and edge_config.get('sourceHandle') != 'source':
run_condition = RunCondition(
type='branch_identify',
branch_identify=edge_config.get('sourceHandle')
)
graph_edge = GraphEdge(
source_node_id=source_node_id,
target_node_id=target_node_id,
run_condition=run_condition
)
edge_mapping[source_node_id].append(graph_edge)
reverse_edge_mapping[target_node_id].append(graph_edge)
# node configs
node_configs = graph_config.get('nodes')
if not node_configs:
raise ValueError("Graph must have at least one node")
node_configs = cast(list, node_configs)
# fetch nodes that have no predecessor node
root_node_configs = []
all_node_id_config_mapping: dict[str, dict] = {}
for node_config in node_configs:
node_id = node_config.get('id')
if not node_id:
continue
if node_id not in target_edge_ids:
root_node_configs.append(node_config)
all_node_id_config_mapping[node_id] = node_config
root_node_ids = [node_config.get('id') for node_config in root_node_configs]
# fetch root node
if not root_node_id:
# if no root node id, use the START type node as root node
root_node_id = next((node_config.get("id") for node_config in root_node_configs
if node_config.get('data', {}).get('type', '') == NodeType.START.value), None)
if not root_node_id or root_node_id not in root_node_ids:
raise ValueError(f"Root node id {root_node_id} not found in the graph")
# Check whether it is connected to the previous node
cls._check_connected_to_previous_node(
route=[root_node_id],
edge_mapping=edge_mapping
)
# fetch all node ids from root node
node_ids = [root_node_id]
cls._recursively_add_node_ids(
node_ids=node_ids,
edge_mapping=edge_mapping,
node_id=root_node_id
)
node_id_config_mapping = {node_id: all_node_id_config_mapping[node_id] for node_id in node_ids}
# init parallel mapping
parallel_mapping: dict[str, GraphParallel] = {}
node_parallel_mapping: dict[str, str] = {}
cls._recursively_add_parallels(
edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping,
start_node_id=root_node_id,
parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping
)
# Check if it exceeds N layers of parallel
for parallel in parallel_mapping.values():
if parallel.parent_parallel_id:
cls._check_exceed_parallel_limit(
parallel_mapping=parallel_mapping,
level_limit=3,
parent_parallel_id=parallel.parent_parallel_id
)
# init answer stream generate routes
answer_stream_generate_routes = AnswerStreamGeneratorRouter.init(
node_id_config_mapping=node_id_config_mapping,
reverse_edge_mapping=reverse_edge_mapping
)
# init end stream param
end_stream_param = EndStreamGeneratorRouter.init(
node_id_config_mapping=node_id_config_mapping,
reverse_edge_mapping=reverse_edge_mapping,
node_parallel_mapping=node_parallel_mapping
)
# init graph
graph = cls(
root_node_id=root_node_id,
node_ids=node_ids,
node_id_config_mapping=node_id_config_mapping,
edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping,
parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping,
answer_stream_generate_routes=answer_stream_generate_routes,
end_stream_param=end_stream_param
)
return graph
def add_extra_edge(self, source_node_id: str,
target_node_id: str,
run_condition: Optional[RunCondition] = None) -> None:
"""
Add extra edge to the graph
:param source_node_id: source node id
:param target_node_id: target node id
:param run_condition: run condition
"""
if source_node_id not in self.node_ids or target_node_id not in self.node_ids:
return
if source_node_id not in self.edge_mapping:
self.edge_mapping[source_node_id] = []
if target_node_id in [graph_edge.target_node_id for graph_edge in self.edge_mapping[source_node_id]]:
return
graph_edge = GraphEdge(
source_node_id=source_node_id,
target_node_id=target_node_id,
run_condition=run_condition
)
self.edge_mapping[source_node_id].append(graph_edge)
def get_leaf_node_ids(self) -> list[str]:
"""
Get leaf node ids of the graph
:return: leaf node ids
"""
leaf_node_ids = []
for node_id in self.node_ids:
if node_id not in self.edge_mapping:
leaf_node_ids.append(node_id)
elif (len(self.edge_mapping[node_id]) == 1
and self.edge_mapping[node_id][0].target_node_id == self.root_node_id):
leaf_node_ids.append(node_id)
return leaf_node_ids
@classmethod
def _recursively_add_node_ids(cls,
node_ids: list[str],
edge_mapping: dict[str, list[GraphEdge]],
node_id: str) -> None:
"""
Recursively add node ids
:param node_ids: node ids
:param edge_mapping: edge mapping
:param node_id: node id
"""
for graph_edge in edge_mapping.get(node_id, []):
if graph_edge.target_node_id in node_ids:
continue
node_ids.append(graph_edge.target_node_id)
cls._recursively_add_node_ids(
node_ids=node_ids,
edge_mapping=edge_mapping,
node_id=graph_edge.target_node_id
)
@classmethod
def _check_connected_to_previous_node(
cls,
route: list[str],
edge_mapping: dict[str, list[GraphEdge]]
) -> None:
"""
Check whether it is connected to the previous node
"""
last_node_id = route[-1]
for graph_edge in edge_mapping.get(last_node_id, []):
if not graph_edge.target_node_id:
continue
if graph_edge.target_node_id in route:
raise ValueError(f"Node {graph_edge.source_node_id} is connected to the previous node, please check the graph.")
new_route = route[:]
new_route.append(graph_edge.target_node_id)
cls._check_connected_to_previous_node(
route=new_route,
edge_mapping=edge_mapping,
)
@classmethod
def _recursively_add_parallels(
cls,
edge_mapping: dict[str, list[GraphEdge]],
reverse_edge_mapping: dict[str, list[GraphEdge]],
start_node_id: str,
parallel_mapping: dict[str, GraphParallel],
node_parallel_mapping: dict[str, str],
parent_parallel: Optional[GraphParallel] = None
) -> None:
"""
Recursively add parallel ids
:param edge_mapping: edge mapping
:param start_node_id: start from node id
:param parallel_mapping: parallel mapping
:param node_parallel_mapping: node parallel mapping
:param parent_parallel: parent parallel
"""
target_node_edges = edge_mapping.get(start_node_id, [])
parallel = None
if len(target_node_edges) > 1:
# fetch all node ids in current parallels
parallel_branch_node_ids = []
condition_edge_mappings = {}
for graph_edge in target_node_edges:
if graph_edge.run_condition is None:
parallel_branch_node_ids.append(graph_edge.target_node_id)
else:
condition_hash = graph_edge.run_condition.hash
if not condition_hash in condition_edge_mappings:
condition_edge_mappings[condition_hash] = []
condition_edge_mappings[condition_hash].append(graph_edge)
for _, graph_edges in condition_edge_mappings.items():
if len(graph_edges) > 1:
for graph_edge in graph_edges:
parallel_branch_node_ids.append(graph_edge.target_node_id)
# any target node id in node_parallel_mapping
if parallel_branch_node_ids:
parent_parallel_id = parent_parallel.id if parent_parallel else None
parallel = GraphParallel(
start_from_node_id=start_node_id,
parent_parallel_id=parent_parallel.id if parent_parallel else None,
parent_parallel_start_node_id=parent_parallel.start_from_node_id if parent_parallel else None
)
parallel_mapping[parallel.id] = parallel
in_branch_node_ids = cls._fetch_all_node_ids_in_parallels(
edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping,
parallel_branch_node_ids=parallel_branch_node_ids
)
# collect all branches node ids
parallel_node_ids = []
for _, node_ids in in_branch_node_ids.items():
for node_id in node_ids:
in_parent_parallel = True
if parent_parallel_id:
in_parent_parallel = False
for parallel_node_id, parallel_id in node_parallel_mapping.items():
if parallel_id == parent_parallel_id and parallel_node_id == node_id:
in_parent_parallel = True
break
if in_parent_parallel:
parallel_node_ids.append(node_id)
node_parallel_mapping[node_id] = parallel.id
outside_parallel_target_node_ids = set()
for node_id in parallel_node_ids:
if node_id == parallel.start_from_node_id:
continue
node_edges = edge_mapping.get(node_id)
if not node_edges:
continue
if len(node_edges) > 1:
continue
target_node_id = node_edges[0].target_node_id
if target_node_id in parallel_node_ids:
continue
if parent_parallel_id:
parent_parallel = parallel_mapping.get(parent_parallel_id)
if not parent_parallel:
continue
if (
(node_parallel_mapping.get(target_node_id) and node_parallel_mapping.get(target_node_id) == parent_parallel_id)
or (parent_parallel and parent_parallel.end_to_node_id and target_node_id == parent_parallel.end_to_node_id)
or (not node_parallel_mapping.get(target_node_id) and not parent_parallel)
):
outside_parallel_target_node_ids.add(target_node_id)
if len(outside_parallel_target_node_ids) == 1:
if parent_parallel and parent_parallel.end_to_node_id and parallel.end_to_node_id == parent_parallel.end_to_node_id:
parallel.end_to_node_id = None
else:
parallel.end_to_node_id = outside_parallel_target_node_ids.pop()
for graph_edge in target_node_edges:
current_parallel = None
if parallel:
current_parallel = parallel
elif parent_parallel:
if not parent_parallel.end_to_node_id or (parent_parallel.end_to_node_id and graph_edge.target_node_id != parent_parallel.end_to_node_id):
current_parallel = parent_parallel
else:
# fetch parent parallel's parent parallel
parent_parallel_parent_parallel_id = parent_parallel.parent_parallel_id
if parent_parallel_parent_parallel_id:
parent_parallel_parent_parallel = parallel_mapping.get(parent_parallel_parent_parallel_id)
if (
parent_parallel_parent_parallel
and (
not parent_parallel_parent_parallel.end_to_node_id
or (parent_parallel_parent_parallel.end_to_node_id and graph_edge.target_node_id != parent_parallel_parent_parallel.end_to_node_id)
)
):
current_parallel = parent_parallel_parent_parallel
cls._recursively_add_parallels(
edge_mapping=edge_mapping,
reverse_edge_mapping=reverse_edge_mapping,
start_node_id=graph_edge.target_node_id,
parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping,
parent_parallel=current_parallel
)
@classmethod
def _check_exceed_parallel_limit(
cls,
parallel_mapping: dict[str, GraphParallel],
level_limit: int,
parent_parallel_id: str,
current_level: int = 1
) -> None:
"""
Check if it exceeds N layers of parallel
"""
parent_parallel = parallel_mapping.get(parent_parallel_id)
if not parent_parallel:
return
current_level += 1
if current_level > level_limit:
raise ValueError(f"Exceeds {level_limit} layers of parallel")
if parent_parallel.parent_parallel_id:
cls._check_exceed_parallel_limit(
parallel_mapping=parallel_mapping,
level_limit=level_limit,
parent_parallel_id=parent_parallel.parent_parallel_id,
current_level=current_level
)
@classmethod
def _recursively_add_parallel_node_ids(cls,
branch_node_ids: list[str],
edge_mapping: dict[str, list[GraphEdge]],
merge_node_id: str,
start_node_id: str) -> None:
"""
Recursively add node ids
:param branch_node_ids: in branch node ids
:param edge_mapping: edge mapping
:param merge_node_id: merge node id
:param start_node_id: start node id
"""
for graph_edge in edge_mapping.get(start_node_id, []):
if (graph_edge.target_node_id != merge_node_id
and graph_edge.target_node_id not in branch_node_ids):
branch_node_ids.append(graph_edge.target_node_id)
cls._recursively_add_parallel_node_ids(
branch_node_ids=branch_node_ids,
edge_mapping=edge_mapping,
merge_node_id=merge_node_id,
start_node_id=graph_edge.target_node_id
)
@classmethod
def _fetch_all_node_ids_in_parallels(cls,
edge_mapping: dict[str, list[GraphEdge]],
reverse_edge_mapping: dict[str, list[GraphEdge]],
parallel_branch_node_ids: list[str]) -> dict[str, list[str]]:
"""
Fetch all node ids in parallels
"""
routes_node_ids: dict[str, list[str]] = {}
for parallel_branch_node_id in parallel_branch_node_ids:
routes_node_ids[parallel_branch_node_id] = [parallel_branch_node_id]
# fetch routes node ids
cls._recursively_fetch_routes(
edge_mapping=edge_mapping,
start_node_id=parallel_branch_node_id,
routes_node_ids=routes_node_ids[parallel_branch_node_id]
)
# fetch leaf node ids from routes node ids
leaf_node_ids: dict[str, list[str]] = {}
merge_branch_node_ids: dict[str, list[str]] = {}
for branch_node_id, node_ids in routes_node_ids.items():
for node_id in node_ids:
if node_id not in edge_mapping or len(edge_mapping[node_id]) == 0:
if branch_node_id not in leaf_node_ids:
leaf_node_ids[branch_node_id] = []
leaf_node_ids[branch_node_id].append(node_id)
for branch_node_id2, inner_route2 in routes_node_ids.items():
if (
branch_node_id != branch_node_id2
and node_id in inner_route2
and len(reverse_edge_mapping.get(node_id, [])) > 1
and cls._is_node_in_routes(
reverse_edge_mapping=reverse_edge_mapping,
start_node_id=node_id,
routes_node_ids=routes_node_ids
)
):
if node_id not in merge_branch_node_ids:
merge_branch_node_ids[node_id] = []
if branch_node_id2 not in merge_branch_node_ids[node_id]:
merge_branch_node_ids[node_id].append(branch_node_id2)
# sorted merge_branch_node_ids by branch_node_ids length desc
merge_branch_node_ids = dict(sorted(merge_branch_node_ids.items(), key=lambda x: len(x[1]), reverse=True))
duplicate_end_node_ids = {}
for node_id, branch_node_ids in merge_branch_node_ids.items():
for node_id2, branch_node_ids2 in merge_branch_node_ids.items():
if node_id != node_id2 and set(branch_node_ids) == set(branch_node_ids2):
if (node_id, node_id2) not in duplicate_end_node_ids and (node_id2, node_id) not in duplicate_end_node_ids:
duplicate_end_node_ids[(node_id, node_id2)] = branch_node_ids
for (node_id, node_id2), branch_node_ids in duplicate_end_node_ids.items():
# check which node is after
if cls._is_node2_after_node1(
node1_id=node_id,
node2_id=node_id2,
edge_mapping=edge_mapping
):
if node_id in merge_branch_node_ids:
del merge_branch_node_ids[node_id2]
elif cls._is_node2_after_node1(
node1_id=node_id2,
node2_id=node_id,
edge_mapping=edge_mapping
):
if node_id2 in merge_branch_node_ids:
del merge_branch_node_ids[node_id]
branches_merge_node_ids: dict[str, str] = {}
for node_id, branch_node_ids in merge_branch_node_ids.items():
if len(branch_node_ids) <= 1:
continue
for branch_node_id in branch_node_ids:
if branch_node_id in branches_merge_node_ids:
continue
branches_merge_node_ids[branch_node_id] = node_id
in_branch_node_ids: dict[str, list[str]] = {}
for branch_node_id, node_ids in routes_node_ids.items():
in_branch_node_ids[branch_node_id] = []
if branch_node_id not in branches_merge_node_ids:
# all node ids in current branch is in this thread
in_branch_node_ids[branch_node_id].append(branch_node_id)
in_branch_node_ids[branch_node_id].extend(node_ids)
else:
merge_node_id = branches_merge_node_ids[branch_node_id]
if merge_node_id != branch_node_id:
in_branch_node_ids[branch_node_id].append(branch_node_id)
# fetch all node ids from branch_node_id and merge_node_id
cls._recursively_add_parallel_node_ids(
branch_node_ids=in_branch_node_ids[branch_node_id],
edge_mapping=edge_mapping,
merge_node_id=merge_node_id,
start_node_id=branch_node_id
)
return in_branch_node_ids
@classmethod
def _recursively_fetch_routes(cls,
edge_mapping: dict[str, list[GraphEdge]],
start_node_id: str,
routes_node_ids: list[str]) -> None:
"""
Recursively fetch route
"""
if start_node_id not in edge_mapping:
return
for graph_edge in edge_mapping[start_node_id]:
# find next node ids
if graph_edge.target_node_id not in routes_node_ids:
routes_node_ids.append(graph_edge.target_node_id)
cls._recursively_fetch_routes(
edge_mapping=edge_mapping,
start_node_id=graph_edge.target_node_id,
routes_node_ids=routes_node_ids
)
@classmethod
def _is_node_in_routes(cls,
reverse_edge_mapping: dict[str, list[GraphEdge]],
start_node_id: str,
routes_node_ids: dict[str, list[str]]) -> bool:
"""
Recursively check if the node is in the routes
"""
if start_node_id not in reverse_edge_mapping:
return False
all_routes_node_ids = set()
parallel_start_node_ids: dict[str, list[str]] = {}
for branch_node_id, node_ids in routes_node_ids.items():
for node_id in node_ids:
all_routes_node_ids.add(node_id)
if branch_node_id in reverse_edge_mapping:
for graph_edge in reverse_edge_mapping[branch_node_id]:
if graph_edge.source_node_id not in parallel_start_node_ids:
parallel_start_node_ids[graph_edge.source_node_id] = []
parallel_start_node_ids[graph_edge.source_node_id].append(branch_node_id)
parallel_start_node_id = None
for p_start_node_id, branch_node_ids in parallel_start_node_ids.items():
if set(branch_node_ids) == set(routes_node_ids.keys()):
parallel_start_node_id = p_start_node_id
return True
if not parallel_start_node_id:
raise Exception("Parallel start node id not found")
for graph_edge in reverse_edge_mapping[start_node_id]:
if graph_edge.source_node_id not in all_routes_node_ids or graph_edge.source_node_id != parallel_start_node_id:
return False
return True
@classmethod
def _is_node2_after_node1(
cls,
node1_id: str,
node2_id: str,
edge_mapping: dict[str, list[GraphEdge]]
) -> bool:
"""
is node2 after node1
"""
if node1_id not in edge_mapping:
return False
for graph_edge in edge_mapping[node1_id]:
if graph_edge.target_node_id == node2_id:
return True
if cls._is_node2_after_node1(
node1_id=graph_edge.target_node_id,
node2_id=node2_id,
edge_mapping=edge_mapping
):
return True
return False

View File

@ -0,0 +1,21 @@
from collections.abc import Mapping
from typing import Any
from pydantic import BaseModel, Field
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import UserFrom
from models.workflow import WorkflowType
class GraphInitParams(BaseModel):
# init params
tenant_id: str = Field(..., description="tenant / workspace id")
app_id: str = Field(..., description="app id")
workflow_type: WorkflowType = Field(..., description="workflow type")
workflow_id: str = Field(..., description="workflow id")
graph_config: Mapping[str, Any] = Field(..., description="graph config")
user_id: str = Field(..., description="user id")
user_from: UserFrom = Field(..., description="user from, account or end-user")
invoke_from: InvokeFrom = Field(..., description="invoke from, service-api, web-app, explore or debugger")
call_depth: int = Field(..., description="call depth")

View File

@ -0,0 +1,27 @@
from typing import Any
from pydantic import BaseModel, Field
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState
class GraphRuntimeState(BaseModel):
variable_pool: VariablePool = Field(..., description="variable pool")
"""variable pool"""
start_at: float = Field(..., description="start time")
"""start time"""
total_tokens: int = 0
"""total tokens"""
llm_usage: LLMUsage = LLMUsage.empty_usage()
"""llm usage info"""
outputs: dict[str, Any] = {}
"""outputs"""
node_run_steps: int = 0
"""node run steps"""
node_run_state: RuntimeRouteState = RuntimeRouteState()
"""node run state"""

View File

@ -0,0 +1,13 @@
from typing import Optional
from pydantic import BaseModel
from core.workflow.graph_engine.entities.graph import GraphParallel
class NextGraphNode(BaseModel):
node_id: str
"""next node id"""
parallel: Optional[GraphParallel] = None
"""parallel"""

View File

@ -0,0 +1,21 @@
import hashlib
from typing import Literal, Optional
from pydantic import BaseModel
from core.workflow.utils.condition.entities import Condition
class RunCondition(BaseModel):
type: Literal["branch_identify", "condition"]
"""condition type"""
branch_identify: Optional[str] = None
"""branch identify like: sourceHandle, required when type is branch_identify"""
conditions: Optional[list[Condition]] = None
"""conditions to run the node, required when type is condition"""
@property
def hash(self) -> str:
return hashlib.sha256(self.model_dump_json().encode()).hexdigest()

View File

@ -0,0 +1,111 @@
import uuid
from datetime import datetime, timezone
from enum import Enum
from typing import Optional
from pydantic import BaseModel, Field
from core.workflow.entities.node_entities import NodeRunResult
from models.workflow import WorkflowNodeExecutionStatus
class RouteNodeState(BaseModel):
class Status(Enum):
RUNNING = "running"
SUCCESS = "success"
FAILED = "failed"
PAUSED = "paused"
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
"""node state id"""
node_id: str
"""node id"""
node_run_result: Optional[NodeRunResult] = None
"""node run result"""
status: Status = Status.RUNNING
"""node status"""
start_at: datetime
"""start time"""
paused_at: Optional[datetime] = None
"""paused time"""
finished_at: Optional[datetime] = None
"""finished time"""
failed_reason: Optional[str] = None
"""failed reason"""
paused_by: Optional[str] = None
"""paused by"""
index: int = 1
def set_finished(self, run_result: NodeRunResult) -> None:
"""
Node finished
:param run_result: run result
"""
if self.status in [RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED]:
raise Exception(f"Route state {self.id} already finished")
if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
self.status = RouteNodeState.Status.SUCCESS
elif run_result.status == WorkflowNodeExecutionStatus.FAILED:
self.status = RouteNodeState.Status.FAILED
self.failed_reason = run_result.error
else:
raise Exception(f"Invalid route status {run_result.status}")
self.node_run_result = run_result
self.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
class RuntimeRouteState(BaseModel):
routes: dict[str, list[str]] = Field(
default_factory=dict,
description="graph state routes (source_node_state_id: target_node_state_id)"
)
node_state_mapping: dict[str, RouteNodeState] = Field(
default_factory=dict,
description="node state mapping (route_node_state_id: route_node_state)"
)
def create_node_state(self, node_id: str) -> RouteNodeState:
"""
Create node state
:param node_id: node id
"""
state = RouteNodeState(node_id=node_id, start_at=datetime.now(timezone.utc).replace(tzinfo=None))
self.node_state_mapping[state.id] = state
return state
def add_route(self, source_node_state_id: str, target_node_state_id: str) -> None:
"""
Add route to the graph state
:param source_node_state_id: source node state id
:param target_node_state_id: target node state id
"""
if source_node_state_id not in self.routes:
self.routes[source_node_state_id] = []
self.routes[source_node_state_id].append(target_node_state_id)
def get_routes_with_node_state_by_source_node_state_id(self, source_node_state_id: str) \
-> list[RouteNodeState]:
"""
Get routes with node state by source node id
:param source_node_state_id: source node state id
:return: routes with node state
"""
return [self.node_state_mapping[target_state_id]
for target_state_id in self.routes.get(source_node_state_id, [])]

View File

@ -0,0 +1,716 @@
import logging
import queue
import time
import uuid
from collections.abc import Generator, Mapping
from concurrent.futures import ThreadPoolExecutor, wait
from typing import Any, Optional
from flask import Flask, current_app
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import (
NodeRunMetadataKey,
NodeType,
UserFrom,
)
from core.workflow.entities.variable_pool import VariablePool, VariableValue
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
from core.workflow.graph_engine.entities.event import (
BaseIterationEvent,
GraphEngineEvent,
GraphRunFailedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
NodeRunFailedEvent,
NodeRunRetrieverResourceEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
ParallelBranchRunFailedEvent,
ParallelBranchRunStartedEvent,
ParallelBranchRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph, GraphEdge
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.nodes.node_mapping import node_classes
from extensions.ext_database import db
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
logger = logging.getLogger(__name__)
class GraphEngineThreadPool(ThreadPoolExecutor):
def __init__(self, max_workers=None, thread_name_prefix='',
initializer=None, initargs=(), max_submit_count=100) -> None:
super().__init__(max_workers, thread_name_prefix, initializer, initargs)
self.max_submit_count = max_submit_count
self.submit_count = 0
def submit(self, fn, *args, **kwargs):
self.submit_count += 1
self.check_is_full()
return super().submit(fn, *args, **kwargs)
def check_is_full(self) -> None:
print(f"submit_count: {self.submit_count}, max_submit_count: {self.max_submit_count}")
if self.submit_count > self.max_submit_count:
raise ValueError(f"Max submit count {self.max_submit_count} of workflow thread pool reached.")
class GraphEngine:
workflow_thread_pool_mapping: dict[str, GraphEngineThreadPool] = {}
def __init__(
self,
tenant_id: str,
app_id: str,
workflow_type: WorkflowType,
workflow_id: str,
user_id: str,
user_from: UserFrom,
invoke_from: InvokeFrom,
call_depth: int,
graph: Graph,
graph_config: Mapping[str, Any],
variable_pool: VariablePool,
max_execution_steps: int,
max_execution_time: int,
thread_pool_id: Optional[str] = None
) -> None:
thread_pool_max_submit_count = 100
thread_pool_max_workers = 10
## init thread pool
if thread_pool_id:
if not thread_pool_id in GraphEngine.workflow_thread_pool_mapping:
raise ValueError(f"Max submit count {thread_pool_max_submit_count} of workflow thread pool reached.")
self.thread_pool_id = thread_pool_id
self.thread_pool = GraphEngine.workflow_thread_pool_mapping[thread_pool_id]
self.is_main_thread_pool = False
else:
self.thread_pool = GraphEngineThreadPool(max_workers=thread_pool_max_workers, max_submit_count=thread_pool_max_submit_count)
self.thread_pool_id = str(uuid.uuid4())
self.is_main_thread_pool = True
GraphEngine.workflow_thread_pool_mapping[self.thread_pool_id] = self.thread_pool
self.graph = graph
self.init_params = GraphInitParams(
tenant_id=tenant_id,
app_id=app_id,
workflow_type=workflow_type,
workflow_id=workflow_id,
graph_config=graph_config,
user_id=user_id,
user_from=user_from,
invoke_from=invoke_from,
call_depth=call_depth
)
self.graph_runtime_state = GraphRuntimeState(
variable_pool=variable_pool,
start_at=time.perf_counter()
)
self.max_execution_steps = max_execution_steps
self.max_execution_time = max_execution_time
def run(self) -> Generator[GraphEngineEvent, None, None]:
# trigger graph run start event
yield GraphRunStartedEvent()
try:
stream_processor_cls: type[AnswerStreamProcessor | EndStreamProcessor]
if self.init_params.workflow_type == WorkflowType.CHAT:
stream_processor_cls = AnswerStreamProcessor
else:
stream_processor_cls = EndStreamProcessor
stream_processor = stream_processor_cls(
graph=self.graph,
variable_pool=self.graph_runtime_state.variable_pool
)
# run graph
generator = stream_processor.process(
self._run(start_node_id=self.graph.root_node_id)
)
for item in generator:
try:
yield item
if isinstance(item, NodeRunFailedEvent):
yield GraphRunFailedEvent(error=item.route_node_state.failed_reason or 'Unknown error.')
return
elif isinstance(item, NodeRunSucceededEvent):
if item.node_type == NodeType.END:
self.graph_runtime_state.outputs = (item.route_node_state.node_run_result.outputs
if item.route_node_state.node_run_result
and item.route_node_state.node_run_result.outputs
else {})
elif item.node_type == NodeType.ANSWER:
if "answer" not in self.graph_runtime_state.outputs:
self.graph_runtime_state.outputs["answer"] = ""
self.graph_runtime_state.outputs["answer"] += "\n" + (item.route_node_state.node_run_result.outputs.get("answer", "")
if item.route_node_state.node_run_result
and item.route_node_state.node_run_result.outputs
else "")
self.graph_runtime_state.outputs["answer"] = self.graph_runtime_state.outputs["answer"].strip()
except Exception as e:
logger.exception(f"Graph run failed: {str(e)}")
yield GraphRunFailedEvent(error=str(e))
return
# trigger graph run success event
yield GraphRunSucceededEvent(outputs=self.graph_runtime_state.outputs)
except GraphRunFailedError as e:
yield GraphRunFailedEvent(error=e.error)
return
except Exception as e:
logger.exception("Unknown Error when graph running")
yield GraphRunFailedEvent(error=str(e))
raise e
finally:
if self.is_main_thread_pool and self.thread_pool_id in GraphEngine.workflow_thread_pool_mapping:
del GraphEngine.workflow_thread_pool_mapping[self.thread_pool_id]
def _run(
self,
start_node_id: str,
in_parallel_id: Optional[str] = None,
parent_parallel_id: Optional[str] = None,
parent_parallel_start_node_id: Optional[str] = None
) -> Generator[GraphEngineEvent, None, None]:
parallel_start_node_id = None
if in_parallel_id:
parallel_start_node_id = start_node_id
next_node_id = start_node_id
previous_route_node_state: Optional[RouteNodeState] = None
while True:
# max steps reached
if self.graph_runtime_state.node_run_steps > self.max_execution_steps:
raise GraphRunFailedError('Max steps {} reached.'.format(self.max_execution_steps))
# or max execution time reached
if self._is_timed_out(
start_at=self.graph_runtime_state.start_at,
max_execution_time=self.max_execution_time
):
raise GraphRunFailedError('Max execution time {}s reached.'.format(self.max_execution_time))
# init route node state
route_node_state = self.graph_runtime_state.node_run_state.create_node_state(
node_id=next_node_id
)
# get node config
node_id = route_node_state.node_id
node_config = self.graph.node_id_config_mapping.get(node_id)
if not node_config:
raise GraphRunFailedError(f'Node {node_id} config not found.')
# convert to specific node
node_type = NodeType.value_of(node_config.get('data', {}).get('type'))
node_cls = node_classes.get(node_type)
if not node_cls:
raise GraphRunFailedError(f'Node {node_id} type {node_type} not found.')
previous_node_id = previous_route_node_state.node_id if previous_route_node_state else None
# init workflow run state
node_instance = node_cls( # type: ignore
id=route_node_state.id,
config=node_config,
graph_init_params=self.init_params,
graph=self.graph,
graph_runtime_state=self.graph_runtime_state,
previous_node_id=previous_node_id,
thread_pool_id=self.thread_pool_id
)
try:
# run node
generator = self._run_node(
node_instance=node_instance,
route_node_state=route_node_state,
parallel_id=in_parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id
)
for item in generator:
if isinstance(item, NodeRunStartedEvent):
self.graph_runtime_state.node_run_steps += 1
item.route_node_state.index = self.graph_runtime_state.node_run_steps
yield item
self.graph_runtime_state.node_run_state.node_state_mapping[route_node_state.id] = route_node_state
# append route
if previous_route_node_state:
self.graph_runtime_state.node_run_state.add_route(
source_node_state_id=previous_route_node_state.id,
target_node_state_id=route_node_state.id
)
except Exception as e:
route_node_state.status = RouteNodeState.Status.FAILED
route_node_state.failed_reason = str(e)
yield NodeRunFailedEvent(
error=str(e),
id=node_instance.id,
node_id=next_node_id,
node_type=node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=in_parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id
)
raise e
# It may not be necessary, but it is necessary. :)
if (self.graph.node_id_config_mapping[next_node_id]
.get("data", {}).get("type", "").lower() == NodeType.END.value):
break
previous_route_node_state = route_node_state
# get next node ids
edge_mappings = self.graph.edge_mapping.get(next_node_id)
if not edge_mappings:
break
if len(edge_mappings) == 1:
edge = edge_mappings[0]
if edge.run_condition:
result = ConditionManager.get_condition_handler(
init_params=self.init_params,
graph=self.graph,
run_condition=edge.run_condition,
).check(
graph_runtime_state=self.graph_runtime_state,
previous_route_node_state=previous_route_node_state
)
if not result:
break
next_node_id = edge.target_node_id
else:
final_node_id = None
if any(edge.run_condition for edge in edge_mappings):
# if nodes has run conditions, get node id which branch to take based on the run condition results
condition_edge_mappings = {}
for edge in edge_mappings:
if edge.run_condition:
run_condition_hash = edge.run_condition.hash
if run_condition_hash not in condition_edge_mappings:
condition_edge_mappings[run_condition_hash] = []
condition_edge_mappings[run_condition_hash].append(edge)
for _, sub_edge_mappings in condition_edge_mappings.items():
if len(sub_edge_mappings) == 0:
continue
edge = sub_edge_mappings[0]
result = ConditionManager.get_condition_handler(
init_params=self.init_params,
graph=self.graph,
run_condition=edge.run_condition,
).check(
graph_runtime_state=self.graph_runtime_state,
previous_route_node_state=previous_route_node_state,
)
if not result:
continue
if len(sub_edge_mappings) == 1:
final_node_id = edge.target_node_id
else:
parallel_generator = self._run_parallel_branches(
edge_mappings=sub_edge_mappings,
in_parallel_id=in_parallel_id,
parallel_start_node_id=parallel_start_node_id
)
for item in parallel_generator:
if isinstance(item, str):
final_node_id = item
else:
yield item
break
if not final_node_id:
break
next_node_id = final_node_id
else:
parallel_generator = self._run_parallel_branches(
edge_mappings=edge_mappings,
in_parallel_id=in_parallel_id,
parallel_start_node_id=parallel_start_node_id
)
for item in parallel_generator:
if isinstance(item, str):
final_node_id = item
else:
yield item
if not final_node_id:
break
next_node_id = final_node_id
if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') != in_parallel_id:
break
def _run_parallel_branches(
self,
edge_mappings: list[GraphEdge],
in_parallel_id: Optional[str] = None,
parallel_start_node_id: Optional[str] = None,
) -> Generator[GraphEngineEvent | str, None, None]:
# if nodes has no run conditions, parallel run all nodes
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id)
if not parallel_id:
node_id = edge_mappings[0].target_node_id
node_config = self.graph.node_id_config_mapping.get(node_id)
if not node_config:
raise GraphRunFailedError(f'Node {node_id} related parallel not found or incorrectly connected to multiple parallel branches.')
node_title = node_config.get('data', {}).get('title')
raise GraphRunFailedError(f'Node {node_title} related parallel not found or incorrectly connected to multiple parallel branches.')
parallel = self.graph.parallel_mapping.get(parallel_id)
if not parallel:
raise GraphRunFailedError(f'Parallel {parallel_id} not found.')
# run parallel nodes, run in new thread and use queue to get results
q: queue.Queue = queue.Queue()
# Create a list to store the threads
futures = []
# new thread
for edge in edge_mappings:
if (
edge.target_node_id not in self.graph.node_parallel_mapping
or self.graph.node_parallel_mapping.get(edge.target_node_id, '') != parallel_id
):
continue
futures.append(
self.thread_pool.submit(self._run_parallel_node, **{
'flask_app': current_app._get_current_object(), # type: ignore[attr-defined]
'q': q,
'parallel_id': parallel_id,
'parallel_start_node_id': edge.target_node_id,
'parent_parallel_id': in_parallel_id,
'parent_parallel_start_node_id': parallel_start_node_id,
})
)
succeeded_count = 0
while True:
try:
event = q.get(timeout=1)
if event is None:
break
yield event
if event.parallel_id == parallel_id:
if isinstance(event, ParallelBranchRunSucceededEvent):
succeeded_count += 1
if succeeded_count == len(futures):
q.put(None)
continue
elif isinstance(event, ParallelBranchRunFailedEvent):
raise GraphRunFailedError(event.error)
except queue.Empty:
continue
# wait all threads
wait(futures)
# get final node id
final_node_id = parallel.end_to_node_id
if final_node_id:
yield final_node_id
def _run_parallel_node(
self,
flask_app: Flask,
q: queue.Queue,
parallel_id: str,
parallel_start_node_id: str,
parent_parallel_id: Optional[str] = None,
parent_parallel_start_node_id: Optional[str] = None,
) -> None:
"""
Run parallel nodes
"""
with flask_app.app_context():
try:
q.put(ParallelBranchRunStartedEvent(
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id
))
# run node
generator = self._run(
start_node_id=parallel_start_node_id,
in_parallel_id=parallel_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id
)
for item in generator:
q.put(item)
# trigger graph run success event
q.put(ParallelBranchRunSucceededEvent(
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id
))
except GraphRunFailedError as e:
q.put(ParallelBranchRunFailedEvent(
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
error=e.error
))
except Exception as e:
logger.exception("Unknown Error when generating in parallel")
q.put(ParallelBranchRunFailedEvent(
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
error=str(e)
))
finally:
db.session.remove()
def _run_node(
self,
node_instance: BaseNode,
route_node_state: RouteNodeState,
parallel_id: Optional[str] = None,
parallel_start_node_id: Optional[str] = None,
parent_parallel_id: Optional[str] = None,
parent_parallel_start_node_id: Optional[str] = None,
) -> Generator[GraphEngineEvent, None, None]:
"""
Run node
"""
# trigger node run start event
yield NodeRunStartedEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
predecessor_node_id=node_instance.previous_node_id,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id
)
db.session.close()
try:
# run node
generator = node_instance.run()
for item in generator:
if isinstance(item, GraphEngineEvent):
if isinstance(item, BaseIterationEvent):
# add parallel info to iteration event
item.parallel_id = parallel_id
item.parallel_start_node_id = parallel_start_node_id
item.parent_parallel_id = parent_parallel_id
item.parent_parallel_start_node_id = parent_parallel_start_node_id
yield item
else:
if isinstance(item, RunCompletedEvent):
run_result = item.run_result
route_node_state.set_finished(run_result=run_result)
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
yield NodeRunFailedEvent(
error=route_node_state.failed_reason or 'Unknown error.',
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id
)
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
# plus state total_tokens
self.graph_runtime_state.total_tokens += int(
run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type]
)
if run_result.llm_usage:
# use the latest usage
self.graph_runtime_state.llm_usage += run_result.llm_usage
# append node output variables to variable pool
if run_result.outputs:
for variable_key, variable_value in run_result.outputs.items():
# append variables to variable pool recursively
self._append_variables_recursively(
node_id=node_instance.node_id,
variable_key_list=[variable_key],
variable_value=variable_value
)
# add parallel info to run result metadata
if parallel_id and parallel_start_node_id:
if not run_result.metadata:
run_result.metadata = {}
run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id
run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id
if parent_parallel_id and parent_parallel_start_node_id:
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = parent_parallel_start_node_id
yield NodeRunSucceededEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id
)
break
elif isinstance(item, RunStreamChunkEvent):
yield NodeRunStreamChunkEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
chunk_content=item.chunk_content,
from_variable_selector=item.from_variable_selector,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id
)
elif isinstance(item, RunRetrieverResourceEvent):
yield NodeRunRetrieverResourceEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
retriever_resources=item.retriever_resources,
context=item.context,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id
)
except GenerateTaskStoppedException:
# trigger node run failed event
route_node_state.status = RouteNodeState.Status.FAILED
route_node_state.failed_reason = "Workflow stopped."
yield NodeRunFailedEvent(
error="Workflow stopped.",
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id
)
return
except Exception as e:
logger.exception(f"Node {node_instance.node_data.title} run failed: {str(e)}")
raise e
finally:
db.session.close()
def _append_variables_recursively(self,
node_id: str,
variable_key_list: list[str],
variable_value: VariableValue):
"""
Append variables recursively
:param node_id: node id
:param variable_key_list: variable key list
:param variable_value: variable value
:return:
"""
self.graph_runtime_state.variable_pool.add(
[node_id] + variable_key_list,
variable_value
)
# if variable_value is a dict, then recursively append variables
if isinstance(variable_value, dict):
for key, value in variable_value.items():
# construct new key list
new_key_list = variable_key_list + [key]
self._append_variables_recursively(
node_id=node_id,
variable_key_list=new_key_list,
variable_value=value
)
def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool:
"""
Check timeout
:param start_at: start time
:param max_execution_time: max execution time
:return:
"""
return time.perf_counter() - start_at > max_execution_time
class GraphRunFailedError(Exception):
def __init__(self, error: str):
self.error = error

View File

@ -1,9 +1,8 @@
from typing import cast
from collections.abc import Mapping, Sequence
from typing import Any, cast
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
from core.workflow.nodes.answer.entities import (
AnswerNodeData,
GenerateRouteChunk,
@ -19,24 +18,26 @@ class AnswerNode(BaseNode):
_node_data_cls = AnswerNodeData
_node_type: NodeType = NodeType.ANSWER
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
def _run(self) -> NodeRunResult:
"""
Run node
:param variable_pool: variable pool
:return:
"""
node_data = self.node_data
node_data = cast(AnswerNodeData, node_data)
# generate routes
generate_routes = self.extract_generate_route_from_node_data(node_data)
generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(node_data)
answer = ''
for part in generate_routes:
if part.type == "var":
if part.type == GenerateRouteChunk.ChunkType.VAR:
part = cast(VarGenerateRouteChunk, part)
value_selector = part.value_selector
value = variable_pool.get(value_selector)
value = self.graph_runtime_state.variable_pool.get(
value_selector
)
if value:
answer += value.markdown
else:
@ -51,70 +52,16 @@ class AnswerNode(BaseNode):
)
@classmethod
def extract_generate_route_selectors(cls, config: dict) -> list[GenerateRouteChunk]:
"""
Extract generate route selectors
:param config: node config
:return:
"""
node_data = cls._node_data_cls(**config.get("data", {}))
node_data = cast(AnswerNodeData, node_data)
return cls.extract_generate_route_from_node_data(node_data)
@classmethod
def extract_generate_route_from_node_data(cls, node_data: AnswerNodeData) -> list[GenerateRouteChunk]:
"""
Extract generate route from node data
:param node_data: node data object
:return:
"""
variable_template_parser = VariableTemplateParser(template=node_data.answer)
variable_selectors = variable_template_parser.extract_variable_selectors()
value_selector_mapping = {
variable_selector.variable: variable_selector.value_selector
for variable_selector in variable_selectors
}
variable_keys = list(value_selector_mapping.keys())
# format answer template
template_parser = PromptTemplateParser(template=node_data.answer, with_variable_tmpl=True)
template_variable_keys = template_parser.variable_keys
# Take the intersection of variable_keys and template_variable_keys
variable_keys = list(set(variable_keys) & set(template_variable_keys))
template = node_data.answer
for var in variable_keys:
template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω')
generate_routes = []
for part in template.split('Ω'):
if part:
if cls._is_variable(part, variable_keys):
var_key = part.replace('Ω', '').replace('{{', '').replace('}}', '')
value_selector = value_selector_mapping[var_key]
generate_routes.append(VarGenerateRouteChunk(
value_selector=value_selector
))
else:
generate_routes.append(TextGenerateRouteChunk(
text=part
))
return generate_routes
@classmethod
def _is_variable(cls, part, variable_keys):
cleaned_part = part.replace('{{', '').replace('}}', '')
return part.startswith('{{') and cleaned_part in variable_keys
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: AnswerNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
@ -126,6 +73,6 @@ class AnswerNode(BaseNode):
variable_mapping = {}
for variable_selector in variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
variable_mapping[node_id + '.' + variable_selector.variable] = variable_selector.value_selector
return variable_mapping

View File

@ -0,0 +1,169 @@
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.entities.node_entities import NodeType
from core.workflow.nodes.answer.entities import (
AnswerNodeData,
AnswerStreamGenerateRoute,
GenerateRouteChunk,
TextGenerateRouteChunk,
VarGenerateRouteChunk,
)
from core.workflow.utils.variable_template_parser import VariableTemplateParser
class AnswerStreamGeneratorRouter:
@classmethod
def init(cls,
node_id_config_mapping: dict[str, dict],
reverse_edge_mapping: dict[str, list["GraphEdge"]] # type: ignore[name-defined]
) -> AnswerStreamGenerateRoute:
"""
Get stream generate routes.
:return:
"""
# parse stream output node value selectors of answer nodes
answer_generate_route: dict[str, list[GenerateRouteChunk]] = {}
for answer_node_id, node_config in node_id_config_mapping.items():
if not node_config.get('data', {}).get('type') == NodeType.ANSWER.value:
continue
# get generate route for stream output
generate_route = cls._extract_generate_route_selectors(node_config)
answer_generate_route[answer_node_id] = generate_route
# fetch answer dependencies
answer_node_ids = list(answer_generate_route.keys())
answer_dependencies = cls._fetch_answers_dependencies(
answer_node_ids=answer_node_ids,
reverse_edge_mapping=reverse_edge_mapping,
node_id_config_mapping=node_id_config_mapping
)
return AnswerStreamGenerateRoute(
answer_generate_route=answer_generate_route,
answer_dependencies=answer_dependencies
)
@classmethod
def extract_generate_route_from_node_data(cls, node_data: AnswerNodeData) -> list[GenerateRouteChunk]:
"""
Extract generate route from node data
:param node_data: node data object
:return:
"""
variable_template_parser = VariableTemplateParser(template=node_data.answer)
variable_selectors = variable_template_parser.extract_variable_selectors()
value_selector_mapping = {
variable_selector.variable: variable_selector.value_selector
for variable_selector in variable_selectors
}
variable_keys = list(value_selector_mapping.keys())
# format answer template
template_parser = PromptTemplateParser(template=node_data.answer, with_variable_tmpl=True)
template_variable_keys = template_parser.variable_keys
# Take the intersection of variable_keys and template_variable_keys
variable_keys = list(set(variable_keys) & set(template_variable_keys))
template = node_data.answer
for var in variable_keys:
template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω')
generate_routes: list[GenerateRouteChunk] = []
for part in template.split('Ω'):
if part:
if cls._is_variable(part, variable_keys):
var_key = part.replace('Ω', '').replace('{{', '').replace('}}', '')
value_selector = value_selector_mapping[var_key]
generate_routes.append(VarGenerateRouteChunk(
value_selector=value_selector
))
else:
generate_routes.append(TextGenerateRouteChunk(
text=part
))
return generate_routes
@classmethod
def _extract_generate_route_selectors(cls, config: dict) -> list[GenerateRouteChunk]:
"""
Extract generate route selectors
:param config: node config
:return:
"""
node_data = AnswerNodeData(**config.get("data", {}))
return cls.extract_generate_route_from_node_data(node_data)
@classmethod
def _is_variable(cls, part, variable_keys):
cleaned_part = part.replace('{{', '').replace('}}', '')
return part.startswith('{{') and cleaned_part in variable_keys
@classmethod
def _fetch_answers_dependencies(cls,
answer_node_ids: list[str],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
node_id_config_mapping: dict[str, dict]
) -> dict[str, list[str]]:
"""
Fetch answer dependencies
:param answer_node_ids: answer node ids
:param reverse_edge_mapping: reverse edge mapping
:param node_id_config_mapping: node id config mapping
:return:
"""
answer_dependencies: dict[str, list[str]] = {}
for answer_node_id in answer_node_ids:
if answer_dependencies.get(answer_node_id) is None:
answer_dependencies[answer_node_id] = []
cls._recursive_fetch_answer_dependencies(
current_node_id=answer_node_id,
answer_node_id=answer_node_id,
node_id_config_mapping=node_id_config_mapping,
reverse_edge_mapping=reverse_edge_mapping,
answer_dependencies=answer_dependencies
)
return answer_dependencies
@classmethod
def _recursive_fetch_answer_dependencies(cls,
current_node_id: str,
answer_node_id: str,
node_id_config_mapping: dict[str, dict],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
answer_dependencies: dict[str, list[str]]
) -> None:
"""
Recursive fetch answer dependencies
:param current_node_id: current node id
:param answer_node_id: answer node id
:param node_id_config_mapping: node id config mapping
:param reverse_edge_mapping: reverse edge mapping
:param answer_dependencies: answer dependencies
:return:
"""
reverse_edges = reverse_edge_mapping.get(current_node_id, [])
for edge in reverse_edges:
source_node_id = edge.source_node_id
source_node_type = node_id_config_mapping[source_node_id].get('data', {}).get('type')
if source_node_type in (
NodeType.ANSWER.value,
NodeType.IF_ELSE.value,
NodeType.QUESTION_CLASSIFIER,
):
answer_dependencies[answer_node_id].append(source_node_id)
else:
cls._recursive_fetch_answer_dependencies(
current_node_id=source_node_id,
answer_node_id=answer_node_id,
node_id_config_mapping=node_id_config_mapping,
reverse_edge_mapping=reverse_edge_mapping,
answer_dependencies=answer_dependencies
)

View File

@ -0,0 +1,221 @@
import logging
from collections.abc import Generator
from typing import Optional, cast
from core.file.file_obj import FileVar
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes.answer.base_stream_processor import StreamProcessor
from core.workflow.nodes.answer.entities import GenerateRouteChunk, TextGenerateRouteChunk, VarGenerateRouteChunk
logger = logging.getLogger(__name__)
class AnswerStreamProcessor(StreamProcessor):
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
super().__init__(graph, variable_pool)
self.generate_routes = graph.answer_stream_generate_routes
self.route_position = {}
for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items():
self.route_position[answer_node_id] = 0
self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
def process(self,
generator: Generator[GraphEngineEvent, None, None]
) -> Generator[GraphEngineEvent, None, None]:
for event in generator:
if isinstance(event, NodeRunStartedEvent):
if event.route_node_state.node_id == self.graph.root_node_id and not self.rest_node_ids:
self.reset()
yield event
elif isinstance(event, NodeRunStreamChunkEvent):
if event.in_iteration_id:
yield event
continue
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
stream_out_answer_node_ids = self.current_stream_chunk_generating_node_ids[
event.route_node_state.node_id
]
else:
stream_out_answer_node_ids = self._get_stream_out_answer_node_ids(event)
self.current_stream_chunk_generating_node_ids[
event.route_node_state.node_id
] = stream_out_answer_node_ids
for _ in stream_out_answer_node_ids:
yield event
elif isinstance(event, NodeRunSucceededEvent):
yield event
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
# update self.route_position after all stream event finished
for answer_node_id in self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]:
self.route_position[answer_node_id] += 1
del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]
# remove unreachable nodes
self._remove_unreachable_nodes(event)
# generate stream outputs
yield from self._generate_stream_outputs_when_node_finished(event)
else:
yield event
def reset(self) -> None:
self.route_position = {}
for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items():
self.route_position[answer_node_id] = 0
self.rest_node_ids = self.graph.node_ids.copy()
self.current_stream_chunk_generating_node_ids = {}
def _generate_stream_outputs_when_node_finished(self,
event: NodeRunSucceededEvent
) -> Generator[GraphEngineEvent, None, None]:
"""
Generate stream outputs.
:param event: node run succeeded event
:return:
"""
for answer_node_id, position in self.route_position.items():
# all depends on answer node id not in rest node ids
if (event.route_node_state.node_id != answer_node_id
and (answer_node_id not in self.rest_node_ids
or not all(dep_id not in self.rest_node_ids
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]))):
continue
route_position = self.route_position[answer_node_id]
route_chunks = self.generate_routes.answer_generate_route[answer_node_id][route_position:]
for route_chunk in route_chunks:
if route_chunk.type == GenerateRouteChunk.ChunkType.TEXT:
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
yield NodeRunStreamChunkEvent(
id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
chunk_content=route_chunk.text,
route_node_state=event.route_node_state,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
)
else:
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
value_selector = route_chunk.value_selector
if not value_selector:
break
value = self.variable_pool.get(
value_selector
)
if value is None:
break
text = value.markdown
if text:
yield NodeRunStreamChunkEvent(
id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
chunk_content=text,
from_variable_selector=value_selector,
route_node_state=event.route_node_state,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
)
self.route_position[answer_node_id] += 1
def _get_stream_out_answer_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]:
"""
Is stream out support
:param event: queue text chunk event
:return:
"""
if not event.from_variable_selector:
return []
stream_output_value_selector = event.from_variable_selector
if not stream_output_value_selector:
return []
stream_out_answer_node_ids = []
for answer_node_id, route_position in self.route_position.items():
if answer_node_id not in self.rest_node_ids:
continue
# all depends on answer node id not in rest node ids
if all(dep_id not in self.rest_node_ids
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]):
if route_position >= len(self.generate_routes.answer_generate_route[answer_node_id]):
continue
route_chunk = self.generate_routes.answer_generate_route[answer_node_id][route_position]
if route_chunk.type != GenerateRouteChunk.ChunkType.VAR:
continue
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
value_selector = route_chunk.value_selector
# check chunk node id is before current node id or equal to current node id
if value_selector != stream_output_value_selector:
continue
stream_out_answer_node_ids.append(answer_node_id)
return stream_out_answer_node_ids
@classmethod
def _fetch_files_from_variable_value(cls, value: dict | list) -> list[dict]:
"""
Fetch files from variable value
:param value: variable value
:return:
"""
if not value:
return []
files = []
if isinstance(value, list):
for item in value:
file_var = cls._get_file_var_from_value(item)
if file_var:
files.append(file_var)
elif isinstance(value, dict):
file_var = cls._get_file_var_from_value(value)
if file_var:
files.append(file_var)
return files
@classmethod
def _get_file_var_from_value(cls, value: dict | list) -> Optional[dict]:
"""
Get file var from value
:param value: variable value
:return:
"""
if not value:
return None
if isinstance(value, dict):
if '__variant' in value and value['__variant'] == FileVar.__name__:
return value
elif isinstance(value, FileVar):
return value.to_dict()
return None

View File

@ -0,0 +1,71 @@
from abc import ABC, abstractmethod
from collections.abc import Generator
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunSucceededEvent
from core.workflow.graph_engine.entities.graph import Graph
class StreamProcessor(ABC):
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
self.graph = graph
self.variable_pool = variable_pool
self.rest_node_ids = graph.node_ids.copy()
@abstractmethod
def process(self,
generator: Generator[GraphEngineEvent, None, None]
) -> Generator[GraphEngineEvent, None, None]:
raise NotImplementedError
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None:
finished_node_id = event.route_node_state.node_id
if finished_node_id not in self.rest_node_ids:
return
# remove finished node id
self.rest_node_ids.remove(finished_node_id)
run_result = event.route_node_state.node_run_result
if not run_result:
return
if run_result.edge_source_handle:
reachable_node_ids = []
unreachable_first_node_ids = []
for edge in self.graph.edge_mapping[finished_node_id]:
if (edge.run_condition
and edge.run_condition.branch_identify
and run_result.edge_source_handle == edge.run_condition.branch_identify):
reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
continue
else:
unreachable_first_node_ids.append(edge.target_node_id)
for node_id in unreachable_first_node_ids:
self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids)
def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]:
node_ids = []
for edge in self.graph.edge_mapping.get(node_id, []):
if edge.target_node_id == self.graph.root_node_id:
continue
node_ids.append(edge.target_node_id)
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
return node_ids
def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None:
"""
remove target node ids until merge
"""
if node_id not in self.rest_node_ids:
return
self.rest_node_ids.remove(node_id)
for edge in self.graph.edge_mapping.get(node_id, []):
if edge.target_node_id in reachable_node_ids:
continue
self._remove_node_ids_in_unreachable_branch(edge.target_node_id, reachable_node_ids)

View File

@ -1,5 +1,6 @@
from enum import Enum
from pydantic import BaseModel
from pydantic import BaseModel, Field
from core.workflow.entities.base_node_data_entities import BaseNodeData
@ -8,27 +9,54 @@ class AnswerNodeData(BaseNodeData):
"""
Answer Node Data.
"""
answer: str
answer: str = Field(..., description="answer template string")
class GenerateRouteChunk(BaseModel):
"""
Generate Route Chunk.
"""
type: str
class ChunkType(Enum):
VAR = "var"
TEXT = "text"
type: ChunkType = Field(..., description="generate route chunk type")
class VarGenerateRouteChunk(GenerateRouteChunk):
"""
Var Generate Route Chunk.
"""
type: str = "var"
value_selector: list[str]
type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.VAR
"""generate route chunk type"""
value_selector: list[str] = Field(..., description="value selector")
class TextGenerateRouteChunk(GenerateRouteChunk):
"""
Text Generate Route Chunk.
"""
type: str = "text"
text: str
type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.TEXT
"""generate route chunk type"""
text: str = Field(..., description="text")
class AnswerNodeDoubleLink(BaseModel):
node_id: str = Field(..., description="node id")
source_node_ids: list[str] = Field(..., description="source node ids")
target_node_ids: list[str] = Field(..., description="target node ids")
class AnswerStreamGenerateRoute(BaseModel):
"""
AnswerStreamGenerateRoute entity
"""
answer_dependencies: dict[str, list[str]] = Field(
...,
description="answer dependencies (answer node id -> dependent answer node ids)"
)
answer_generate_route: dict[str, list[GenerateRouteChunk]] = Field(
...,
description="answer generate route (answer node id -> generate route chunks)"
)

View File

@ -1,142 +1,103 @@
from abc import ABC, abstractmethod
from collections.abc import Mapping, Sequence
from enum import Enum
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from models import WorkflowNodeExecutionStatus
class UserFrom(Enum):
"""
User from
"""
ACCOUNT = "account"
END_USER = "end-user"
@classmethod
def value_of(cls, value: str) -> "UserFrom":
"""
Value of
:param value: value
:return:
"""
for item in cls:
if item.value == value:
return item
raise ValueError(f"Invalid value: {value}")
from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.event import RunCompletedEvent, RunEvent
class BaseNode(ABC):
_node_data_cls: type[BaseNodeData]
_node_type: NodeType
tenant_id: str
app_id: str
workflow_id: str
user_id: str
user_from: UserFrom
invoke_from: InvokeFrom
workflow_call_depth: int
node_id: str
node_data: BaseNodeData
node_run_result: Optional[NodeRunResult] = None
callbacks: Sequence[WorkflowCallback]
is_answer_previous_node: bool = False
def __init__(self, tenant_id: str,
app_id: str,
workflow_id: str,
user_id: str,
user_from: UserFrom,
invoke_from: InvokeFrom,
def __init__(self,
id: str,
config: Mapping[str, Any],
callbacks: Sequence[WorkflowCallback] | None = None,
workflow_call_depth: int = 0) -> None:
self.tenant_id = tenant_id
self.app_id = app_id
self.workflow_id = workflow_id
self.user_id = user_id
self.user_from = user_from
self.invoke_from = invoke_from
self.workflow_call_depth = workflow_call_depth
graph_init_params: GraphInitParams,
graph: Graph,
graph_runtime_state: GraphRuntimeState,
previous_node_id: Optional[str] = None,
thread_pool_id: Optional[str] = None) -> None:
self.id = id
self.tenant_id = graph_init_params.tenant_id
self.app_id = graph_init_params.app_id
self.workflow_type = graph_init_params.workflow_type
self.workflow_id = graph_init_params.workflow_id
self.graph_config = graph_init_params.graph_config
self.user_id = graph_init_params.user_id
self.user_from = graph_init_params.user_from
self.invoke_from = graph_init_params.invoke_from
self.workflow_call_depth = graph_init_params.call_depth
self.graph = graph
self.graph_runtime_state = graph_runtime_state
self.previous_node_id = previous_node_id
self.thread_pool_id = thread_pool_id
# TODO: May need to check if key exists.
self.node_id = config["id"]
if not self.node_id:
node_id = config.get("id")
if not node_id:
raise ValueError("Node ID is required.")
self.node_id = node_id
self.node_data = self._node_data_cls(**config.get("data", {}))
self.callbacks = callbacks or []
@abstractmethod
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
def _run(self) \
-> NodeRunResult | Generator[RunEvent | InNodeEvent, None, None]:
"""
Run node
:param variable_pool: variable pool
:return:
"""
raise NotImplementedError
def run(self, variable_pool: VariablePool) -> NodeRunResult:
def run(self) -> Generator[RunEvent | InNodeEvent, None, None]:
"""
Run node entry
:param variable_pool: variable pool
:return:
"""
try:
result = self._run(
variable_pool=variable_pool
)
self.node_run_result = result
return result
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
)
result = self._run()
def publish_text_chunk(self, text: str, value_selector: list[str] | None = None) -> None:
"""
Publish text chunk
:param text: chunk text
:param value_selector: value selector
:return:
"""
if self.callbacks:
for callback in self.callbacks:
callback.on_node_text_chunk(
node_id=self.node_id,
text=text,
metadata={
"node_type": self.node_type,
"is_answer_previous_node": self.is_answer_previous_node,
"value_selector": value_selector
}
)
if isinstance(result, NodeRunResult):
yield RunCompletedEvent(
run_result=result
)
else:
yield from result
@classmethod
def extract_variable_selector_to_variable_mapping(cls, config: dict):
def extract_variable_selector_to_variable_mapping(cls, graph_config: Mapping[str, Any], config: dict) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param config: node config
:return:
"""
node_id = config.get("id")
if not node_id:
raise ValueError("Node ID is required when extracting variable selector to variable mapping.")
node_data = cls._node_data_cls(**config.get("data", {}))
return cls._extract_variable_selector_to_variable_mapping(node_data)
return cls._extract_variable_selector_to_variable_mapping(
graph_config=graph_config,
node_id=node_id,
node_data=node_data
)
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> Mapping[str, Sequence[str]]:
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: BaseNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
@ -158,38 +119,3 @@ class BaseNode(ABC):
:return:
"""
return self._node_type
class BaseIterationNode(BaseNode):
@abstractmethod
def _run(self, variable_pool: VariablePool) -> BaseIterationState:
"""
Run node
:param variable_pool: variable pool
:return:
"""
raise NotImplementedError
def run(self, variable_pool: VariablePool) -> BaseIterationState:
"""
Run node entry
:param variable_pool: variable pool
:return:
"""
return self._run(variable_pool=variable_pool)
def get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str:
"""
Get next iteration start node id based on the graph.
:param graph: graph
:return: next node id
"""
return self._get_next_iteration(variable_pool, state)
@abstractmethod
def _get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str:
"""
Get next iteration start node id based on the graph.
:param graph: graph
:return: next node id
"""
raise NotImplementedError

View File

@ -1,4 +1,5 @@
from typing import Optional, Union, cast
from collections.abc import Mapping, Sequence
from typing import Any, Optional, Union, cast
from configs import dify_config
from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage
@ -6,7 +7,6 @@ from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.code.entities import CodeNodeData
from models.workflow import WorkflowNodeExecutionStatus
@ -33,13 +33,13 @@ class CodeNode(BaseNode):
return code_provider.get_default_config()
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
def _run(self) -> NodeRunResult:
"""
Run code
:param variable_pool: variable pool
:return:
"""
node_data = cast(CodeNodeData, self.node_data)
node_data = self.node_data
node_data = cast(CodeNodeData, node_data)
# Get code language
code_language = node_data.code_language
@ -49,7 +49,7 @@ class CodeNode(BaseNode):
variables = {}
for variable_selector in node_data.variables:
variable = variable_selector.variable
value = variable_pool.get_any(variable_selector.value_selector)
value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
variables[variable] = value
# Run code
@ -311,13 +311,19 @@ class CodeNode(BaseNode):
return transformed_result
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: CodeNodeData) -> dict[str, list[str]]:
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: CodeNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
return {
variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables
node_id + '.' + variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables
}

View File

@ -1,8 +1,7 @@
from typing import cast
from collections.abc import Mapping, Sequence
from typing import Any, cast
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.end.entities import EndNodeData
from models.workflow import WorkflowNodeExecutionStatus
@ -12,10 +11,9 @@ class EndNode(BaseNode):
_node_data_cls = EndNodeData
_node_type = NodeType.END
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
def _run(self) -> NodeRunResult:
"""
Run node
:param variable_pool: variable pool
:return:
"""
node_data = self.node_data
@ -24,7 +22,7 @@ class EndNode(BaseNode):
outputs = {}
for variable_selector in output_variables:
value = variable_pool.get_any(variable_selector.value_selector)
value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
outputs[variable_selector.variable] = value
return NodeRunResult(
@ -34,52 +32,16 @@ class EndNode(BaseNode):
)
@classmethod
def extract_generate_nodes(cls, graph: dict, config: dict) -> list[str]:
"""
Extract generate nodes
:param graph: graph
:param config: node config
:return:
"""
node_data = cls._node_data_cls(**config.get("data", {}))
node_data = cast(EndNodeData, node_data)
return cls.extract_generate_nodes_from_node_data(graph, node_data)
@classmethod
def extract_generate_nodes_from_node_data(cls, graph: dict, node_data: EndNodeData) -> list[str]:
"""
Extract generate nodes from node data
:param graph: graph
:param node_data: node data object
:return:
"""
nodes = graph.get('nodes', [])
node_mapping = {node.get('id'): node for node in nodes}
variable_selectors = node_data.outputs
generate_nodes = []
for variable_selector in variable_selectors:
if not variable_selector.value_selector:
continue
node_id = variable_selector.value_selector[0]
if node_id != 'sys' and node_id in node_mapping:
node = node_mapping[node_id]
node_type = node.get('data', {}).get('type')
if node_type == NodeType.LLM.value and variable_selector.value_selector[1] == 'text':
generate_nodes.append(node_id)
# remove duplicates
generate_nodes = list(set(generate_nodes))
return generate_nodes
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: EndNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""

View File

@ -0,0 +1,148 @@
from core.workflow.entities.node_entities import NodeType
from core.workflow.nodes.end.entities import EndNodeData, EndStreamParam
class EndStreamGeneratorRouter:
@classmethod
def init(cls,
node_id_config_mapping: dict[str, dict],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
node_parallel_mapping: dict[str, str]
) -> EndStreamParam:
"""
Get stream generate routes.
:return:
"""
# parse stream output node value selector of end nodes
end_stream_variable_selectors_mapping: dict[str, list[list[str]]] = {}
for end_node_id, node_config in node_id_config_mapping.items():
if not node_config.get('data', {}).get('type') == NodeType.END.value:
continue
# skip end node in parallel
if end_node_id in node_parallel_mapping:
continue
# get generate route for stream output
stream_variable_selectors = cls._extract_stream_variable_selector(node_id_config_mapping, node_config)
end_stream_variable_selectors_mapping[end_node_id] = stream_variable_selectors
# fetch end dependencies
end_node_ids = list(end_stream_variable_selectors_mapping.keys())
end_dependencies = cls._fetch_ends_dependencies(
end_node_ids=end_node_ids,
reverse_edge_mapping=reverse_edge_mapping,
node_id_config_mapping=node_id_config_mapping
)
return EndStreamParam(
end_stream_variable_selector_mapping=end_stream_variable_selectors_mapping,
end_dependencies=end_dependencies
)
@classmethod
def extract_stream_variable_selector_from_node_data(cls,
node_id_config_mapping: dict[str, dict],
node_data: EndNodeData) -> list[list[str]]:
"""
Extract stream variable selector from node data
:param node_id_config_mapping: node id config mapping
:param node_data: node data object
:return:
"""
variable_selectors = node_data.outputs
value_selectors = []
for variable_selector in variable_selectors:
if not variable_selector.value_selector:
continue
node_id = variable_selector.value_selector[0]
if node_id != 'sys' and node_id in node_id_config_mapping:
node = node_id_config_mapping[node_id]
node_type = node.get('data', {}).get('type')
if (
variable_selector.value_selector not in value_selectors
and node_type == NodeType.LLM.value
and variable_selector.value_selector[1] == 'text'
):
value_selectors.append(variable_selector.value_selector)
return value_selectors
@classmethod
def _extract_stream_variable_selector(cls, node_id_config_mapping: dict[str, dict], config: dict) \
-> list[list[str]]:
"""
Extract stream variable selector from node config
:param node_id_config_mapping: node id config mapping
:param config: node config
:return:
"""
node_data = EndNodeData(**config.get("data", {}))
return cls.extract_stream_variable_selector_from_node_data(node_id_config_mapping, node_data)
@classmethod
def _fetch_ends_dependencies(cls,
end_node_ids: list[str],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
node_id_config_mapping: dict[str, dict]
) -> dict[str, list[str]]:
"""
Fetch end dependencies
:param end_node_ids: end node ids
:param reverse_edge_mapping: reverse edge mapping
:param node_id_config_mapping: node id config mapping
:return:
"""
end_dependencies: dict[str, list[str]] = {}
for end_node_id in end_node_ids:
if end_dependencies.get(end_node_id) is None:
end_dependencies[end_node_id] = []
cls._recursive_fetch_end_dependencies(
current_node_id=end_node_id,
end_node_id=end_node_id,
node_id_config_mapping=node_id_config_mapping,
reverse_edge_mapping=reverse_edge_mapping,
end_dependencies=end_dependencies
)
return end_dependencies
@classmethod
def _recursive_fetch_end_dependencies(cls,
current_node_id: str,
end_node_id: str,
node_id_config_mapping: dict[str, dict],
reverse_edge_mapping: dict[str, list["GraphEdge"]],
# type: ignore[name-defined]
end_dependencies: dict[str, list[str]]
) -> None:
"""
Recursive fetch end dependencies
:param current_node_id: current node id
:param end_node_id: end node id
:param node_id_config_mapping: node id config mapping
:param reverse_edge_mapping: reverse edge mapping
:param end_dependencies: end dependencies
:return:
"""
reverse_edges = reverse_edge_mapping.get(current_node_id, [])
for edge in reverse_edges:
source_node_id = edge.source_node_id
source_node_type = node_id_config_mapping[source_node_id].get('data', {}).get('type')
if source_node_type in (
NodeType.IF_ELSE.value,
NodeType.QUESTION_CLASSIFIER,
):
end_dependencies[end_node_id].append(source_node_id)
else:
cls._recursive_fetch_end_dependencies(
current_node_id=source_node_id,
end_node_id=end_node_id,
node_id_config_mapping=node_id_config_mapping,
reverse_edge_mapping=reverse_edge_mapping,
end_dependencies=end_dependencies
)

View File

@ -0,0 +1,191 @@
import logging
from collections.abc import Generator
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes.answer.base_stream_processor import StreamProcessor
logger = logging.getLogger(__name__)
class EndStreamProcessor(StreamProcessor):
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
super().__init__(graph, variable_pool)
self.end_stream_param = graph.end_stream_param
self.route_position = {}
for end_node_id, _ in self.end_stream_param.end_stream_variable_selector_mapping.items():
self.route_position[end_node_id] = 0
self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
self.has_outputed = False
self.outputed_node_ids = set()
def process(self,
generator: Generator[GraphEngineEvent, None, None]
) -> Generator[GraphEngineEvent, None, None]:
for event in generator:
if isinstance(event, NodeRunStartedEvent):
if event.route_node_state.node_id == self.graph.root_node_id and not self.rest_node_ids:
self.reset()
yield event
elif isinstance(event, NodeRunStreamChunkEvent):
if event.in_iteration_id:
if self.has_outputed and event.node_id not in self.outputed_node_ids:
event.chunk_content = '\n' + event.chunk_content
self.outputed_node_ids.add(event.node_id)
self.has_outputed = True
yield event
continue
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
stream_out_end_node_ids = self.current_stream_chunk_generating_node_ids[
event.route_node_state.node_id
]
else:
stream_out_end_node_ids = self._get_stream_out_end_node_ids(event)
self.current_stream_chunk_generating_node_ids[
event.route_node_state.node_id
] = stream_out_end_node_ids
if stream_out_end_node_ids:
if self.has_outputed and event.node_id not in self.outputed_node_ids:
event.chunk_content = '\n' + event.chunk_content
self.outputed_node_ids.add(event.node_id)
self.has_outputed = True
yield event
elif isinstance(event, NodeRunSucceededEvent):
yield event
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
# update self.route_position after all stream event finished
for end_node_id in self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]:
self.route_position[end_node_id] += 1
del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]
# remove unreachable nodes
self._remove_unreachable_nodes(event)
# generate stream outputs
yield from self._generate_stream_outputs_when_node_finished(event)
else:
yield event
def reset(self) -> None:
self.route_position = {}
for end_node_id, _ in self.end_stream_param.end_stream_variable_selector_mapping.items():
self.route_position[end_node_id] = 0
self.rest_node_ids = self.graph.node_ids.copy()
self.current_stream_chunk_generating_node_ids = {}
def _generate_stream_outputs_when_node_finished(self,
event: NodeRunSucceededEvent
) -> Generator[GraphEngineEvent, None, None]:
"""
Generate stream outputs.
:param event: node run succeeded event
:return:
"""
for end_node_id, position in self.route_position.items():
# all depends on end node id not in rest node ids
if (event.route_node_state.node_id != end_node_id
and (end_node_id not in self.rest_node_ids
or not all(dep_id not in self.rest_node_ids
for dep_id in self.end_stream_param.end_dependencies[end_node_id]))):
continue
route_position = self.route_position[end_node_id]
position = 0
value_selectors = []
for current_value_selectors in self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]:
if position >= route_position:
value_selectors.append(current_value_selectors)
position += 1
for value_selector in value_selectors:
if not value_selector:
continue
value = self.variable_pool.get(
value_selector
)
if value is None:
break
text = value.markdown
if text:
current_node_id = value_selector[0]
if self.has_outputed and current_node_id not in self.outputed_node_ids:
text = '\n' + text
self.outputed_node_ids.add(current_node_id)
self.has_outputed = True
yield NodeRunStreamChunkEvent(
id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
chunk_content=text,
from_variable_selector=value_selector,
route_node_state=event.route_node_state,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
)
self.route_position[end_node_id] += 1
def _get_stream_out_end_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]:
"""
Is stream out support
:param event: queue text chunk event
:return:
"""
if not event.from_variable_selector:
return []
stream_output_value_selector = event.from_variable_selector
if not stream_output_value_selector:
return []
stream_out_end_node_ids = []
for end_node_id, route_position in self.route_position.items():
if end_node_id not in self.rest_node_ids:
continue
# all depends on end node id not in rest node ids
if all(dep_id not in self.rest_node_ids
for dep_id in self.end_stream_param.end_dependencies[end_node_id]):
if route_position >= len(self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]):
continue
position = 0
value_selector = None
for current_value_selectors in self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]:
if position == route_position:
value_selector = current_value_selectors
break
position += 1
if not value_selector:
continue
# check chunk node id is before current node id or equal to current node id
if value_selector != stream_output_value_selector:
continue
stream_out_end_node_ids.append(end_node_id)
return stream_out_end_node_ids

View File

@ -1,3 +1,5 @@
from pydantic import BaseModel, Field
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.variable_entities import VariableSelector
@ -7,3 +9,17 @@ class EndNodeData(BaseNodeData):
END Node Data.
"""
outputs: list[VariableSelector]
class EndStreamParam(BaseModel):
"""
EndStreamParam entity
"""
end_dependencies: dict[str, list[str]] = Field(
...,
description="end dependencies (end node id -> dependent node ids)"
)
end_stream_variable_selector_mapping: dict[str, list[list[str]]] = Field(
...,
description="end stream variable selector mapping (end node id -> stream variable selectors)"
)

View File

@ -0,0 +1,20 @@
from pydantic import BaseModel, Field
from core.workflow.entities.node_entities import NodeRunResult
class RunCompletedEvent(BaseModel):
run_result: NodeRunResult = Field(..., description="run result")
class RunStreamChunkEvent(BaseModel):
chunk_content: str = Field(..., description="chunk content")
from_variable_selector: list[str] = Field(..., description="from variable selector")
class RunRetrieverResourceEvent(BaseModel):
retriever_resources: list[dict] = Field(..., description="retriever resources")
context: str = Field(..., description="context")
RunEvent = RunCompletedEvent | RunStreamChunkEvent | RunRetrieverResourceEvent

View File

@ -1,15 +1,14 @@
import logging
from collections.abc import Mapping, Sequence
from mimetypes import guess_extension
from os import path
from typing import cast
from typing import Any, cast
from configs import dify_config
from core.app.segments import parser
from core.file.file_obj import FileTransferMethod, FileType, FileVar
from core.tools.tool_file_manager import ToolFileManager
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.http_request.entities import (
HttpRequestNodeData,
@ -48,17 +47,22 @@ class HttpRequestNode(BaseNode):
},
}
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
def _run(self) -> NodeRunResult:
node_data: HttpRequestNodeData = cast(HttpRequestNodeData, self.node_data)
# TODO: Switch to use segment directly
if node_data.authorization.config and node_data.authorization.config.api_key:
node_data.authorization.config.api_key = parser.convert_template(template=node_data.authorization.config.api_key, variable_pool=variable_pool).text
node_data.authorization.config.api_key = parser.convert_template(
template=node_data.authorization.config.api_key,
variable_pool=self.graph_runtime_state.variable_pool
).text
# init http executor
http_executor = None
try:
http_executor = HttpExecutor(
node_data=node_data, timeout=self._get_request_timeout(node_data), variable_pool=variable_pool
node_data=node_data,
timeout=self._get_request_timeout(node_data),
variable_pool=self.graph_runtime_state.variable_pool
)
# invoke http executor
@ -102,13 +106,19 @@ class HttpRequestNode(BaseNode):
return timeout
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: HttpRequestNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
node_data = cast(HttpRequestNodeData, node_data)
try:
http_executor = HttpExecutor(node_data=node_data, timeout=HTTP_REQUEST_DEFAULT_TIMEOUT)
@ -116,7 +126,7 @@ class HttpRequestNode(BaseNode):
variable_mapping = {}
for variable_selector in variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
variable_mapping[node_id + '.' + variable_selector.variable] = variable_selector.value_selector
return variable_mapping
except Exception as e:

View File

@ -3,20 +3,7 @@ from typing import Literal, Optional
from pydantic import BaseModel
from core.workflow.entities.base_node_data_entities import BaseNodeData
class Condition(BaseModel):
"""
Condition entity
"""
variable_selector: list[str]
comparison_operator: Literal[
# for string or array
"contains", "not contains", "start with", "end with", "is", "is not", "empty", "not empty", "regex match",
# for number
"=", "", ">", "<", "", "", "null", "not null"
]
value: Optional[str] = None
from core.workflow.utils.condition.entities import Condition
class IfElseNodeData(BaseNodeData):

View File

@ -1,13 +1,10 @@
import re
from collections.abc import Sequence
from typing import Optional, cast
from collections.abc import Mapping, Sequence
from typing import Any, cast
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.if_else.entities import Condition, IfElseNodeData
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from core.workflow.nodes.if_else.entities import IfElseNodeData
from core.workflow.utils.condition.processor import ConditionProcessor
from models.workflow import WorkflowNodeExecutionStatus
@ -15,31 +12,35 @@ class IfElseNode(BaseNode):
_node_data_cls = IfElseNodeData
_node_type = NodeType.IF_ELSE
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
def _run(self) -> NodeRunResult:
"""
Run node
:param variable_pool: variable pool
:return:
"""
node_data = self.node_data
node_data = cast(IfElseNodeData, node_data)
node_inputs = {
node_inputs: dict[str, list] = {
"conditions": []
}
process_datas = {
process_datas: dict[str, list] = {
"condition_results": []
}
input_conditions = []
final_result = False
selected_case_id = None
condition_processor = ConditionProcessor()
try:
# Check if the new cases structure is used
if node_data.cases:
for case in node_data.cases:
input_conditions, group_result = self.process_conditions(variable_pool, case.conditions)
input_conditions, group_result = condition_processor.process_conditions(
variable_pool=self.graph_runtime_state.variable_pool,
conditions=case.conditions
)
# Apply the logical operator for the current case
final_result = all(group_result) if case.logical_operator == "and" else any(group_result)
@ -58,7 +59,10 @@ class IfElseNode(BaseNode):
else:
# Fallback to old structure if cases are not defined
input_conditions, group_result = self.process_conditions(variable_pool, node_data.conditions)
input_conditions, group_result = condition_processor.process_conditions(
variable_pool=self.graph_runtime_state.variable_pool,
conditions=node_data.conditions
)
final_result = all(group_result) if node_data.logical_operator == "and" else any(group_result)
@ -94,376 +98,17 @@ class IfElseNode(BaseNode):
return data
def evaluate_condition(
self, actual_value: Optional[str | list], expected_value: str, comparison_operator: str
) -> bool:
"""
Evaluate condition
:param actual_value: actual value
:param expected_value: expected value
:param comparison_operator: comparison operator
:return: bool
"""
if comparison_operator == "contains":
return self._assert_contains(actual_value, expected_value)
elif comparison_operator == "not contains":
return self._assert_not_contains(actual_value, expected_value)
elif comparison_operator == "start with":
return self._assert_start_with(actual_value, expected_value)
elif comparison_operator == "end with":
return self._assert_end_with(actual_value, expected_value)
elif comparison_operator == "is":
return self._assert_is(actual_value, expected_value)
elif comparison_operator == "is not":
return self._assert_is_not(actual_value, expected_value)
elif comparison_operator == "empty":
return self._assert_empty(actual_value)
elif comparison_operator == "not empty":
return self._assert_not_empty(actual_value)
elif comparison_operator == "=":
return self._assert_equal(actual_value, expected_value)
elif comparison_operator == "":
return self._assert_not_equal(actual_value, expected_value)
elif comparison_operator == ">":
return self._assert_greater_than(actual_value, expected_value)
elif comparison_operator == "<":
return self._assert_less_than(actual_value, expected_value)
elif comparison_operator == "":
return self._assert_greater_than_or_equal(actual_value, expected_value)
elif comparison_operator == "":
return self._assert_less_than_or_equal(actual_value, expected_value)
elif comparison_operator == "null":
return self._assert_null(actual_value)
elif comparison_operator == "not null":
return self._assert_not_null(actual_value)
elif comparison_operator == "regex match":
return self._assert_regex_match(actual_value, expected_value)
else:
raise ValueError(f"Invalid comparison operator: {comparison_operator}")
def process_conditions(self, variable_pool: VariablePool, conditions: Sequence[Condition]):
input_conditions = []
group_result = []
for condition in conditions:
actual_variable = variable_pool.get_any(condition.variable_selector)
if condition.value is not None:
variable_template_parser = VariableTemplateParser(template=condition.value)
expected_value = variable_template_parser.extract_variable_selectors()
variable_selectors = variable_template_parser.extract_variable_selectors()
if variable_selectors:
for variable_selector in variable_selectors:
value = variable_pool.get_any(variable_selector.value_selector)
expected_value = variable_template_parser.format({variable_selector.variable: value})
else:
expected_value = condition.value
else:
expected_value = None
comparison_operator = condition.comparison_operator
input_conditions.append(
{
"actual_value": actual_variable,
"expected_value": expected_value,
"comparison_operator": comparison_operator
}
)
result = self.evaluate_condition(actual_variable, expected_value, comparison_operator)
group_result.append(result)
return input_conditions, group_result
def _assert_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool:
"""
Assert contains
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if not actual_value:
return False
if not isinstance(actual_value, str | list):
raise ValueError('Invalid actual value type: string or array')
if expected_value not in actual_value:
return False
return True
def _assert_not_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool:
"""
Assert not contains
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if not actual_value:
return True
if not isinstance(actual_value, str | list):
raise ValueError('Invalid actual value type: string or array')
if expected_value in actual_value:
return False
return True
def _assert_start_with(self, actual_value: Optional[str], expected_value: str) -> bool:
"""
Assert start with
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if not actual_value:
return False
if not isinstance(actual_value, str):
raise ValueError('Invalid actual value type: string')
if not actual_value.startswith(expected_value):
return False
return True
def _assert_end_with(self, actual_value: Optional[str], expected_value: str) -> bool:
"""
Assert end with
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if not actual_value:
return False
if not isinstance(actual_value, str):
raise ValueError('Invalid actual value type: string')
if not actual_value.endswith(expected_value):
return False
return True
def _assert_is(self, actual_value: Optional[str], expected_value: str) -> bool:
"""
Assert is
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, str):
raise ValueError('Invalid actual value type: string')
if actual_value != expected_value:
return False
return True
def _assert_is_not(self, actual_value: Optional[str], expected_value: str) -> bool:
"""
Assert is not
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, str):
raise ValueError('Invalid actual value type: string')
if actual_value == expected_value:
return False
return True
def _assert_empty(self, actual_value: Optional[str]) -> bool:
"""
Assert empty
:param actual_value: actual value
:return:
"""
if not actual_value:
return True
return False
def _assert_regex_match(self, actual_value: Optional[str], expected_value: str) -> bool:
"""
Assert empty
:param actual_value: actual value
:return:
"""
if actual_value is None:
return False
return re.search(expected_value, actual_value) is not None
def _assert_not_empty(self, actual_value: Optional[str]) -> bool:
"""
Assert not empty
:param actual_value: actual value
:return:
"""
if actual_value:
return True
return False
def _assert_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
"""
Assert equal
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value != expected_value:
return False
return True
def _assert_not_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
"""
Assert not equal
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value == expected_value:
return False
return True
def _assert_greater_than(self, actual_value: Optional[int | float], expected_value: str) -> bool:
"""
Assert greater than
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value <= expected_value:
return False
return True
def _assert_less_than(self, actual_value: Optional[int | float], expected_value: str) -> bool:
"""
Assert less than
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value >= expected_value:
return False
return True
def _assert_greater_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
"""
Assert greater than or equal
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value < expected_value:
return False
return True
def _assert_less_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
"""
Assert less than or equal
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value > expected_value:
return False
return True
def _assert_null(self, actual_value: Optional[int | float]) -> bool:
"""
Assert null
:param actual_value: actual value
:return:
"""
if actual_value is None:
return True
return False
def _assert_not_null(self, actual_value: Optional[int | float]) -> bool:
"""
Assert not null
:param actual_value: actual value
:return:
"""
if actual_value is not None:
return True
return False
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: IfElseNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""

View File

@ -1,6 +1,6 @@
from typing import Any, Optional
from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState
from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState, BaseNodeData
class IterationNodeData(BaseIterationNodeData):
@ -11,6 +11,13 @@ class IterationNodeData(BaseIterationNodeData):
iterator_selector: list[str] # variable selector
output_selector: list[str] # output selector
class IterationStartNodeData(BaseNodeData):
"""
Iteration Start Node Data.
"""
pass
class IterationState(BaseIterationState):
"""
Iteration State.

View File

@ -1,124 +1,371 @@
from typing import cast
import logging
from collections.abc import Generator, Mapping, Sequence
from datetime import datetime, timezone
from typing import Any, cast
from configs import dify_config
from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.entities.base_node_data_entities import BaseIterationState
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseIterationNode
from core.workflow.nodes.iteration.entities import IterationNodeData, IterationState
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.graph_engine.entities.event import (
BaseGraphEvent,
BaseNodeEvent,
BaseParallelBranchEvent,
GraphRunFailedEvent,
InNodeEvent,
IterationRunFailedEvent,
IterationRunNextEvent,
IterationRunStartedEvent,
IterationRunSucceededEvent,
NodeRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.run_condition import RunCondition
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.event import RunCompletedEvent, RunEvent
from core.workflow.nodes.iteration.entities import IterationNodeData
from core.workflow.utils.condition.entities import Condition
from models.workflow import WorkflowNodeExecutionStatus
logger = logging.getLogger(__name__)
class IterationNode(BaseIterationNode):
class IterationNode(BaseNode):
"""
Iteration Node.
"""
_node_data_cls = IterationNodeData
_node_type = NodeType.ITERATION
def _run(self, variable_pool: VariablePool) -> BaseIterationState:
def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]:
"""
Run the node.
"""
self.node_data = cast(IterationNodeData, self.node_data)
iterator = variable_pool.get_any(self.node_data.iterator_selector)
iterator_list_segment = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector)
if not isinstance(iterator, list):
raise ValueError(f"Invalid iterator value: {iterator}, please provide a list.")
state = IterationState(iteration_node_id=self.node_id, index=-1, inputs={
'iterator_selector': iterator
}, outputs=[], metadata=IterationState.MetaData(
iterator_length=len(iterator) if iterator is not None else 0
))
if not iterator_list_segment:
raise ValueError(f"Iterator variable {self.node_data.iterator_selector} not found")
self._set_current_iteration_variable(variable_pool, state)
return state
iterator_list_value = iterator_list_segment.to_object()
def _get_next_iteration(self, variable_pool: VariablePool, state: IterationState) -> NodeRunResult | str:
"""
Get next iteration start node id based on the graph.
:param graph: graph
:return: next node id
"""
# resolve current output
self._resolve_current_output(variable_pool, state)
# move to next iteration
self._next_iteration(variable_pool, state)
if not isinstance(iterator_list_value, list):
raise ValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.")
node_data = cast(IterationNodeData, self.node_data)
if self._reached_iteration_limit(variable_pool, state):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs = {
"iterator_selector": iterator_list_value
}
graph_config = self.graph_config
if not self.node_data.start_node_id:
raise ValueError(f'field start_node_id in iteration {self.node_id} not found')
root_node_id = self.node_data.start_node_id
# init graph
iteration_graph = Graph.init(
graph_config=graph_config,
root_node_id=root_node_id
)
if not iteration_graph:
raise ValueError('iteration graph not found')
leaf_node_ids = iteration_graph.get_leaf_node_ids()
iteration_leaf_node_ids = []
for leaf_node_id in leaf_node_ids:
node_config = iteration_graph.node_id_config_mapping.get(leaf_node_id)
if not node_config:
continue
leaf_node_iteration_id = node_config.get("data", {}).get("iteration_id")
if not leaf_node_iteration_id:
continue
if leaf_node_iteration_id != self.node_id:
continue
iteration_leaf_node_ids.append(leaf_node_id)
# add condition of end nodes to root node
iteration_graph.add_extra_edge(
source_node_id=leaf_node_id,
target_node_id=root_node_id,
run_condition=RunCondition(
type="condition",
conditions=[
Condition(
variable_selector=[self.node_id, "index"],
comparison_operator="<",
value=str(len(iterator_list_value))
)
]
)
)
variable_pool = self.graph_runtime_state.variable_pool
# append iteration variable (item, index) to variable pool
variable_pool.add(
[self.node_id, 'index'],
0
)
variable_pool.add(
[self.node_id, 'item'],
iterator_list_value[0]
)
# init graph engine
from core.workflow.graph_engine.graph_engine import GraphEngine
graph_engine = GraphEngine(
tenant_id=self.tenant_id,
app_id=self.app_id,
workflow_type=self.workflow_type,
workflow_id=self.workflow_id,
user_id=self.user_id,
user_from=self.user_from,
invoke_from=self.invoke_from,
call_depth=self.workflow_call_depth,
graph=iteration_graph,
graph_config=graph_config,
variable_pool=variable_pool,
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME
)
start_at = datetime.now(timezone.utc).replace(tzinfo=None)
yield IterationRunStartedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
metadata={
"iterator_length": len(iterator_list_value)
},
predecessor_node_id=self.previous_node_id
)
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
index=0,
pre_iteration_output=None
)
outputs: list[Any] = []
try:
# run workflow
rst = graph_engine.run()
for event in rst:
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
event.in_iteration_id = self.node_id
if isinstance(event, BaseNodeEvent) and event.node_type == NodeType.ITERATION_START:
continue
if isinstance(event, NodeRunSucceededEvent):
if event.route_node_state.node_run_result:
metadata = event.route_node_state.node_run_result.metadata
if not metadata:
metadata = {}
if NodeRunMetadataKey.ITERATION_ID not in metadata:
metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id
metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any([self.node_id, 'index'])
event.route_node_state.node_run_result.metadata = metadata
yield event
# handle iteration run result
if event.route_node_state.node_id in iteration_leaf_node_ids:
# append to iteration output variable list
current_iteration_output = variable_pool.get_any(self.node_data.output_selector)
outputs.append(current_iteration_output)
# remove all nodes outputs from variable pool
for node_id in iteration_graph.node_ids:
variable_pool.remove_node(node_id)
# move to next iteration
current_index = variable_pool.get([self.node_id, 'index'])
if current_index is None:
raise ValueError(f'iteration {self.node_id} current index not found')
next_index = int(current_index.to_object()) + 1
variable_pool.add(
[self.node_id, 'index'],
next_index
)
if next_index < len(iterator_list_value):
variable_pool.add(
[self.node_id, 'item'],
iterator_list_value[next_index]
)
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
index=next_index,
pre_iteration_output=jsonable_encoder(
current_iteration_output) if current_iteration_output else None
)
elif isinstance(event, BaseGraphEvent):
if isinstance(event, GraphRunFailedEvent):
# iteration run failed
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
outputs={
"output": jsonable_encoder(outputs)
},
steps=len(iterator_list_value),
metadata={
"total_tokens": graph_engine.graph_runtime_state.total_tokens
},
error=event.error,
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=event.error,
)
)
break
else:
event = cast(InNodeEvent, event)
yield event
yield IterationRunSucceededEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
outputs={
'output': jsonable_encoder(state.outputs)
"output": jsonable_encoder(outputs)
},
steps=len(iterator_list_value),
metadata={
"total_tokens": graph_engine.graph_runtime_state.total_tokens
}
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={
'output': jsonable_encoder(outputs)
}
)
)
except Exception as e:
# iteration run failed
logger.exception("Iteration run failed")
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
outputs={
"output": jsonable_encoder(outputs)
},
steps=len(iterator_list_value),
metadata={
"total_tokens": graph_engine.graph_runtime_state.total_tokens
},
error=str(e),
)
return node_data.start_node_id
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
)
)
finally:
# remove iteration variable (item, index) from variable pool after iteration run completed
variable_pool.remove([self.node_id, 'index'])
variable_pool.remove([self.node_id, 'item'])
def _set_current_iteration_variable(self, variable_pool: VariablePool, state: IterationState):
"""
Set current iteration variable.
:variable_pool: variable pool
"""
node_data = cast(IterationNodeData, self.node_data)
variable_pool.add((self.node_id, 'index'), state.index)
# get the iterator value
iterator = variable_pool.get_any(node_data.iterator_selector)
if iterator is None or not isinstance(iterator, list):
return
if state.index < len(iterator):
variable_pool.add((self.node_id, 'item'), iterator[state.index])
def _next_iteration(self, variable_pool: VariablePool, state: IterationState):
"""
Move to next iteration.
:param variable_pool: variable pool
"""
state.index += 1
self._set_current_iteration_variable(variable_pool, state)
def _reached_iteration_limit(self, variable_pool: VariablePool, state: IterationState):
"""
Check if iteration limit is reached.
:return: True if iteration limit is reached, False otherwise
"""
node_data = cast(IterationNodeData, self.node_data)
iterator = variable_pool.get_any(node_data.iterator_selector)
if iterator is None or not isinstance(iterator, list):
return True
return state.index >= len(iterator)
def _resolve_current_output(self, variable_pool: VariablePool, state: IterationState):
"""
Resolve current output.
:param variable_pool: variable pool
"""
output_selector = cast(IterationNodeData, self.node_data).output_selector
output = variable_pool.get_any(output_selector)
# clear the output for this iteration
variable_pool.remove([self.node_id] + output_selector[1:])
state.current_output = output
if output is not None:
# NOTE: This is a temporary patch to process double nested list (for example, DALL-E output in iteration).
if isinstance(output, list):
state.outputs.extend(output)
else:
state.outputs.append(output)
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: IterationNodeData) -> dict[str, list[str]]:
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: IterationNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
return {
'input_selector': node_data.iterator_selector,
}
variable_mapping = {
f'{node_id}.input_selector': node_data.iterator_selector,
}
# init graph
iteration_graph = Graph.init(
graph_config=graph_config,
root_node_id=node_data.start_node_id
)
if not iteration_graph:
raise ValueError('iteration graph not found')
for sub_node_id, sub_node_config in iteration_graph.node_id_config_mapping.items():
if sub_node_config.get('data', {}).get('iteration_id') != node_id:
continue
# variable selector to variable mapping
try:
# Get node class
from core.workflow.nodes.node_mapping import node_classes
node_type = NodeType.value_of(sub_node_config.get('data', {}).get('type'))
node_cls = node_classes.get(node_type)
if not node_cls:
continue
node_cls = cast(BaseNode, node_cls)
sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=graph_config,
config=sub_node_config
)
sub_node_variable_mapping = cast(dict[str, list[str]], sub_node_variable_mapping)
except NotImplementedError:
sub_node_variable_mapping = {}
# remove iteration variables
sub_node_variable_mapping = {
sub_node_id + '.' + key: value for key, value in sub_node_variable_mapping.items()
if value[0] != node_id
}
variable_mapping.update(sub_node_variable_mapping)
# remove variable out from iteration
variable_mapping = {
key: value for key, value in variable_mapping.items()
if value[0] not in iteration_graph.node_ids
}
return variable_mapping

View File

@ -0,0 +1,39 @@
from collections.abc import Mapping, Sequence
from typing import Any
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.iteration.entities import IterationNodeData, IterationStartNodeData
from models.workflow import WorkflowNodeExecutionStatus
class IterationStartNode(BaseNode):
"""
Iteration Start Node.
"""
_node_data_cls = IterationStartNodeData
_node_type = NodeType.ITERATION_START
def _run(self) -> NodeRunResult:
"""
Run the node.
"""
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED
)
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: IterationNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
return {}

View File

@ -1,3 +1,5 @@
import logging
from collections.abc import Mapping, Sequence
from typing import Any, cast
from sqlalchemy import func
@ -12,15 +14,15 @@ from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
from models.workflow import WorkflowNodeExecutionStatus
logger = logging.getLogger(__name__)
default_retrieval_model = {
'search_method': RetrievalMethod.SEMANTIC_SEARCH.value,
'reranking_enable': False,
@ -37,11 +39,11 @@ class KnowledgeRetrievalNode(BaseNode):
_node_data_cls = KnowledgeRetrievalNodeData
node_type = NodeType.KNOWLEDGE_RETRIEVAL
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
node_data: KnowledgeRetrievalNodeData = cast(self._node_data_cls, self.node_data)
def _run(self) -> NodeRunResult:
node_data = cast(KnowledgeRetrievalNodeData, self.node_data)
# extract variables
variable = variable_pool.get_any(node_data.query_variable_selector)
variable = self.graph_runtime_state.variable_pool.get_any(node_data.query_variable_selector)
query = variable
variables = {
'query': query
@ -68,7 +70,7 @@ class KnowledgeRetrievalNode(BaseNode):
)
except Exception as e:
logger.exception("Error when running knowledge retrieval node")
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
@ -235,11 +237,21 @@ class KnowledgeRetrievalNode(BaseNode):
return context_list
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
node_data = node_data
node_data = cast(cls._node_data_cls, node_data)
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: KnowledgeRetrievalNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
variable_mapping = {}
variable_mapping['query'] = node_data.query_variable_selector
variable_mapping[node_id + '.query'] = node_data.query_variable_selector
return variable_mapping
def _fetch_model_config(self, node_data: KnowledgeRetrievalNodeData) -> tuple[

View File

@ -1,16 +1,17 @@
import json
from collections.abc import Generator
from collections.abc import Generator, Mapping, Sequence
from copy import deepcopy
from typing import TYPE_CHECKING, Optional, cast
from typing import TYPE_CHECKING, Any, Optional, cast
from pydantic import BaseModel
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
from core.entities.model_entities import ModelStatus
from core.entities.provider_entities import QuotaUnit
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import (
ImagePromptMessageContent,
PromptMessage,
@ -25,7 +26,9 @@ from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.event import RunCompletedEvent, RunEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.nodes.llm.entities import (
LLMNodeChatModelMessage,
LLMNodeCompletionModelPromptTemplate,
@ -43,17 +46,26 @@ if TYPE_CHECKING:
class ModelInvokeCompleted(BaseModel):
"""
Model invoke completed
"""
text: str
usage: LLMUsage
finish_reason: Optional[str] = None
class LLMNode(BaseNode):
_node_data_cls = LLMNodeData
_node_type = NodeType.LLM
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]:
"""
Run node
:param variable_pool: variable pool
:return:
"""
node_data = cast(LLMNodeData, deepcopy(self.node_data))
variable_pool = self.graph_runtime_state.variable_pool
node_inputs = None
process_data = None
@ -80,10 +92,15 @@ class LLMNode(BaseNode):
node_inputs['#files#'] = [file.to_dict() for file in files]
# fetch context value
context = self._fetch_context(node_data, variable_pool)
generator = self._fetch_context(node_data, variable_pool)
context = None
for event in generator:
if isinstance(event, RunRetrieverResourceEvent):
context = event.context
yield event
if context:
node_inputs['#context#'] = context
node_inputs['#context#'] = context # type: ignore
# fetch model config
model_instance, model_config = self._fetch_model_config(node_data.model)
@ -115,19 +132,34 @@ class LLMNode(BaseNode):
}
# handle invoke result
result_text, usage, finish_reason = self._invoke_llm(
generator = self._invoke_llm(
node_data_model=node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop
)
result_text = ''
usage = LLMUsage.empty_usage()
finish_reason = None
for event in generator:
if isinstance(event, RunStreamChunkEvent):
yield event
elif isinstance(event, ModelInvokeCompleted):
result_text = event.text
usage = event.usage
finish_reason = event.finish_reason
break
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
inputs=node_inputs,
process_data=process_data
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
inputs=node_inputs,
process_data=process_data
)
)
return
outputs = {
'text': result_text,
@ -135,22 +167,26 @@ class LLMNode(BaseNode):
'finish_reason': finish_reason
}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=node_inputs,
process_data=process_data,
outputs=outputs,
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
NodeRunMetadataKey.CURRENCY: usage.currency
}
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=node_inputs,
process_data=process_data,
outputs=outputs,
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
NodeRunMetadataKey.CURRENCY: usage.currency
},
llm_usage=usage
)
)
def _invoke_llm(self, node_data_model: ModelConfig,
model_instance: ModelInstance,
prompt_messages: list[PromptMessage],
stop: list[str]) -> tuple[str, LLMUsage]:
stop: Optional[list[str]] = None) \
-> Generator[RunEvent | ModelInvokeCompleted, None, None]:
"""
Invoke large language model
:param node_data_model: node data model
@ -170,23 +206,31 @@ class LLMNode(BaseNode):
)
# handle invoke result
text, usage, finish_reason = self._handle_invoke_result(
generator = self._handle_invoke_result(
invoke_result=invoke_result
)
usage = LLMUsage.empty_usage()
for event in generator:
yield event
if isinstance(event, ModelInvokeCompleted):
usage = event.usage
# deduct quota
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
return text, usage, finish_reason
def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]:
def _handle_invoke_result(self, invoke_result: LLMResult | Generator) \
-> Generator[RunEvent | ModelInvokeCompleted, None, None]:
"""
Handle invoke result
:param invoke_result: invoke result
:return:
"""
if isinstance(invoke_result, LLMResult):
return
model = None
prompt_messages = []
prompt_messages: list[PromptMessage] = []
full_text = ''
usage = None
finish_reason = None
@ -194,7 +238,10 @@ class LLMNode(BaseNode):
text = result.delta.message.content
full_text += text
self.publish_text_chunk(text=text, value_selector=[self.node_id, 'text'])
yield RunStreamChunkEvent(
chunk_content=text,
from_variable_selector=[self.node_id, 'text']
)
if not model:
model = result.model
@ -211,11 +258,15 @@ class LLMNode(BaseNode):
if not usage:
usage = LLMUsage.empty_usage()
return full_text, usage, finish_reason
yield ModelInvokeCompleted(
text=full_text,
usage=usage,
finish_reason=finish_reason
)
def _transform_chat_messages(self,
messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
"""
Transform chat messages
@ -224,13 +275,13 @@ class LLMNode(BaseNode):
"""
if isinstance(messages, LLMNodeCompletionModelPromptTemplate):
if messages.edition_type == 'jinja2':
if messages.edition_type == 'jinja2' and messages.jinja2_text:
messages.text = messages.jinja2_text
return messages
for message in messages:
if message.edition_type == 'jinja2':
if message.edition_type == 'jinja2' and message.jinja2_text:
message.text = message.jinja2_text
return messages
@ -348,7 +399,7 @@ class LLMNode(BaseNode):
return files
def _fetch_context(self, node_data: LLMNodeData, variable_pool: VariablePool) -> Optional[str]:
def _fetch_context(self, node_data: LLMNodeData, variable_pool: VariablePool) -> Generator[RunEvent, None, None]:
"""
Fetch context
:param node_data: node data
@ -356,15 +407,18 @@ class LLMNode(BaseNode):
:return:
"""
if not node_data.context.enabled:
return None
return
if not node_data.context.variable_selector:
return None
return
context_value = variable_pool.get_any(node_data.context.variable_selector)
if context_value:
if isinstance(context_value, str):
return context_value
yield RunRetrieverResourceEvent(
retriever_resources=[],
context=context_value
)
elif isinstance(context_value, list):
context_str = ''
original_retriever_resource = []
@ -381,17 +435,10 @@ class LLMNode(BaseNode):
if retriever_resource:
original_retriever_resource.append(retriever_resource)
if self.callbacks and original_retriever_resource:
for callback in self.callbacks:
callback.on_event(
event=QueueRetrieverResourcesEvent(
retriever_resources=original_retriever_resource
)
)
return context_str.strip()
return None
yield RunRetrieverResourceEvent(
retriever_resources=original_retriever_resource,
context=context_str.strip()
)
def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]:
"""
@ -574,7 +621,8 @@ class LLMNode(BaseNode):
if not isinstance(prompt_message.content, str):
prompt_message_content = []
for content_item in prompt_message.content:
if vision_enabled and content_item.type == PromptMessageContentType.IMAGE and isinstance(content_item, ImagePromptMessageContent):
if vision_enabled and content_item.type == PromptMessageContentType.IMAGE and isinstance(
content_item, ImagePromptMessageContent):
# Override vision config if LLM node has vision config
if vision_detail:
content_item.detail = ImagePromptMessageContent.DETAIL(vision_detail)
@ -646,13 +694,19 @@ class LLMNode(BaseNode):
db.session.commit()
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: LLMNodeData) -> dict[str, list[str]]:
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: LLMNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
prompt_template = node_data.prompt_template
variable_selectors = []
@ -702,6 +756,10 @@ class LLMNode(BaseNode):
for variable_selector in node_data.prompt_config.jinja2_variables or []:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
variable_mapping = {
node_id + '.' + key: value for key, value in variable_mapping.items()
}
return variable_mapping
@classmethod

View File

@ -1,20 +1,34 @@
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseIterationNode
from typing import Any
from core.workflow.entities.node_entities import NodeType
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.loop.entities import LoopNodeData, LoopState
from core.workflow.utils.condition.entities import Condition
class LoopNode(BaseIterationNode):
class LoopNode(BaseNode):
"""
Loop Node.
"""
_node_data_cls = LoopNodeData
_node_type = NodeType.LOOP
def _run(self, variable_pool: VariablePool) -> LoopState:
return super()._run(variable_pool)
def _run(self) -> LoopState:
return super()._run()
def _get_next_iteration(self, variable_loop: VariablePool) -> NodeRunResult | str:
@classmethod
def get_conditions(cls, node_config: dict[str, Any]) -> list[Condition]:
"""
Get next iteration start node id based on the graph.
Get conditions.
"""
node_id = node_config.get('id')
if not node_id:
return []
# TODO waiting for implementation
return [Condition(
variable_selector=[node_id, 'index'],
comparison_operator="",
value_type="value_selector",
value_selector=[]
)]

View File

@ -0,0 +1,37 @@
from core.workflow.entities.node_entities import NodeType
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.http_request.http_request_node import HttpRequestNode
from core.workflow.nodes.if_else.if_else_node import IfElseNode
from core.workflow.nodes.iteration.iteration_node import IterationNode
from core.workflow.nodes.iteration.iteration_start_node import IterationStartNode
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
from core.workflow.nodes.llm.llm_node import LLMNode
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from core.workflow.nodes.tool.tool_node import ToolNode
from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode
from core.workflow.nodes.variable_assigner import VariableAssignerNode
node_classes = {
NodeType.START: StartNode,
NodeType.END: EndNode,
NodeType.ANSWER: AnswerNode,
NodeType.LLM: LLMNode,
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
NodeType.IF_ELSE: IfElseNode,
NodeType.CODE: CodeNode,
NodeType.TEMPLATE_TRANSFORM: TemplateTransformNode,
NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode,
NodeType.HTTP_REQUEST: HttpRequestNode,
NodeType.TOOL: ToolNode,
NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode,
NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, # original name of VARIABLE_AGGREGATOR
NodeType.ITERATION: IterationNode,
NodeType.ITERATION_START: IterationStartNode,
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode,
NodeType.CONVERSATION_VARIABLE_ASSIGNER: VariableAssignerNode,
}

View File

@ -1,6 +1,7 @@
import json
import uuid
from typing import Optional, cast
from collections.abc import Mapping, Sequence
from typing import Any, Optional, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory
@ -66,12 +67,12 @@ class ParameterExtractorNode(LLMNode):
}
}
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
def _run(self) -> NodeRunResult:
"""
Run the node.
"""
node_data = cast(ParameterExtractorNodeData, self.node_data)
variable = variable_pool.get_any(node_data.query)
variable = self.graph_runtime_state.variable_pool.get_any(node_data.query)
if not variable:
raise ValueError("Input variable content not found or is empty")
query = variable
@ -92,17 +93,20 @@ class ParameterExtractorNode(LLMNode):
raise ValueError("Model schema not found")
# fetch memory
memory = self._fetch_memory(node_data.memory, variable_pool, model_instance)
memory = self._fetch_memory(node_data.memory, self.graph_runtime_state.variable_pool, model_instance)
if set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL} \
and node_data.reasoning_mode == 'function_call':
# use function call
prompt_messages, prompt_message_tools = self._generate_function_call_prompt(
node_data, query, variable_pool, model_config, memory
node_data, query, self.graph_runtime_state.variable_pool, model_config, memory
)
else:
# use prompt engineering
prompt_messages = self._generate_prompt_engineering_prompt(node_data, query, variable_pool, model_config,
prompt_messages = self._generate_prompt_engineering_prompt(node_data,
query,
self.graph_runtime_state.variable_pool,
model_config,
memory)
prompt_message_tools = []
@ -172,7 +176,8 @@ class ParameterExtractorNode(LLMNode):
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
NodeRunMetadataKey.CURRENCY: usage.currency
}
},
llm_usage=usage
)
def _invoke_llm(self, node_data_model: ModelConfig,
@ -697,15 +702,19 @@ class ParameterExtractorNode(LLMNode):
return self._model_instance, self._model_config
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: ParameterExtractorNodeData) -> dict[
str, list[str]]:
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: ParameterExtractorNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
node_data = node_data
variable_mapping = {
'query': node_data.query
}
@ -715,4 +724,8 @@ class ParameterExtractorNode(LLMNode):
for selector in variable_template_parser.extract_variable_selectors():
variable_mapping[selector.variable] = selector.value_selector
variable_mapping = {
node_id + '.' + key: value for key, value in variable_mapping.items()
}
return variable_mapping

View File

@ -1,10 +1,12 @@
import json
import logging
from typing import Optional, Union, cast
from collections.abc import Mapping, Sequence
from typing import Any, Optional, Union, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.utils.encoders import jsonable_encoder
@ -13,10 +15,9 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.llm.llm_node import LLMNode
from core.workflow.nodes.llm.llm_node import LLMNode, ModelInvokeCompleted
from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData
from core.workflow.nodes.question_classifier.template_prompts import (
QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1,
@ -36,9 +37,10 @@ class QuestionClassifierNode(LLMNode):
_node_data_cls = QuestionClassifierNodeData
node_type = NodeType.QUESTION_CLASSIFIER
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
def _run(self) -> NodeRunResult:
node_data: QuestionClassifierNodeData = cast(self._node_data_cls, self.node_data)
node_data = cast(QuestionClassifierNodeData, node_data)
variable_pool = self.graph_runtime_state.variable_pool
# extract variables
variable = variable_pool.get(node_data.query_variable_selector)
@ -63,12 +65,23 @@ class QuestionClassifierNode(LLMNode):
)
# handle invoke result
result_text, usage, finish_reason = self._invoke_llm(
generator = self._invoke_llm(
node_data_model=node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop
)
result_text = ''
usage = LLMUsage.empty_usage()
finish_reason = None
for event in generator:
if isinstance(event, ModelInvokeCompleted):
result_text = event.text
usage = event.usage
finish_reason = event.finish_reason
break
category_name = node_data.classes[0].name
category_id = node_data.classes[0].id
try:
@ -109,7 +122,8 @@ class QuestionClassifierNode(LLMNode):
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
NodeRunMetadataKey.CURRENCY: usage.currency
}
},
llm_usage=usage
)
except ValueError as e:
@ -121,13 +135,24 @@ class QuestionClassifierNode(LLMNode):
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
NodeRunMetadataKey.CURRENCY: usage.currency
}
},
llm_usage=usage
)
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
node_data = node_data
node_data = cast(cls._node_data_cls, node_data)
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: QuestionClassifierNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
variable_mapping = {'query': node_data.query_variable_selector}
variable_selectors = []
if node_data.instruction:
@ -135,6 +160,11 @@ class QuestionClassifierNode(LLMNode):
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
for variable_selector in variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
variable_mapping = {
node_id + '.' + key: value for key, value in variable_mapping.items()
}
return variable_mapping
@classmethod

View File

@ -1,7 +1,9 @@
from core.workflow.entities.base_node_data_entities import BaseNodeData
from collections.abc import Mapping, Sequence
from typing import Any
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import SYSTEM_VARIABLE_NODE_ID, VariablePool
from core.workflow.entities.variable_pool import SYSTEM_VARIABLE_NODE_ID
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.start.entities import StartNodeData
from models.workflow import WorkflowNodeExecutionStatus
@ -11,14 +13,13 @@ class StartNode(BaseNode):
_node_data_cls = StartNodeData
_node_type = NodeType.START
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
def _run(self) -> NodeRunResult:
"""
Run node
:param variable_pool: variable pool
:return:
"""
node_inputs = dict(variable_pool.user_inputs)
system_inputs = variable_pool.system_variables
node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
system_inputs = self.graph_runtime_state.variable_pool.system_variables
for var in system_inputs:
node_inputs[SYSTEM_VARIABLE_NODE_ID + '.' + var] = system_inputs[var]
@ -30,9 +31,16 @@ class StartNode(BaseNode):
)
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: StartNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""

View File

@ -1,15 +1,16 @@
import os
from typing import Optional, cast
from collections.abc import Mapping, Sequence
from typing import Any, Optional, cast
from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
from models.workflow import WorkflowNodeExecutionStatus
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get('TEMPLATE_TRANSFORM_MAX_LENGTH', '80000'))
class TemplateTransformNode(BaseNode):
_node_data_cls = TemplateTransformNodeData
_node_type = NodeType.TEMPLATE_TRANSFORM
@ -34,7 +35,7 @@ class TemplateTransformNode(BaseNode):
}
}
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
def _run(self) -> NodeRunResult:
"""
Run node
"""
@ -45,7 +46,7 @@ class TemplateTransformNode(BaseNode):
variables = {}
for variable_selector in node_data.variables:
variable_name = variable_selector.variable
value = variable_pool.get_any(variable_selector.value_selector)
value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
variables[variable_name] = value
# Run code
try:
@ -60,7 +61,7 @@ class TemplateTransformNode(BaseNode):
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e)
)
if len(result['result']) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
return NodeRunResult(
inputs=variables,
@ -75,14 +76,21 @@ class TemplateTransformNode(BaseNode):
'output': result['result']
}
)
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: TemplateTransformNodeData) -> dict[str, list[str]]:
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: TemplateTransformNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
return {
variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables
}
node_id + '.' + variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables
}

View File

@ -26,7 +26,7 @@ class ToolNode(BaseNode):
_node_data_cls = ToolNodeData
_node_type = NodeType.TOOL
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
def _run(self) -> NodeRunResult:
"""
Run the tool node
"""
@ -56,8 +56,8 @@ class ToolNode(BaseNode):
# get parameters
tool_parameters = tool_runtime.get_runtime_parameters() or []
parameters = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=variable_pool, node_data=node_data)
parameters_for_log = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=variable_pool, node_data=node_data, for_log=True)
parameters = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=node_data)
parameters_for_log = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=node_data, for_log=True)
try:
messages = ToolEngine.workflow_invoke(
@ -66,6 +66,7 @@ class ToolNode(BaseNode):
user_id=self.user_id,
workflow_tool_callback=DifyWorkflowCallbackHandler(),
workflow_call_depth=self.workflow_call_depth,
thread_pool_id=self.thread_pool_id,
)
except Exception as e:
return NodeRunResult(
@ -145,7 +146,8 @@ class ToolNode(BaseNode):
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
return list(variable.value) if variable else []
def _convert_tool_messages(self, messages: list[ToolInvokeMessage]):
def _convert_tool_messages(self, messages: list[ToolInvokeMessage])\
-> tuple[str, list[FileVar], list[dict]]:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
"""
@ -221,9 +223,16 @@ class ToolNode(BaseNode):
return [message.message for message in tool_response if message.type == ToolInvokeMessage.MessageType.JSON]
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: ToolNodeData) -> dict[str, list[str]]:
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: ToolNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
@ -239,4 +248,8 @@ class ToolNode(BaseNode):
elif input.type == 'constant':
pass
result = {
node_id + '.' + key: value for key, value in result.items()
}
return result

View File

@ -1,8 +1,7 @@
from typing import cast
from collections.abc import Mapping, Sequence
from typing import Any, cast
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData
from models.workflow import WorkflowNodeExecutionStatus
@ -12,7 +11,7 @@ class VariableAggregatorNode(BaseNode):
_node_data_cls = VariableAssignerNodeData
_node_type = NodeType.VARIABLE_AGGREGATOR
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
def _run(self) -> NodeRunResult:
node_data = cast(VariableAssignerNodeData, self.node_data)
# Get variables
outputs = {}
@ -20,7 +19,7 @@ class VariableAggregatorNode(BaseNode):
if not node_data.advanced_settings or not node_data.advanced_settings.group_enabled:
for selector in node_data.variables:
variable = variable_pool.get_any(selector)
variable = self.graph_runtime_state.variable_pool.get_any(selector)
if variable is not None:
outputs = {
"output": variable
@ -33,7 +32,7 @@ class VariableAggregatorNode(BaseNode):
else:
for group in node_data.advanced_settings.groups:
for selector in group.variables:
variable = variable_pool.get_any(selector)
variable = self.graph_runtime_state.variable_pool.get_any(selector)
if variable is not None:
outputs[group.group_name] = {
@ -49,5 +48,17 @@ class VariableAggregatorNode(BaseNode):
)
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: VariableAssignerNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
return {}

View File

@ -6,7 +6,6 @@ from sqlalchemy.orm import Session
from core.app.segments import SegmentType, Variable, factory
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from extensions.ext_database import db
from models import ConversationVariable, WorkflowNodeExecutionStatus
@ -19,23 +18,23 @@ class VariableAssignerNode(BaseNode):
_node_data_cls: type[BaseNodeData] = VariableAssignerData
_node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
def _run(self) -> NodeRunResult:
data = cast(VariableAssignerData, self.node_data)
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
original_variable = variable_pool.get(data.assigned_variable_selector)
original_variable = self.graph_runtime_state.variable_pool.get(data.assigned_variable_selector)
if not isinstance(original_variable, Variable):
raise VariableAssignerNodeError('assigned variable not found')
match data.write_mode:
case WriteMode.OVER_WRITE:
income_value = variable_pool.get(data.input_variable_selector)
income_value = self.graph_runtime_state.variable_pool.get(data.input_variable_selector)
if not income_value:
raise VariableAssignerNodeError('input value not found')
updated_variable = original_variable.model_copy(update={'value': income_value.value})
case WriteMode.APPEND:
income_value = variable_pool.get(data.input_variable_selector)
income_value = self.graph_runtime_state.variable_pool.get(data.input_variable_selector)
if not income_value:
raise VariableAssignerNodeError('input value not found')
updated_value = original_variable.value + [income_value.value]
@ -49,11 +48,11 @@ class VariableAssignerNode(BaseNode):
raise VariableAssignerNodeError(f'unsupported write mode: {data.write_mode}')
# Over write the variable.
variable_pool.add(data.assigned_variable_selector, updated_variable)
self.graph_runtime_state.variable_pool.add(data.assigned_variable_selector, updated_variable)
# TODO: Move database operation to the pipeline.
# Update conversation variable.
conversation_id = variable_pool.get(['sys', 'conversation_id'])
conversation_id = self.graph_runtime_state.variable_pool.get(['sys', 'conversation_id'])
if not conversation_id:
raise VariableAssignerNodeError('conversation_id not found')
update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable)

View File

@ -0,0 +1,17 @@
from typing import Literal, Optional
from pydantic import BaseModel
class Condition(BaseModel):
"""
Condition entity
"""
variable_selector: list[str]
comparison_operator: Literal[
# for string or array
"contains", "not contains", "start with", "end with", "is", "is not", "empty", "not empty",
# for number
"=", "", ">", "<", "", "", "null", "not null"
]
value: Optional[str] = None

View File

@ -0,0 +1,383 @@
from collections.abc import Sequence
from typing import Any, Optional
from core.file.file_obj import FileVar
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.utils.condition.entities import Condition
from core.workflow.utils.variable_template_parser import VariableTemplateParser
class ConditionProcessor:
def process_conditions(self, variable_pool: VariablePool, conditions: Sequence[Condition]):
input_conditions = []
group_result = []
index = 0
for condition in conditions:
index += 1
actual_value = variable_pool.get_any(
condition.variable_selector
)
expected_value = None
if condition.value is not None:
variable_template_parser = VariableTemplateParser(template=condition.value)
variable_selectors = variable_template_parser.extract_variable_selectors()
if variable_selectors:
for variable_selector in variable_selectors:
value = variable_pool.get_any(
variable_selector.value_selector
)
expected_value = variable_template_parser.format({variable_selector.variable: value})
if expected_value is None:
expected_value = condition.value
else:
expected_value = condition.value
comparison_operator = condition.comparison_operator
input_conditions.append(
{
"actual_value": actual_value,
"expected_value": expected_value,
"comparison_operator": comparison_operator
}
)
result = self.evaluate_condition(actual_value, comparison_operator, expected_value)
group_result.append(result)
return input_conditions, group_result
def evaluate_condition(
self,
actual_value: Optional[str | int | float | dict[Any, Any] | list[Any] | FileVar | None],
comparison_operator: str,
expected_value: Optional[str] = None
) -> bool:
"""
Evaluate condition
:param actual_value: actual value
:param expected_value: expected value
:param comparison_operator: comparison operator
:return: bool
"""
if comparison_operator == "contains":
return self._assert_contains(actual_value, expected_value)
elif comparison_operator == "not contains":
return self._assert_not_contains(actual_value, expected_value)
elif comparison_operator == "start with":
return self._assert_start_with(actual_value, expected_value)
elif comparison_operator == "end with":
return self._assert_end_with(actual_value, expected_value)
elif comparison_operator == "is":
return self._assert_is(actual_value, expected_value)
elif comparison_operator == "is not":
return self._assert_is_not(actual_value, expected_value)
elif comparison_operator == "empty":
return self._assert_empty(actual_value)
elif comparison_operator == "not empty":
return self._assert_not_empty(actual_value)
elif comparison_operator == "=":
return self._assert_equal(actual_value, expected_value)
elif comparison_operator == "":
return self._assert_not_equal(actual_value, expected_value)
elif comparison_operator == ">":
return self._assert_greater_than(actual_value, expected_value)
elif comparison_operator == "<":
return self._assert_less_than(actual_value, expected_value)
elif comparison_operator == "":
return self._assert_greater_than_or_equal(actual_value, expected_value)
elif comparison_operator == "":
return self._assert_less_than_or_equal(actual_value, expected_value)
elif comparison_operator == "null":
return self._assert_null(actual_value)
elif comparison_operator == "not null":
return self._assert_not_null(actual_value)
else:
raise ValueError(f"Invalid comparison operator: {comparison_operator}")
def _assert_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool:
"""
Assert contains
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if not actual_value:
return False
if not isinstance(actual_value, str | list):
raise ValueError('Invalid actual value type: string or array')
if expected_value not in actual_value:
return False
return True
def _assert_not_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool:
"""
Assert not contains
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if not actual_value:
return True
if not isinstance(actual_value, str | list):
raise ValueError('Invalid actual value type: string or array')
if expected_value in actual_value:
return False
return True
def _assert_start_with(self, actual_value: Optional[str], expected_value: str) -> bool:
"""
Assert start with
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if not actual_value:
return False
if not isinstance(actual_value, str):
raise ValueError('Invalid actual value type: string')
if not actual_value.startswith(expected_value):
return False
return True
def _assert_end_with(self, actual_value: Optional[str], expected_value: str) -> bool:
"""
Assert end with
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if not actual_value:
return False
if not isinstance(actual_value, str):
raise ValueError('Invalid actual value type: string')
if not actual_value.endswith(expected_value):
return False
return True
def _assert_is(self, actual_value: Optional[str], expected_value: str) -> bool:
"""
Assert is
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, str):
raise ValueError('Invalid actual value type: string')
if actual_value != expected_value:
return False
return True
def _assert_is_not(self, actual_value: Optional[str], expected_value: str) -> bool:
"""
Assert is not
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, str):
raise ValueError('Invalid actual value type: string')
if actual_value == expected_value:
return False
return True
def _assert_empty(self, actual_value: Optional[str]) -> bool:
"""
Assert empty
:param actual_value: actual value
:return:
"""
if not actual_value:
return True
return False
def _assert_not_empty(self, actual_value: Optional[str]) -> bool:
"""
Assert not empty
:param actual_value: actual value
:return:
"""
if actual_value:
return True
return False
def _assert_equal(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool:
"""
Assert equal
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value != expected_value:
return False
return True
def _assert_not_equal(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool:
"""
Assert not equal
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value == expected_value:
return False
return True
def _assert_greater_than(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool:
"""
Assert greater than
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value <= expected_value:
return False
return True
def _assert_less_than(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool:
"""
Assert less than
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value >= expected_value:
return False
return True
def _assert_greater_than_or_equal(self, actual_value: Optional[int | float],
expected_value: str | int | float) -> bool:
"""
Assert greater than or equal
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value < expected_value:
return False
return True
def _assert_less_than_or_equal(self, actual_value: Optional[int | float],
expected_value: str | int | float) -> bool:
"""
Assert less than or equal
:param actual_value: actual value
:param expected_value: expected value
:return:
"""
if actual_value is None:
return False
if not isinstance(actual_value, int | float):
raise ValueError('Invalid actual value type: number')
if isinstance(actual_value, int):
expected_value = int(expected_value)
else:
expected_value = float(expected_value)
if actual_value > expected_value:
return False
return True
def _assert_null(self, actual_value: Optional[int | float]) -> bool:
"""
Assert null
:param actual_value: actual value
:return:
"""
if actual_value is None:
return True
return False
def _assert_not_null(self, actual_value: Optional[int | float]) -> bool:
"""
Assert not null
:param actual_value: actual value
:return:
"""
if actual_value is not None:
return True
return False
class ConditionAssertionError(Exception):
def __init__(self, message: str, conditions: list[dict], sub_condition_compare_results: list[dict]) -> None:
self.message = message
self.conditions = conditions
self.sub_condition_compare_results = sub_condition_compare_results
super().__init__(self.message)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,314 @@
import logging
import time
import uuid
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast
from configs import dify_config
from core.app.app_config.entities import FileExtraConfig
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.file_obj import FileTransferMethod, FileType, FileVar
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType, UserFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph_engine.entities.event import GraphEngineEvent, GraphRunFailedEvent, InNodeEvent
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.event import RunEvent
from core.workflow.nodes.llm.entities import LLMNodeData
from core.workflow.nodes.node_mapping import node_classes
from models.workflow import (
Workflow,
WorkflowType,
)
logger = logging.getLogger(__name__)
class WorkflowEntry:
def __init__(
self,
tenant_id: str,
app_id: str,
workflow_id: str,
workflow_type: WorkflowType,
graph_config: Mapping[str, Any],
graph: Graph,
user_id: str,
user_from: UserFrom,
invoke_from: InvokeFrom,
call_depth: int,
variable_pool: VariablePool,
thread_pool_id: Optional[str] = None
) -> None:
"""
Init workflow entry
:param tenant_id: tenant id
:param app_id: app id
:param workflow_id: workflow id
:param workflow_type: workflow type
:param graph_config: workflow graph config
:param graph: workflow graph
:param user_id: user id
:param user_from: user from
:param invoke_from: invoke from
:param call_depth: call depth
:param variable_pool: variable pool
:param thread_pool_id: thread pool id
"""
# check call depth
workflow_call_max_depth = dify_config.WORKFLOW_CALL_MAX_DEPTH
if call_depth > workflow_call_max_depth:
raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth))
# init workflow run state
self.graph_engine = GraphEngine(
tenant_id=tenant_id,
app_id=app_id,
workflow_type=workflow_type,
workflow_id=workflow_id,
user_id=user_id,
user_from=user_from,
invoke_from=invoke_from,
call_depth=call_depth,
graph=graph,
graph_config=graph_config,
variable_pool=variable_pool,
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
thread_pool_id=thread_pool_id
)
def run(
self,
*,
callbacks: Sequence[WorkflowCallback],
) -> Generator[GraphEngineEvent, None, None]:
"""
:param callbacks: workflow callbacks
"""
graph_engine = self.graph_engine
try:
# run workflow
generator = graph_engine.run()
for event in generator:
if callbacks:
for callback in callbacks:
callback.on_event(
event=event
)
yield event
except GenerateTaskStoppedException:
pass
except Exception as e:
logger.exception("Unknown Error when workflow entry running")
if callbacks:
for callback in callbacks:
callback.on_event(
event=GraphRunFailedEvent(
error=str(e)
)
)
return
@classmethod
def single_step_run(
cls,
workflow: Workflow,
node_id: str,
user_id: str,
user_inputs: dict
) -> tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]:
"""
Single step run workflow node
:param workflow: Workflow instance
:param node_id: node id
:param user_id: user id
:param user_inputs: user inputs
:return:
"""
# fetch node info from workflow graph
graph = workflow.graph_dict
if not graph:
raise ValueError('workflow graph not found')
nodes = graph.get('nodes')
if not nodes:
raise ValueError('nodes not found in workflow graph')
# fetch node config from node id
node_config = None
for node in nodes:
if node.get('id') == node_id:
node_config = node
break
if not node_config:
raise ValueError('node id not found in workflow graph')
# Get node class
node_type = NodeType.value_of(node_config.get('data', {}).get('type'))
node_cls = node_classes.get(node_type)
node_cls = cast(type[BaseNode], node_cls)
if not node_cls:
raise ValueError(f'Node class not found for node type {node_type}')
# init variable pool
variable_pool = VariablePool(
system_variables={},
user_inputs={},
environment_variables=workflow.environment_variables,
)
# init graph
graph = Graph.init(
graph_config=workflow.graph_dict
)
# init workflow run state
node_instance: BaseNode = node_cls(
id=str(uuid.uuid4()),
config=node_config,
graph_init_params=GraphInitParams(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
workflow_type=WorkflowType.value_of(workflow.type),
workflow_id=workflow.id,
graph_config=workflow.graph_dict,
user_id=user_id,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0
),
graph=graph,
graph_runtime_state=GraphRuntimeState(
variable_pool=variable_pool,
start_at=time.perf_counter()
)
)
try:
# variable selector to variable mapping
try:
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=workflow.graph_dict,
config=node_config
)
except NotImplementedError:
variable_mapping = {}
cls.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
user_inputs=user_inputs,
variable_pool=variable_pool,
tenant_id=workflow.tenant_id,
node_type=node_type,
node_data=node_instance.node_data
)
# run node
generator = node_instance.run()
return node_instance, generator
except Exception as e:
raise WorkflowNodeRunFailedError(
node_instance=node_instance,
error=str(e)
)
@classmethod
def handle_special_values(cls, value: Optional[Mapping[str, Any]]) -> Optional[dict]:
"""
Handle special values
:param value: value
:return:
"""
if not value:
return None
new_value = dict(value) if value else {}
if isinstance(new_value, dict):
for key, val in new_value.items():
if isinstance(val, FileVar):
new_value[key] = val.to_dict()
elif isinstance(val, list):
new_val = []
for v in val:
if isinstance(v, FileVar):
new_val.append(v.to_dict())
else:
new_val.append(v)
new_value[key] = new_val
return new_value
@classmethod
def mapping_user_inputs_to_variable_pool(
cls,
variable_mapping: Mapping[str, Sequence[str]],
user_inputs: dict,
variable_pool: VariablePool,
tenant_id: str,
node_type: NodeType,
node_data: BaseNodeData
) -> None:
for node_variable, variable_selector in variable_mapping.items():
# fetch node id and variable key from node_variable
node_variable_list = node_variable.split('.')
if len(node_variable_list) < 1:
raise ValueError(f'Invalid node variable {node_variable}')
node_variable_key = '.'.join(node_variable_list[1:])
if (
node_variable_key not in user_inputs
and node_variable not in user_inputs
) and not variable_pool.get(variable_selector):
raise ValueError(f'Variable key {node_variable} not found in user inputs.')
# fetch variable node id from variable selector
variable_node_id = variable_selector[0]
variable_key_list = variable_selector[1:]
variable_key_list = cast(list[str], variable_key_list)
# get input value
input_value = user_inputs.get(node_variable)
if not input_value:
input_value = user_inputs.get(node_variable_key)
# FIXME: temp fix for image type
if node_type == NodeType.LLM:
new_value = []
if isinstance(input_value, list):
node_data = cast(LLMNodeData, node_data)
detail = node_data.vision.configs.detail if node_data.vision.configs else None
for item in input_value:
if isinstance(item, dict) and 'type' in item and item['type'] == 'image':
transfer_method = FileTransferMethod.value_of(item.get('transfer_method'))
file = FileVar(
tenant_id=tenant_id,
type=FileType.IMAGE,
transfer_method=transfer_method,
url=item.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None,
related_id=item.get(
'upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None,
extra_config=FileExtraConfig(image_config={'detail': detail} if detail else None),
)
new_value.append(file)
if new_value:
value = new_value
# append variable and value to variable pool
variable_pool.add([variable_node_id] + variable_key_list, input_value)

View File

@ -0,0 +1,35 @@
"""add node_execution_id into node_executions
Revision ID: 675b5321501b
Revises: 030f4915f36a
Create Date: 2024-08-12 10:54:02.259331
"""
import sqlalchemy as sa
from alembic import op
import models as models
# revision identifiers, used by Alembic.
revision = '675b5321501b'
down_revision = '030f4915f36a'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op:
batch_op.add_column(sa.Column('node_execution_id', sa.String(length=255), nullable=True))
batch_op.create_index('workflow_node_execution_id_idx', ['tenant_id', 'app_id', 'workflow_id', 'triggered_from', 'node_execution_id'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op:
batch_op.drop_index('workflow_node_execution_id_idx')
batch_op.drop_column('node_execution_id')
# ### end Alembic commands ###

View File

@ -581,6 +581,8 @@ class WorkflowNodeExecution(db.Model):
'triggered_from', 'workflow_run_id'),
db.Index('workflow_node_execution_node_run_idx', 'tenant_id', 'app_id', 'workflow_id',
'triggered_from', 'node_id'),
db.Index('workflow_node_execution_id_idx', 'tenant_id', 'app_id', 'workflow_id',
'triggered_from', 'node_execution_id'),
)
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
@ -591,6 +593,7 @@ class WorkflowNodeExecution(db.Model):
workflow_run_id = db.Column(StringUUID)
index = db.Column(db.Integer, nullable=False)
predecessor_node_id = db.Column(db.String(255))
node_execution_id = db.Column(db.String(255), nullable=True)
node_id = db.Column(db.String(255), nullable=False)
node_type = db.Column(db.String(255), nullable=False)
title = db.Column(db.String(255), nullable=False)

View File

@ -13,8 +13,9 @@ from services.workflow_service import WorkflowService
logger = logging.getLogger(__name__)
current_dsl_version = "0.1.1"
current_dsl_version = "0.1.2"
dsl_to_dify_version_mapping: dict[str, str] = {
"0.1.2": "0.8.0",
"0.1.1": "0.6.0", # dsl version -> from dify version
}

View File

@ -12,6 +12,7 @@ from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.features.rate_limiting import RateLimit
from models.model import Account, App, AppMode, EndUser
from models.workflow import Workflow
from services.errors.llm import InvokeRateLimitError
from services.workflow_service import WorkflowService
@ -103,9 +104,7 @@ class AppGenerateService:
return max_active_requests
@classmethod
def generate_single_iteration(
cls, app_model: App, user: Union[Account, EndUser], node_id: str, args: Any, streaming: bool = True
):
def generate_single_iteration(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True):
if app_model.mode == AppMode.ADVANCED_CHAT.value:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator().single_iteration_generate(
@ -142,7 +141,7 @@ class AppGenerateService:
)
@classmethod
def _get_workflow(cls, app_model: App, invoke_from: InvokeFrom) -> Any:
def _get_workflow(cls, app_model: App, invoke_from: InvokeFrom) -> Workflow:
"""
Get workflow
:param app_model: app model

View File

@ -8,9 +8,11 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.app.segments import Variable
from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.entities.node_entities import NodeType
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.workflow_engine_manager import WorkflowEngineManager
from core.workflow.nodes.event import RunCompletedEvent
from core.workflow.nodes.node_mapping import node_classes
from core.workflow.workflow_entry import WorkflowEntry
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
from extensions.ext_database import db
from models.account import Account
@ -172,8 +174,13 @@ class WorkflowService:
Get default block configs
"""
# return default block config
workflow_engine_manager = WorkflowEngineManager()
return workflow_engine_manager.get_default_configs()
default_block_configs = []
for node_type, node_class in node_classes.items():
default_config = node_class.get_default_config()
if default_config:
default_block_configs.append(default_config)
return default_block_configs
def get_default_block_config(self, node_type: str, filters: Optional[dict] = None) -> Optional[dict]:
"""
@ -182,11 +189,18 @@ class WorkflowService:
:param filters: filter by node config parameters.
:return:
"""
node_type = NodeType.value_of(node_type)
node_type_enum: NodeType = NodeType.value_of(node_type)
# return default block config
workflow_engine_manager = WorkflowEngineManager()
return workflow_engine_manager.get_default_config(node_type, filters)
node_class = node_classes.get(node_type_enum)
if not node_class:
return None
default_config = node_class.get_default_config(filters=filters)
if not default_config:
return None
return default_config
def run_draft_workflow_node(
self, app_model: App, node_id: str, user_inputs: dict, account: Account
@ -200,82 +214,68 @@ class WorkflowService:
raise ValueError("Workflow not initialized")
# run draft workflow node
workflow_engine_manager = WorkflowEngineManager()
start_at = time.perf_counter()
try:
node_instance, node_run_result = workflow_engine_manager.single_step_run_workflow_node(
node_instance, generator = WorkflowEntry.single_step_run(
workflow=draft_workflow,
node_id=node_id,
user_inputs=user_inputs,
user_id=account.id,
)
node_run_result: NodeRunResult | None = None
for event in generator:
if isinstance(event, RunCompletedEvent):
node_run_result = event.run_result
# sign output files
node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs)
break
if not node_run_result:
raise ValueError("Node run failed with no run result")
run_succeeded = True if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED else False
error = node_run_result.error if not run_succeeded else None
except WorkflowNodeRunFailedError as e:
workflow_node_execution = WorkflowNodeExecution(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
workflow_id=draft_workflow.id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value,
index=1,
node_id=e.node_id,
node_type=e.node_type.value,
title=e.node_title,
status=WorkflowNodeExecutionStatus.FAILED.value,
error=e.error,
elapsed_time=time.perf_counter() - start_at,
created_by_role=CreatedByRole.ACCOUNT.value,
created_by=account.id,
created_at=datetime.now(timezone.utc).replace(tzinfo=None),
finished_at=datetime.now(timezone.utc).replace(tzinfo=None),
)
db.session.add(workflow_node_execution)
db.session.commit()
node_instance = e.node_instance
run_succeeded = False
node_run_result = None
error = e.error
return workflow_node_execution
workflow_node_execution = WorkflowNodeExecution()
workflow_node_execution.tenant_id = app_model.tenant_id
workflow_node_execution.app_id = app_model.id
workflow_node_execution.workflow_id = draft_workflow.id
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value
workflow_node_execution.index = 1
workflow_node_execution.node_id = node_id
workflow_node_execution.node_type = node_instance.node_type.value
workflow_node_execution.title = node_instance.node_data.title
workflow_node_execution.elapsed_time = time.perf_counter() - start_at
workflow_node_execution.created_by_role = CreatedByRole.ACCOUNT.value
workflow_node_execution.created_by = account.id
workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None)
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
if run_succeeded and node_run_result:
# create workflow node execution
workflow_node_execution = WorkflowNodeExecution(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
workflow_id=draft_workflow.id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value,
index=1,
node_id=node_id,
node_type=node_instance.node_type.value,
title=node_instance.node_data.title,
inputs=json.dumps(node_run_result.inputs) if node_run_result.inputs else None,
process_data=json.dumps(node_run_result.process_data) if node_run_result.process_data else None,
outputs=json.dumps(jsonable_encoder(node_run_result.outputs)) if node_run_result.outputs else None,
execution_metadata=(
json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None
),
status=WorkflowNodeExecutionStatus.SUCCEEDED.value,
elapsed_time=time.perf_counter() - start_at,
created_by_role=CreatedByRole.ACCOUNT.value,
created_by=account.id,
created_at=datetime.now(timezone.utc).replace(tzinfo=None),
finished_at=datetime.now(timezone.utc).replace(tzinfo=None),
workflow_node_execution.inputs = json.dumps(node_run_result.inputs) if node_run_result.inputs else None
workflow_node_execution.process_data = (
json.dumps(node_run_result.process_data) if node_run_result.process_data else None
)
workflow_node_execution.outputs = (
json.dumps(jsonable_encoder(node_run_result.outputs)) if node_run_result.outputs else None
)
workflow_node_execution.execution_metadata = (
json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None
)
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
else:
# create workflow node execution
workflow_node_execution = WorkflowNodeExecution(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
workflow_id=draft_workflow.id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value,
index=1,
node_id=node_id,
node_type=node_instance.node_type.value,
title=node_instance.node_data.title,
status=node_run_result.status.value,
error=node_run_result.error,
elapsed_time=time.perf_counter() - start_at,
created_by_role=CreatedByRole.ACCOUNT.value,
created_by=account.id,
created_at=datetime.now(timezone.utc).replace(tzinfo=None),
finished_at=datetime.now(timezone.utc).replace(tzinfo=None),
)
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.error = error
db.session.add(workflow_node_execution)
db.session.commit()
@ -321,25 +321,3 @@ class WorkflowService:
)
else:
raise ValueError(f"Invalid app mode: {app_model.mode}")
@classmethod
def get_elapsed_time(cls, workflow_run_id: str) -> float:
"""
Get elapsed time
"""
elapsed_time = 0.0
# fetch workflow node execution by workflow_run_id
workflow_nodes = (
db.session.query(WorkflowNodeExecution)
.filter(WorkflowNodeExecution.workflow_run_id == workflow_run_id)
.order_by(WorkflowNodeExecution.created_at.asc())
.all()
)
if not workflow_nodes:
return elapsed_time
for node in workflow_nodes:
elapsed_time += node.elapsed_time
return elapsed_time

View File

@ -1,17 +1,72 @@
import time
import uuid
from os import getenv
from typing import cast
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import NodeRunResult, UserFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import UserFrom
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.code.code_node import CodeNode
from models.workflow import WorkflowNodeExecutionStatus
from core.workflow.nodes.code.entities import CodeNodeData
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
CODE_MAX_STRING_LENGTH = int(getenv("CODE_MAX_STRING_LENGTH", "10000"))
def init_code_node(code_config: dict):
graph_config = {
"edges": [
{
"id": "start-source-code-target",
"source": "start",
"target": "code",
},
],
"nodes": [{"data": {"type": "start"}, "id": "start"}, code_config],
}
graph = Graph.init(graph_config=graph_config)
init_params = GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config=graph_config,
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
# construct variable pool
variable_pool = VariablePool(
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
user_inputs={},
environment_variables=[],
conversation_variables=[],
)
variable_pool.add(["code", "123", "args1"], 1)
variable_pool.add(["code", "123", "args2"], 2)
node = CodeNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config=code_config,
)
return node
@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
def test_execute_code(setup_code_executor_mock):
code = """
@ -22,44 +77,36 @@ def test_execute_code(setup_code_executor_mock):
"""
# trim first 4 spaces at the beginning of each line
code = "\n".join([line[4:] for line in code.split("\n")])
node = CodeNode(
tenant_id="1",
app_id="1",
workflow_id="1",
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.WEB_APP,
config={
"id": "1",
"data": {
"outputs": {
"result": {
"type": "number",
},
},
"title": "123",
"variables": [
{
"variable": "args1",
"value_selector": ["1", "123", "args1"],
},
{"variable": "args2", "value_selector": ["1", "123", "args2"]},
],
"answer": "123",
"code_language": "python3",
"code": code,
},
},
)
# construct variable pool
pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[])
pool.add(["1", "123", "args1"], 1)
pool.add(["1", "123", "args2"], 2)
code_config = {
"id": "code",
"data": {
"outputs": {
"result": {
"type": "number",
},
},
"title": "123",
"variables": [
{
"variable": "args1",
"value_selector": ["1", "123", "args1"],
},
{"variable": "args2", "value_selector": ["1", "123", "args2"]},
],
"answer": "123",
"code_language": "python3",
"code": code,
},
}
node = init_code_node(code_config)
# execute node
result = node.run(pool)
result = node._run()
assert isinstance(result, NodeRunResult)
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs is not None
assert result.outputs["result"] == 3
assert result.error is None
@ -74,44 +121,34 @@ def test_execute_code_output_validator(setup_code_executor_mock):
"""
# trim first 4 spaces at the beginning of each line
code = "\n".join([line[4:] for line in code.split("\n")])
node = CodeNode(
tenant_id="1",
app_id="1",
workflow_id="1",
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.WEB_APP,
config={
"id": "1",
"data": {
"outputs": {
"result": {
"type": "string",
},
},
"title": "123",
"variables": [
{
"variable": "args1",
"value_selector": ["1", "123", "args1"],
},
{"variable": "args2", "value_selector": ["1", "123", "args2"]},
],
"answer": "123",
"code_language": "python3",
"code": code,
},
},
)
# construct variable pool
pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[])
pool.add(["1", "123", "args1"], 1)
pool.add(["1", "123", "args2"], 2)
code_config = {
"id": "code",
"data": {
"outputs": {
"result": {
"type": "string",
},
},
"title": "123",
"variables": [
{
"variable": "args1",
"value_selector": ["1", "123", "args1"],
},
{"variable": "args2", "value_selector": ["1", "123", "args2"]},
],
"answer": "123",
"code_language": "python3",
"code": code,
},
}
node = init_code_node(code_config)
# execute node
result = node.run(pool)
result = node._run()
assert isinstance(result, NodeRunResult)
assert result.status == WorkflowNodeExecutionStatus.FAILED
assert result.error == "Output variable `result` must be a string"
@ -127,65 +164,60 @@ def test_execute_code_output_validator_depth():
"""
# trim first 4 spaces at the beginning of each line
code = "\n".join([line[4:] for line in code.split("\n")])
node = CodeNode(
tenant_id="1",
app_id="1",
workflow_id="1",
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.WEB_APP,
config={
"id": "1",
"data": {
"outputs": {
"string_validator": {
"type": "string",
},
"number_validator": {
"type": "number",
},
"number_array_validator": {
"type": "array[number]",
},
"string_array_validator": {
"type": "array[string]",
},
"object_validator": {
"type": "object",
"children": {
"result": {
"type": "number",
},
"depth": {
"type": "object",
"children": {
"depth": {
"type": "object",
"children": {
"depth": {
"type": "number",
}
},
}
},
code_config = {
"id": "code",
"data": {
"outputs": {
"string_validator": {
"type": "string",
},
"number_validator": {
"type": "number",
},
"number_array_validator": {
"type": "array[number]",
},
"string_array_validator": {
"type": "array[string]",
},
"object_validator": {
"type": "object",
"children": {
"result": {
"type": "number",
},
"depth": {
"type": "object",
"children": {
"depth": {
"type": "object",
"children": {
"depth": {
"type": "number",
}
},
}
},
},
},
},
"title": "123",
"variables": [
{
"variable": "args1",
"value_selector": ["1", "123", "args1"],
},
{"variable": "args2", "value_selector": ["1", "123", "args2"]},
],
"answer": "123",
"code_language": "python3",
"code": code,
},
"title": "123",
"variables": [
{
"variable": "args1",
"value_selector": ["1", "123", "args1"],
},
{"variable": "args2", "value_selector": ["1", "123", "args2"]},
],
"answer": "123",
"code_language": "python3",
"code": code,
},
)
}
node = init_code_node(code_config)
# construct result
result = {
@ -196,6 +228,8 @@ def test_execute_code_output_validator_depth():
"object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}},
}
node.node_data = cast(CodeNodeData, node.node_data)
# validate
node._transform_result(result, node.node_data.outputs)
@ -250,35 +284,30 @@ def test_execute_code_output_object_list():
"""
# trim first 4 spaces at the beginning of each line
code = "\n".join([line[4:] for line in code.split("\n")])
node = CodeNode(
tenant_id="1",
app_id="1",
workflow_id="1",
user_id="1",
invoke_from=InvokeFrom.WEB_APP,
user_from=UserFrom.ACCOUNT,
config={
"id": "1",
"data": {
"outputs": {
"object_list": {
"type": "array[object]",
},
code_config = {
"id": "code",
"data": {
"outputs": {
"object_list": {
"type": "array[object]",
},
"title": "123",
"variables": [
{
"variable": "args1",
"value_selector": ["1", "123", "args1"],
},
{"variable": "args2", "value_selector": ["1", "123", "args2"]},
],
"answer": "123",
"code_language": "python3",
"code": code,
},
"title": "123",
"variables": [
{
"variable": "args1",
"value_selector": ["1", "123", "args1"],
},
{"variable": "args2", "value_selector": ["1", "123", "args2"]},
],
"answer": "123",
"code_language": "python3",
"code": code,
},
)
}
node = init_code_node(code_config)
# construct result
result = {
@ -295,6 +324,8 @@ def test_execute_code_output_object_list():
]
}
node.node_data = cast(CodeNodeData, node.node_data)
# validate
node._transform_result(result, node.node_data.outputs)

View File

@ -1,31 +1,69 @@
import time
import uuid
from urllib.parse import urlencode
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import UserFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import UserFrom
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.http_request.http_request_node import HttpRequestNode
from models.workflow import WorkflowType
from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock
BASIC_NODE_DATA = {
"tenant_id": "1",
"app_id": "1",
"workflow_id": "1",
"user_id": "1",
"user_from": UserFrom.ACCOUNT,
"invoke_from": InvokeFrom.WEB_APP,
}
# construct variable pool
pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[])
pool.add(["a", "b123", "args1"], 1)
pool.add(["a", "b123", "args2"], 2)
def init_http_node(config: dict):
graph_config = {
"edges": [
{
"id": "start-source-next-target",
"source": "start",
"target": "1",
},
],
"nodes": [{"data": {"type": "start"}, "id": "start"}, config],
}
graph = Graph.init(graph_config=graph_config)
init_params = GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config=graph_config,
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
# construct variable pool
variable_pool = VariablePool(
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
user_inputs={},
environment_variables=[],
conversation_variables=[],
)
variable_pool.add(["a", "b123", "args1"], 1)
variable_pool.add(["a", "b123", "args2"], 2)
return HttpRequestNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config=config,
)
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
def test_get(setup_http_mock):
node = HttpRequestNode(
node = init_http_node(
config={
"id": "1",
"data": {
@ -45,12 +83,11 @@ def test_get(setup_http_mock):
"params": "A:b",
"body": None,
},
},
**BASIC_NODE_DATA,
}
)
result = node.run(pool)
result = node._run()
assert result.process_data is not None
data = result.process_data.get("request", "")
assert "?A=b" in data
@ -59,7 +96,7 @@ def test_get(setup_http_mock):
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
def test_no_auth(setup_http_mock):
node = HttpRequestNode(
node = init_http_node(
config={
"id": "1",
"data": {
@ -75,12 +112,11 @@ def test_no_auth(setup_http_mock):
"params": "A:b",
"body": None,
},
},
**BASIC_NODE_DATA,
}
)
result = node.run(pool)
result = node._run()
assert result.process_data is not None
data = result.process_data.get("request", "")
assert "?A=b" in data
@ -89,7 +125,7 @@ def test_no_auth(setup_http_mock):
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
def test_custom_authorization_header(setup_http_mock):
node = HttpRequestNode(
node = init_http_node(
config={
"id": "1",
"data": {
@ -109,12 +145,11 @@ def test_custom_authorization_header(setup_http_mock):
"params": "A:b",
"body": None,
},
},
**BASIC_NODE_DATA,
}
)
result = node.run(pool)
result = node._run()
assert result.process_data is not None
data = result.process_data.get("request", "")
assert "?A=b" in data
@ -123,7 +158,7 @@ def test_custom_authorization_header(setup_http_mock):
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
def test_template(setup_http_mock):
node = HttpRequestNode(
node = init_http_node(
config={
"id": "1",
"data": {
@ -143,11 +178,11 @@ def test_template(setup_http_mock):
"params": "A:b\nTemplate:{{#a.b123.args2#}}",
"body": None,
},
},
**BASIC_NODE_DATA,
}
)
result = node.run(pool)
result = node._run()
assert result.process_data is not None
data = result.process_data.get("request", "")
assert "?A=b" in data
@ -158,7 +193,7 @@ def test_template(setup_http_mock):
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
def test_json(setup_http_mock):
node = HttpRequestNode(
node = init_http_node(
config={
"id": "1",
"data": {
@ -178,11 +213,11 @@ def test_json(setup_http_mock):
"params": "A:b",
"body": {"type": "json", "data": '{"a": "{{#a.b123.args1#}}"}'},
},
},
**BASIC_NODE_DATA,
}
)
result = node.run(pool)
result = node._run()
assert result.process_data is not None
data = result.process_data.get("request", "")
assert '{"a": "1"}' in data
@ -190,7 +225,7 @@ def test_json(setup_http_mock):
def test_x_www_form_urlencoded(setup_http_mock):
node = HttpRequestNode(
node = init_http_node(
config={
"id": "1",
"data": {
@ -210,11 +245,11 @@ def test_x_www_form_urlencoded(setup_http_mock):
"params": "A:b",
"body": {"type": "x-www-form-urlencoded", "data": "a:{{#a.b123.args1#}}\nb:{{#a.b123.args2#}}"},
},
},
**BASIC_NODE_DATA,
}
)
result = node.run(pool)
result = node._run()
assert result.process_data is not None
data = result.process_data.get("request", "")
assert "a=1&b=2" in data
@ -222,7 +257,7 @@ def test_x_www_form_urlencoded(setup_http_mock):
def test_form_data(setup_http_mock):
node = HttpRequestNode(
node = init_http_node(
config={
"id": "1",
"data": {
@ -242,11 +277,11 @@ def test_form_data(setup_http_mock):
"params": "A:b",
"body": {"type": "form-data", "data": "a:{{#a.b123.args1#}}\nb:{{#a.b123.args2#}}"},
},
},
**BASIC_NODE_DATA,
}
)
result = node.run(pool)
result = node._run()
assert result.process_data is not None
data = result.process_data.get("request", "")
assert 'form-data; name="a"' in data
@ -257,7 +292,7 @@ def test_form_data(setup_http_mock):
def test_none_data(setup_http_mock):
node = HttpRequestNode(
node = init_http_node(
config={
"id": "1",
"data": {
@ -277,11 +312,11 @@ def test_none_data(setup_http_mock):
"params": "A:b",
"body": {"type": "none", "data": "123123123"},
},
},
**BASIC_NODE_DATA,
}
)
result = node.run(pool)
result = node._run()
assert result.process_data is not None
data = result.process_data.get("request", "")
assert "X-Header: 123" in data
@ -289,7 +324,7 @@ def test_none_data(setup_http_mock):
def test_mock_404(setup_http_mock):
node = HttpRequestNode(
node = init_http_node(
config={
"id": "1",
"data": {
@ -305,19 +340,19 @@ def test_mock_404(setup_http_mock):
"params": "",
"headers": "X-Header:123",
},
},
**BASIC_NODE_DATA,
}
)
result = node.run(pool)
result = node._run()
assert result.outputs is not None
resp = result.outputs
assert 404 == resp.get("status_code")
assert "Not Found" in resp.get("body")
assert "Not Found" in resp.get("body", "")
def test_multi_colons_parse(setup_http_mock):
node = HttpRequestNode(
node = init_http_node(
config={
"id": "1",
"data": {
@ -333,13 +368,14 @@ def test_multi_colons_parse(setup_http_mock):
"headers": "Referer:http://example3.com\nRedirect:http://example4.com",
"body": {"type": "form-data", "data": "Referer:http://example5.com\nRedirect:http://example6.com"},
},
},
**BASIC_NODE_DATA,
}
)
result = node.run(pool)
result = node._run()
assert result.process_data is not None
assert result.outputs is not None
resp = result.outputs
assert urlencode({"Redirect": "http://example2.com"}) in result.process_data.get("request")
assert 'form-data; name="Redirect"\n\nhttp://example6.com' in result.process_data.get("request")
assert "http://example3.com" == resp.get("headers").get("referer")
assert urlencode({"Redirect": "http://example2.com"}) in result.process_data.get("request", "")
assert 'form-data; name="Redirect"\n\nhttp://example6.com' in result.process_data.get("request", "")
assert "http://example3.com" == resp.get("headers", {}).get("referer")

View File

@ -1,5 +1,8 @@
import json
import os
import time
import uuid
from collections.abc import Generator
from unittest.mock import MagicMock
import pytest
@ -10,28 +13,77 @@ from core.entities.provider_entities import CustomConfiguration, CustomProviderC
from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers import ModelProviderFactory
from core.workflow.entities.node_entities import UserFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base_node import UserFrom
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.event import RunCompletedEvent
from core.workflow.nodes.llm.llm_node import LLMNode
from extensions.ext_database import db
from models.provider import ProviderType
from models.workflow import WorkflowNodeExecutionStatus
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
def test_execute_llm(setup_openai_mock):
node = LLMNode(
def init_llm_node(config: dict) -> LLMNode:
graph_config = {
"edges": [
{
"id": "start-source-next-target",
"source": "start",
"target": "llm",
},
],
"nodes": [{"data": {"type": "start"}, "id": "start"}, config],
}
graph = Graph.init(graph_config=graph_config)
init_params = GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config=graph_config,
user_id="1",
invoke_from=InvokeFrom.WEB_APP,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
# construct variable pool
variable_pool = VariablePool(
system_variables={
SystemVariableKey.QUERY: "what's the weather today?",
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: "abababa",
SystemVariableKey.USER_ID: "aaa",
},
user_inputs={},
environment_variables=[],
conversation_variables=[],
)
variable_pool.add(["abc", "output"], "sunny")
node = LLMNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config=config,
)
return node
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
def test_execute_llm(setup_openai_mock):
node = init_llm_node(
config={
"id": "llm",
"data": {
@ -49,19 +101,6 @@ def test_execute_llm(setup_openai_mock):
},
)
# construct variable pool
pool = VariablePool(
system_variables={
SystemVariableKey.QUERY: "what's the weather today?",
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: "abababa",
SystemVariableKey.USER_ID: "aaa",
},
user_inputs={},
environment_variables=[],
)
pool.add(["abc", "output"], "sunny")
credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")}
provider_instance = ModelProviderFactory().get_provider_instance("openai")
@ -80,13 +119,15 @@ def test_execute_llm(setup_openai_mock):
model_type_instance=model_type_instance,
)
model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model="gpt-3.5-turbo")
model_schema = model_type_instance.get_model_schema("gpt-3.5-turbo")
assert model_schema is not None
model_config = ModelConfigWithCredentialsEntity(
model="gpt-3.5-turbo",
provider="openai",
mode="chat",
credentials=credentials,
parameters={},
model_schema=model_type_instance.get_model_schema("gpt-3.5-turbo"),
model_schema=model_schema,
provider_model_bundle=provider_model_bundle,
)
@ -96,11 +137,16 @@ def test_execute_llm(setup_openai_mock):
node._fetch_model_config = MagicMock(return_value=(model_instance, model_config))
# execute node
result = node.run(pool)
result = node._run()
assert isinstance(result, Generator)
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs["text"] is not None
assert result.outputs["usage"]["total_tokens"] > 0
for item in result:
if isinstance(item, RunCompletedEvent):
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.process_data is not None
assert item.run_result.outputs is not None
assert item.run_result.outputs.get("text") is not None
assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0
@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
@ -109,13 +155,7 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock):
"""
Test execute LLM node with jinja2
"""
node = LLMNode(
tenant_id="1",
app_id="1",
workflow_id="1",
user_id="1",
invoke_from=InvokeFrom.WEB_APP,
user_from=UserFrom.ACCOUNT,
node = init_llm_node(
config={
"id": "llm",
"data": {
@ -149,19 +189,6 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock):
},
)
# construct variable pool
pool = VariablePool(
system_variables={
SystemVariableKey.QUERY: "what's the weather today?",
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: "abababa",
SystemVariableKey.USER_ID: "aaa",
},
user_inputs={},
environment_variables=[],
)
pool.add(["abc", "output"], "sunny")
credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")}
provider_instance = ModelProviderFactory().get_provider_instance("openai")
@ -181,14 +208,15 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock):
)
model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model="gpt-3.5-turbo")
model_schema = model_type_instance.get_model_schema("gpt-3.5-turbo")
assert model_schema is not None
model_config = ModelConfigWithCredentialsEntity(
model="gpt-3.5-turbo",
provider="openai",
mode="chat",
credentials=credentials,
parameters={},
model_schema=model_type_instance.get_model_schema("gpt-3.5-turbo"),
model_schema=model_schema,
provider_model_bundle=provider_model_bundle,
)
@ -198,8 +226,11 @@ def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock):
node._fetch_model_config = MagicMock(return_value=(model_instance, model_config))
# execute node
result = node.run(pool)
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert "sunny" in json.dumps(result.process_data)
assert "what's the weather today?" in json.dumps(result.process_data)
for item in result:
if isinstance(item, RunCompletedEvent):
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.process_data is not None
assert "sunny" in json.dumps(item.run_result.process_data)
assert "what's the weather today?" in json.dumps(item.run_result.process_data)

View File

@ -1,5 +1,7 @@
import json
import os
import time
import uuid
from typing import Optional
from unittest.mock import MagicMock
@ -8,19 +10,21 @@ import pytest
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.workflow.entities.node_entities import UserFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base_node import UserFrom
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from extensions.ext_database import db
from models.provider import ProviderType
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
from models.workflow import WorkflowNodeExecutionStatus
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
from tests.integration_tests.model_runtime.__mock.anthropic import setup_anthropic_mock
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
@ -47,13 +51,15 @@ def get_mocked_fetch_model_config(
model_type_instance=model_type_instance,
)
model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model=model)
model_schema = model_type_instance.get_model_schema(model)
assert model_schema is not None
model_config = ModelConfigWithCredentialsEntity(
model=model,
provider=provider,
mode=mode,
credentials=credentials,
parameters={},
model_schema=model_type_instance.get_model_schema(model),
model_schema=model_schema,
provider_model_bundle=provider_model_bundle,
)
@ -74,18 +80,62 @@ def get_mocked_fetch_memory(memory_text: str):
return MagicMock(return_value=MemoryMock())
def init_parameter_extractor_node(config: dict):
graph_config = {
"edges": [
{
"id": "start-source-next-target",
"source": "start",
"target": "llm",
},
],
"nodes": [{"data": {"type": "start"}, "id": "start"}, config],
}
graph = Graph.init(graph_config=graph_config)
init_params = GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config=graph_config,
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
# construct variable pool
variable_pool = VariablePool(
system_variables={
SystemVariableKey.QUERY: "what's the weather in SF",
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: "abababa",
SystemVariableKey.USER_ID: "aaa",
},
user_inputs={},
environment_variables=[],
conversation_variables=[],
)
variable_pool.add(["a", "b123", "args1"], 1)
variable_pool.add(["a", "b123", "args2"], 2)
return ParameterExtractorNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config=config,
)
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
def test_function_calling_parameter_extractor(setup_openai_mock):
"""
Test function calling for parameter extractor.
"""
node = ParameterExtractorNode(
tenant_id="1",
app_id="1",
workflow_id="1",
user_id="1",
invoke_from=InvokeFrom.WEB_APP,
user_from=UserFrom.ACCOUNT,
node = init_parameter_extractor_node(
config={
"id": "llm",
"data": {
@ -98,7 +148,7 @@ def test_function_calling_parameter_extractor(setup_openai_mock):
"reasoning_mode": "function_call",
"memory": None,
},
},
}
)
node._fetch_model_config = get_mocked_fetch_model_config(
@ -121,9 +171,10 @@ def test_function_calling_parameter_extractor(setup_openai_mock):
environment_variables=[],
)
result = node.run(pool)
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs is not None
assert result.outputs.get("location") == "kawaii"
assert result.outputs.get("__reason") == None
@ -133,13 +184,7 @@ def test_instructions(setup_openai_mock):
"""
Test chat parameter extractor.
"""
node = ParameterExtractorNode(
tenant_id="1",
app_id="1",
workflow_id="1",
user_id="1",
invoke_from=InvokeFrom.WEB_APP,
user_from=UserFrom.ACCOUNT,
node = init_parameter_extractor_node(
config={
"id": "llm",
"data": {
@ -163,29 +208,19 @@ def test_instructions(setup_openai_mock):
)
db.session.close = MagicMock()
# construct variable pool
pool = VariablePool(
system_variables={
SystemVariableKey.QUERY: "what's the weather in SF",
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: "abababa",
SystemVariableKey.USER_ID: "aaa",
},
user_inputs={},
environment_variables=[],
)
result = node.run(pool)
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs is not None
assert result.outputs.get("location") == "kawaii"
assert result.outputs.get("__reason") == None
process_data = result.process_data
assert process_data is not None
process_data.get("prompts")
for prompt in process_data.get("prompts"):
for prompt in process_data.get("prompts", []):
if prompt.get("role") == "system":
assert "what's the weather in SF" in prompt.get("text")
@ -195,13 +230,7 @@ def test_chat_parameter_extractor(setup_anthropic_mock):
"""
Test chat parameter extractor.
"""
node = ParameterExtractorNode(
tenant_id="1",
app_id="1",
workflow_id="1",
user_id="1",
invoke_from=InvokeFrom.WEB_APP,
user_from=UserFrom.ACCOUNT,
node = init_parameter_extractor_node(
config={
"id": "llm",
"data": {
@ -225,27 +254,17 @@ def test_chat_parameter_extractor(setup_anthropic_mock):
)
db.session.close = MagicMock()
# construct variable pool
pool = VariablePool(
system_variables={
SystemVariableKey.QUERY: "what's the weather in SF",
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: "abababa",
SystemVariableKey.USER_ID: "aaa",
},
user_inputs={},
environment_variables=[],
)
result = node.run(pool)
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs is not None
assert result.outputs.get("location") == ""
assert (
result.outputs.get("__reason")
== "Failed to extract result from function call or text response, using empty result."
)
prompts = result.process_data.get("prompts")
assert result.process_data is not None
prompts = result.process_data.get("prompts", [])
for prompt in prompts:
if prompt.get("role") == "user":
@ -258,13 +277,7 @@ def test_completion_parameter_extractor(setup_openai_mock):
"""
Test completion parameter extractor.
"""
node = ParameterExtractorNode(
tenant_id="1",
app_id="1",
workflow_id="1",
user_id="1",
invoke_from=InvokeFrom.WEB_APP,
user_from=UserFrom.ACCOUNT,
node = init_parameter_extractor_node(
config={
"id": "llm",
"data": {
@ -293,28 +306,18 @@ def test_completion_parameter_extractor(setup_openai_mock):
)
db.session.close = MagicMock()
# construct variable pool
pool = VariablePool(
system_variables={
SystemVariableKey.QUERY: "what's the weather in SF",
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: "abababa",
SystemVariableKey.USER_ID: "aaa",
},
user_inputs={},
environment_variables=[],
)
result = node.run(pool)
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs is not None
assert result.outputs.get("location") == ""
assert (
result.outputs.get("__reason")
== "Failed to extract result from function call or text response, using empty result."
)
assert len(result.process_data.get("prompts")) == 1
assert "SF" in result.process_data.get("prompts")[0].get("text")
assert result.process_data is not None
assert len(result.process_data.get("prompts", [])) == 1
assert "SF" in result.process_data.get("prompts", [])[0].get("text")
def test_extract_json_response():
@ -322,13 +325,7 @@ def test_extract_json_response():
Test extract json response.
"""
node = ParameterExtractorNode(
tenant_id="1",
app_id="1",
workflow_id="1",
user_id="1",
invoke_from=InvokeFrom.WEB_APP,
user_from=UserFrom.ACCOUNT,
node = init_parameter_extractor_node(
config={
"id": "llm",
"data": {
@ -357,6 +354,7 @@ def test_extract_json_response():
hello world.
""")
assert result is not None
assert result["location"] == "kawaii"
@ -365,13 +363,7 @@ def test_chat_parameter_extractor_with_memory(setup_anthropic_mock):
"""
Test chat parameter extractor with memory.
"""
node = ParameterExtractorNode(
tenant_id="1",
app_id="1",
workflow_id="1",
user_id="1",
invoke_from=InvokeFrom.WEB_APP,
user_from=UserFrom.ACCOUNT,
node = init_parameter_extractor_node(
config={
"id": "llm",
"data": {
@ -396,27 +388,17 @@ def test_chat_parameter_extractor_with_memory(setup_anthropic_mock):
node._fetch_memory = get_mocked_fetch_memory("customized memory")
db.session.close = MagicMock()
# construct variable pool
pool = VariablePool(
system_variables={
SystemVariableKey.QUERY: "what's the weather in SF",
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: "abababa",
SystemVariableKey.USER_ID: "aaa",
},
user_inputs={},
environment_variables=[],
)
result = node.run(pool)
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs is not None
assert result.outputs.get("location") == ""
assert (
result.outputs.get("__reason")
== "Failed to extract result from function call or text response, using empty result."
)
prompts = result.process_data.get("prompts")
assert result.process_data is not None
prompts = result.process_data.get("prompts", [])
latest_role = None
for prompt in prompts:

View File

@ -1,46 +1,84 @@
import time
import uuid
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import UserFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import UserFrom
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from models.workflow import WorkflowNodeExecutionStatus
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
def test_execute_code(setup_code_executor_mock):
code = """{{args2}}"""
node = TemplateTransformNode(
config = {
"id": "1",
"data": {
"title": "123",
"variables": [
{
"variable": "args1",
"value_selector": ["1", "123", "args1"],
},
{"variable": "args2", "value_selector": ["1", "123", "args2"]},
],
"template": code,
},
}
graph_config = {
"edges": [
{
"id": "start-source-next-target",
"source": "start",
"target": "1",
},
],
"nodes": [{"data": {"type": "start"}, "id": "start"}, config],
}
graph = Graph.init(graph_config=graph_config)
init_params = GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config=graph_config,
user_id="1",
invoke_from=InvokeFrom.WEB_APP,
user_from=UserFrom.END_USER,
config={
"id": "1",
"data": {
"title": "123",
"variables": [
{
"variable": "args1",
"value_selector": ["1", "123", "args1"],
},
{"variable": "args2", "value_selector": ["1", "123", "args2"]},
],
"template": code,
},
},
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
# construct variable pool
pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[])
pool.add(["1", "123", "args1"], 1)
pool.add(["1", "123", "args2"], 3)
variable_pool = VariablePool(
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
user_inputs={},
environment_variables=[],
conversation_variables=[],
)
variable_pool.add(["1", "123", "args1"], 1)
variable_pool.add(["1", "123", "args2"], 3)
node = TemplateTransformNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config=config,
)
# execute node
result = node.run(pool)
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs is not None
assert result.outputs["output"] == "3"

View File

@ -1,21 +1,62 @@
import time
import uuid
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import NodeRunResult, UserFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import UserFrom
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.tool.tool_node import ToolNode
from models.workflow import WorkflowNodeExecutionStatus
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
def init_tool_node(config: dict):
graph_config = {
"edges": [
{
"id": "start-source-next-target",
"source": "start",
"target": "1",
},
],
"nodes": [{"data": {"type": "start"}, "id": "start"}, config],
}
graph = Graph.init(graph_config=graph_config)
init_params = GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config=graph_config,
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
# construct variable pool
variable_pool = VariablePool(
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
user_inputs={},
environment_variables=[],
conversation_variables=[],
)
return ToolNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config=config,
)
def test_tool_variable_invoke():
pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[])
pool.add(["1", "123", "args1"], "1+1")
node = ToolNode(
tenant_id="1",
app_id="1",
workflow_id="1",
user_id="1",
invoke_from=InvokeFrom.WEB_APP,
user_from=UserFrom.ACCOUNT,
node = init_tool_node(
config={
"id": "1",
"data": {
@ -34,28 +75,22 @@ def test_tool_variable_invoke():
}
},
},
},
}
)
# execute node
result = node.run(pool)
node.graph_runtime_state.variable_pool.add(["1", "123", "args1"], "1+1")
# execute node
result = node._run()
assert isinstance(result, NodeRunResult)
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs is not None
assert "2" in result.outputs["text"]
assert result.outputs["files"] == []
def test_tool_mixed_invoke():
pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[])
pool.add(["1", "args1"], "1+1")
node = ToolNode(
tenant_id="1",
app_id="1",
workflow_id="1",
user_id="1",
invoke_from=InvokeFrom.WEB_APP,
user_from=UserFrom.ACCOUNT,
node = init_tool_node(
config={
"id": "1",
"data": {
@ -74,12 +109,15 @@ def test_tool_mixed_invoke():
}
},
},
},
}
)
# execute node
result = node.run(pool)
node.graph_runtime_state.variable_pool.add(["1", "args1"], "1+1")
# execute node
result = node._run()
assert isinstance(result, NodeRunResult)
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs is not None
assert "2" in result.outputs["text"]
assert result.outputs["files"] == []

View File

@ -1,7 +1,24 @@
import os
import pytest
from flask import Flask
# Getting the absolute path of the current file's directory
ABS_PATH = os.path.dirname(os.path.abspath(__file__))
# Getting the absolute path of the project's root directory
PROJECT_DIR = os.path.abspath(os.path.join(ABS_PATH, os.pardir, os.pardir))
CACHED_APP = Flask(__name__)
CACHED_APP.config.update({"TESTING": True})
@pytest.fixture()
def app() -> Flask:
return CACHED_APP
@pytest.fixture(autouse=True)
def _provide_app_context(app: Flask):
with app.app_context():
yield

View File

@ -0,0 +1,791 @@
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.run_condition import RunCondition
from core.workflow.utils.condition.entities import Condition
def test_init():
graph_config = {
"edges": [
{
"id": "llm-source-answer-target",
"source": "llm",
"target": "answer",
},
{
"id": "start-source-qc-target",
"source": "start",
"target": "qc",
},
{
"id": "qc-1-llm-target",
"source": "qc",
"sourceHandle": "1",
"target": "llm",
},
{
"id": "qc-2-http-target",
"source": "qc",
"sourceHandle": "2",
"target": "http",
},
{
"id": "http-source-answer2-target",
"source": "http",
"target": "answer2",
},
],
"nodes": [
{"data": {"type": "start"}, "id": "start"},
{
"data": {
"type": "llm",
},
"id": "llm",
},
{
"data": {"type": "answer", "title": "answer", "answer": "1"},
"id": "answer",
},
{
"data": {"type": "question-classifier"},
"id": "qc",
},
{
"data": {
"type": "http-request",
},
"id": "http",
},
{
"data": {"type": "answer", "title": "answer", "answer": "1"},
"id": "answer2",
},
],
}
graph = Graph.init(graph_config=graph_config)
start_node_id = "start"
assert graph.root_node_id == start_node_id
assert graph.edge_mapping.get(start_node_id)[0].target_node_id == "qc"
assert {"llm", "http"} == {node.target_node_id for node in graph.edge_mapping.get("qc")}
def test__init_iteration_graph():
graph_config = {
"edges": [
{
"id": "llm-answer",
"source": "llm",
"sourceHandle": "source",
"target": "answer",
},
{
"id": "iteration-source-llm-target",
"source": "iteration",
"sourceHandle": "source",
"target": "llm",
},
{
"id": "template-transform-in-iteration-source-llm-in-iteration-target",
"source": "template-transform-in-iteration",
"sourceHandle": "source",
"target": "llm-in-iteration",
},
{
"id": "llm-in-iteration-source-answer-in-iteration-target",
"source": "llm-in-iteration",
"sourceHandle": "source",
"target": "answer-in-iteration",
},
{
"id": "start-source-code-target",
"source": "start",
"sourceHandle": "source",
"target": "code",
},
{
"id": "code-source-iteration-target",
"source": "code",
"sourceHandle": "source",
"target": "iteration",
},
],
"nodes": [
{
"data": {
"type": "start",
},
"id": "start",
},
{
"data": {
"type": "llm",
},
"id": "llm",
},
{
"data": {"type": "answer", "title": "answer", "answer": "1"},
"id": "answer",
},
{
"data": {"type": "iteration"},
"id": "iteration",
},
{
"data": {
"type": "template-transform",
},
"id": "template-transform-in-iteration",
"parentId": "iteration",
},
{
"data": {
"type": "llm",
},
"id": "llm-in-iteration",
"parentId": "iteration",
},
{
"data": {"type": "answer", "title": "answer", "answer": "1"},
"id": "answer-in-iteration",
"parentId": "iteration",
},
{
"data": {
"type": "code",
},
"id": "code",
},
],
}
graph = Graph.init(graph_config=graph_config, root_node_id="template-transform-in-iteration")
graph.add_extra_edge(
source_node_id="answer-in-iteration",
target_node_id="template-transform-in-iteration",
run_condition=RunCondition(
type="condition",
conditions=[Condition(variable_selector=["iteration", "index"], comparison_operator="", value="5")],
),
)
# iteration:
# [template-transform-in-iteration -> llm-in-iteration -> answer-in-iteration]
assert graph.root_node_id == "template-transform-in-iteration"
assert graph.edge_mapping.get("template-transform-in-iteration")[0].target_node_id == "llm-in-iteration"
assert graph.edge_mapping.get("llm-in-iteration")[0].target_node_id == "answer-in-iteration"
assert graph.edge_mapping.get("answer-in-iteration")[0].target_node_id == "template-transform-in-iteration"
def test_parallels_graph():
graph_config = {
"edges": [
{
"id": "start-source-llm1-target",
"source": "start",
"target": "llm1",
},
{
"id": "start-source-llm2-target",
"source": "start",
"target": "llm2",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm3",
},
{
"id": "llm1-source-answer-target",
"source": "llm1",
"target": "answer",
},
{
"id": "llm2-source-answer-target",
"source": "llm2",
"target": "answer",
},
{
"id": "llm3-source-answer-target",
"source": "llm3",
"target": "answer",
},
],
"nodes": [
{"data": {"type": "start"}, "id": "start"},
{
"data": {
"type": "llm",
},
"id": "llm1",
},
{
"data": {
"type": "llm",
},
"id": "llm2",
},
{
"data": {
"type": "llm",
},
"id": "llm3",
},
{
"data": {"type": "answer", "title": "answer", "answer": "1"},
"id": "answer",
},
],
}
graph = Graph.init(graph_config=graph_config)
assert graph.root_node_id == "start"
for i in range(3):
start_edges = graph.edge_mapping.get("start")
assert start_edges is not None
assert start_edges[i].target_node_id == f"llm{i+1}"
llm_edges = graph.edge_mapping.get(f"llm{i+1}")
assert llm_edges is not None
assert llm_edges[0].target_node_id == "answer"
assert len(graph.parallel_mapping) == 1
assert len(graph.node_parallel_mapping) == 3
for node_id in ["llm1", "llm2", "llm3"]:
assert node_id in graph.node_parallel_mapping
def test_parallels_graph2():
graph_config = {
"edges": [
{
"id": "start-source-llm1-target",
"source": "start",
"target": "llm1",
},
{
"id": "start-source-llm2-target",
"source": "start",
"target": "llm2",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm3",
},
{
"id": "llm1-source-answer-target",
"source": "llm1",
"target": "answer",
},
{
"id": "llm2-source-answer-target",
"source": "llm2",
"target": "answer",
},
],
"nodes": [
{"data": {"type": "start"}, "id": "start"},
{
"data": {
"type": "llm",
},
"id": "llm1",
},
{
"data": {
"type": "llm",
},
"id": "llm2",
},
{
"data": {
"type": "llm",
},
"id": "llm3",
},
{
"data": {"type": "answer", "title": "answer", "answer": "1"},
"id": "answer",
},
],
}
graph = Graph.init(graph_config=graph_config)
assert graph.root_node_id == "start"
for i in range(3):
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
if i < 2:
assert graph.edge_mapping.get(f"llm{i + 1}") is not None
assert graph.edge_mapping.get(f"llm{i + 1}")[0].target_node_id == "answer"
assert len(graph.parallel_mapping) == 1
assert len(graph.node_parallel_mapping) == 3
for node_id in ["llm1", "llm2", "llm3"]:
assert node_id in graph.node_parallel_mapping
def test_parallels_graph3():
graph_config = {
"edges": [
{
"id": "start-source-llm1-target",
"source": "start",
"target": "llm1",
},
{
"id": "start-source-llm2-target",
"source": "start",
"target": "llm2",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm3",
},
],
"nodes": [
{"data": {"type": "start"}, "id": "start"},
{
"data": {
"type": "llm",
},
"id": "llm1",
},
{
"data": {
"type": "llm",
},
"id": "llm2",
},
{
"data": {
"type": "llm",
},
"id": "llm3",
},
{
"data": {"type": "answer", "title": "answer", "answer": "1"},
"id": "answer",
},
],
}
graph = Graph.init(graph_config=graph_config)
assert graph.root_node_id == "start"
for i in range(3):
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
assert len(graph.parallel_mapping) == 1
assert len(graph.node_parallel_mapping) == 3
for node_id in ["llm1", "llm2", "llm3"]:
assert node_id in graph.node_parallel_mapping
def test_parallels_graph4():
graph_config = {
"edges": [
{
"id": "start-source-llm1-target",
"source": "start",
"target": "llm1",
},
{
"id": "start-source-llm2-target",
"source": "start",
"target": "llm2",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm3",
},
{
"id": "llm1-source-answer-target",
"source": "llm1",
"target": "code1",
},
{
"id": "llm2-source-answer-target",
"source": "llm2",
"target": "code2",
},
{
"id": "llm3-source-code3-target",
"source": "llm3",
"target": "code3",
},
{
"id": "code1-source-answer-target",
"source": "code1",
"target": "answer",
},
{
"id": "code2-source-answer-target",
"source": "code2",
"target": "answer",
},
{
"id": "code3-source-answer-target",
"source": "code3",
"target": "answer",
},
],
"nodes": [
{"data": {"type": "start"}, "id": "start"},
{
"data": {
"type": "llm",
},
"id": "llm1",
},
{
"data": {
"type": "code",
},
"id": "code1",
},
{
"data": {
"type": "llm",
},
"id": "llm2",
},
{
"data": {
"type": "code",
},
"id": "code2",
},
{
"data": {
"type": "llm",
},
"id": "llm3",
},
{
"data": {
"type": "code",
},
"id": "code3",
},
{
"data": {"type": "answer", "title": "answer", "answer": "1"},
"id": "answer",
},
],
}
graph = Graph.init(graph_config=graph_config)
assert graph.root_node_id == "start"
for i in range(3):
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
assert graph.edge_mapping.get(f"llm{i + 1}") is not None
assert graph.edge_mapping.get(f"llm{i + 1}")[0].target_node_id == f"code{i + 1}"
assert graph.edge_mapping.get(f"code{i + 1}") is not None
assert graph.edge_mapping.get(f"code{i + 1}")[0].target_node_id == "answer"
assert len(graph.parallel_mapping) == 1
assert len(graph.node_parallel_mapping) == 6
for node_id in ["llm1", "llm2", "llm3", "code1", "code2", "code3"]:
assert node_id in graph.node_parallel_mapping
def test_parallels_graph5():
graph_config = {
"edges": [
{
"id": "start-source-llm1-target",
"source": "start",
"target": "llm1",
},
{
"id": "start-source-llm2-target",
"source": "start",
"target": "llm2",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm3",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm4",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm5",
},
{
"id": "llm1-source-code1-target",
"source": "llm1",
"target": "code1",
},
{
"id": "llm2-source-code1-target",
"source": "llm2",
"target": "code1",
},
{
"id": "llm3-source-code2-target",
"source": "llm3",
"target": "code2",
},
{
"id": "llm4-source-code2-target",
"source": "llm4",
"target": "code2",
},
{
"id": "llm5-source-code3-target",
"source": "llm5",
"target": "code3",
},
{
"id": "code1-source-answer-target",
"source": "code1",
"target": "answer",
},
{
"id": "code2-source-answer-target",
"source": "code2",
"target": "answer",
},
],
"nodes": [
{"data": {"type": "start"}, "id": "start"},
{
"data": {
"type": "llm",
},
"id": "llm1",
},
{
"data": {
"type": "code",
},
"id": "code1",
},
{
"data": {
"type": "llm",
},
"id": "llm2",
},
{
"data": {
"type": "code",
},
"id": "code2",
},
{
"data": {
"type": "llm",
},
"id": "llm3",
},
{
"data": {
"type": "code",
},
"id": "code3",
},
{
"data": {"type": "answer", "title": "answer", "answer": "1"},
"id": "answer",
},
{
"data": {
"type": "llm",
},
"id": "llm4",
},
{
"data": {
"type": "llm",
},
"id": "llm5",
},
],
}
graph = Graph.init(graph_config=graph_config)
assert graph.root_node_id == "start"
for i in range(5):
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
assert graph.edge_mapping.get("llm1") is not None
assert graph.edge_mapping.get("llm1")[0].target_node_id == "code1"
assert graph.edge_mapping.get("llm2") is not None
assert graph.edge_mapping.get("llm2")[0].target_node_id == "code1"
assert graph.edge_mapping.get("llm3") is not None
assert graph.edge_mapping.get("llm3")[0].target_node_id == "code2"
assert graph.edge_mapping.get("llm4") is not None
assert graph.edge_mapping.get("llm4")[0].target_node_id == "code2"
assert graph.edge_mapping.get("llm5") is not None
assert graph.edge_mapping.get("llm5")[0].target_node_id == "code3"
assert graph.edge_mapping.get("code1") is not None
assert graph.edge_mapping.get("code1")[0].target_node_id == "answer"
assert graph.edge_mapping.get("code2") is not None
assert graph.edge_mapping.get("code2")[0].target_node_id == "answer"
assert len(graph.parallel_mapping) == 1
assert len(graph.node_parallel_mapping) == 8
for node_id in ["llm1", "llm2", "llm3", "llm4", "llm5", "code1", "code2", "code3"]:
assert node_id in graph.node_parallel_mapping
def test_parallels_graph6():
graph_config = {
"edges": [
{
"id": "start-source-llm1-target",
"source": "start",
"target": "llm1",
},
{
"id": "start-source-llm2-target",
"source": "start",
"target": "llm2",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm3",
},
{
"id": "llm1-source-code1-target",
"source": "llm1",
"target": "code1",
},
{
"id": "llm1-source-code2-target",
"source": "llm1",
"target": "code2",
},
{
"id": "llm2-source-code3-target",
"source": "llm2",
"target": "code3",
},
{
"id": "code1-source-answer-target",
"source": "code1",
"target": "answer",
},
{
"id": "code2-source-answer-target",
"source": "code2",
"target": "answer",
},
{
"id": "code3-source-answer-target",
"source": "code3",
"target": "answer",
},
{
"id": "llm3-source-answer-target",
"source": "llm3",
"target": "answer",
},
],
"nodes": [
{"data": {"type": "start"}, "id": "start"},
{
"data": {
"type": "llm",
},
"id": "llm1",
},
{
"data": {
"type": "code",
},
"id": "code1",
},
{
"data": {
"type": "llm",
},
"id": "llm2",
},
{
"data": {
"type": "code",
},
"id": "code2",
},
{
"data": {
"type": "llm",
},
"id": "llm3",
},
{
"data": {
"type": "code",
},
"id": "code3",
},
{
"data": {"type": "answer", "title": "answer", "answer": "1"},
"id": "answer",
},
],
}
graph = Graph.init(graph_config=graph_config)
assert graph.root_node_id == "start"
for i in range(3):
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
assert graph.edge_mapping.get("llm1") is not None
assert graph.edge_mapping.get("llm1")[0].target_node_id == "code1"
assert graph.edge_mapping.get("llm1") is not None
assert graph.edge_mapping.get("llm1")[1].target_node_id == "code2"
assert graph.edge_mapping.get("llm2") is not None
assert graph.edge_mapping.get("llm2")[0].target_node_id == "code3"
assert graph.edge_mapping.get("code1") is not None
assert graph.edge_mapping.get("code1")[0].target_node_id == "answer"
assert graph.edge_mapping.get("code2") is not None
assert graph.edge_mapping.get("code2")[0].target_node_id == "answer"
assert graph.edge_mapping.get("code3") is not None
assert graph.edge_mapping.get("code3")[0].target_node_id == "answer"
assert len(graph.parallel_mapping) == 2
assert len(graph.node_parallel_mapping) == 6
for node_id in ["llm1", "llm2", "llm3", "code1", "code2", "code3"]:
assert node_id in graph.node_parallel_mapping
parent_parallel = None
child_parallel = None
for p_id, parallel in graph.parallel_mapping.items():
if parallel.parent_parallel_id is None:
parent_parallel = parallel
else:
child_parallel = parallel
for node_id in ["llm1", "llm2", "llm3", "code3"]:
assert graph.node_parallel_mapping[node_id] == parent_parallel.id
for node_id in ["code1", "code2"]:
assert graph.node_parallel_mapping[node_id] == child_parallel.id

View File

@ -0,0 +1,505 @@
from unittest.mock import patch
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, UserFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import (
BaseNodeEvent,
GraphRunFailedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
NodeRunFailedEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.nodes.llm.llm_node import LLMNode
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
@patch("extensions.ext_database.db.session.remove")
@patch("extensions.ext_database.db.session.close")
def test_run_parallel_in_workflow(mock_close, mock_remove):
graph_config = {
"edges": [
{
"id": "1",
"source": "start",
"target": "llm1",
},
{
"id": "2",
"source": "llm1",
"target": "llm2",
},
{
"id": "3",
"source": "llm1",
"target": "llm3",
},
{
"id": "4",
"source": "llm2",
"target": "end1",
},
{
"id": "5",
"source": "llm3",
"target": "end2",
},
],
"nodes": [
{
"data": {
"type": "start",
"title": "start",
"variables": [
{
"label": "query",
"max_length": 48,
"options": [],
"required": True,
"type": "text-input",
"variable": "query",
}
],
},
"id": "start",
},
{
"data": {
"type": "llm",
"title": "llm1",
"context": {"enabled": False, "variable_selector": []},
"model": {
"completion_params": {"temperature": 0.7},
"mode": "chat",
"name": "gpt-4o",
"provider": "openai",
},
"prompt_template": [
{"role": "system", "text": "say hi"},
{"role": "user", "text": "{{#start.query#}}"},
],
"vision": {"configs": {"detail": "high"}, "enabled": False},
},
"id": "llm1",
},
{
"data": {
"type": "llm",
"title": "llm2",
"context": {"enabled": False, "variable_selector": []},
"model": {
"completion_params": {"temperature": 0.7},
"mode": "chat",
"name": "gpt-4o",
"provider": "openai",
},
"prompt_template": [
{"role": "system", "text": "say bye"},
{"role": "user", "text": "{{#start.query#}}"},
],
"vision": {"configs": {"detail": "high"}, "enabled": False},
},
"id": "llm2",
},
{
"data": {
"type": "llm",
"title": "llm3",
"context": {"enabled": False, "variable_selector": []},
"model": {
"completion_params": {"temperature": 0.7},
"mode": "chat",
"name": "gpt-4o",
"provider": "openai",
},
"prompt_template": [
{"role": "system", "text": "say good morning"},
{"role": "user", "text": "{{#start.query#}}"},
],
"vision": {"configs": {"detail": "high"}, "enabled": False},
},
"id": "llm3",
},
{
"data": {
"type": "end",
"title": "end1",
"outputs": [
{"value_selector": ["llm2", "text"], "variable": "result2"},
{"value_selector": ["start", "query"], "variable": "query"},
],
},
"id": "end1",
},
{
"data": {
"type": "end",
"title": "end2",
"outputs": [
{"value_selector": ["llm1", "text"], "variable": "result1"},
{"value_selector": ["llm3", "text"], "variable": "result3"},
],
},
"id": "end2",
},
],
}
graph = Graph.init(graph_config=graph_config)
variable_pool = VariablePool(
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={"query": "hi"}
)
graph_engine = GraphEngine(
tenant_id="111",
app_id="222",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="333",
graph_config=graph_config,
user_id="444",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.WEB_APP,
call_depth=0,
graph=graph,
variable_pool=variable_pool,
max_execution_steps=500,
max_execution_time=1200,
)
def llm_generator(self):
contents = ["hi", "bye", "good morning"]
yield RunStreamChunkEvent(
chunk_content=contents[int(self.node_id[-1]) - 1], from_variable_selector=[self.node_id, "text"]
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={},
process_data={},
outputs={},
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: 1,
NodeRunMetadataKey.TOTAL_PRICE: 1,
NodeRunMetadataKey.CURRENCY: "USD",
},
)
)
# print("")
with patch.object(LLMNode, "_run", new=llm_generator):
items = []
generator = graph_engine.run()
for item in generator:
# print(type(item), item)
items.append(item)
if isinstance(item, NodeRunSucceededEvent):
assert item.route_node_state.status == RouteNodeState.Status.SUCCESS
assert not isinstance(item, NodeRunFailedEvent)
assert not isinstance(item, GraphRunFailedEvent)
if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in ["llm2", "llm3", "end1", "end2"]:
assert item.parallel_id is not None
assert len(items) == 18
assert isinstance(items[0], GraphRunStartedEvent)
assert isinstance(items[1], NodeRunStartedEvent)
assert items[1].route_node_state.node_id == "start"
assert isinstance(items[2], NodeRunSucceededEvent)
assert items[2].route_node_state.node_id == "start"
@patch("extensions.ext_database.db.session.remove")
@patch("extensions.ext_database.db.session.close")
def test_run_parallel_in_chatflow(mock_close, mock_remove):
graph_config = {
"edges": [
{
"id": "1",
"source": "start",
"target": "answer1",
},
{
"id": "2",
"source": "answer1",
"target": "answer2",
},
{
"id": "3",
"source": "answer1",
"target": "answer3",
},
{
"id": "4",
"source": "answer2",
"target": "answer4",
},
{
"id": "5",
"source": "answer3",
"target": "answer5",
},
],
"nodes": [
{"data": {"type": "start", "title": "start"}, "id": "start"},
{"data": {"type": "answer", "title": "answer1", "answer": "1"}, "id": "answer1"},
{
"data": {"type": "answer", "title": "answer2", "answer": "2"},
"id": "answer2",
},
{
"data": {"type": "answer", "title": "answer3", "answer": "3"},
"id": "answer3",
},
{
"data": {"type": "answer", "title": "answer4", "answer": "4"},
"id": "answer4",
},
{
"data": {"type": "answer", "title": "answer5", "answer": "5"},
"id": "answer5",
},
],
}
graph = Graph.init(graph_config=graph_config)
variable_pool = VariablePool(
system_variables={
SystemVariableKey.QUERY: "what's the weather in SF",
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: "abababa",
SystemVariableKey.USER_ID: "aaa",
},
user_inputs={},
)
graph_engine = GraphEngine(
tenant_id="111",
app_id="222",
workflow_type=WorkflowType.CHAT,
workflow_id="333",
graph_config=graph_config,
user_id="444",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.WEB_APP,
call_depth=0,
graph=graph,
variable_pool=variable_pool,
max_execution_steps=500,
max_execution_time=1200,
)
# print("")
items = []
generator = graph_engine.run()
for item in generator:
# print(type(item), item)
items.append(item)
if isinstance(item, NodeRunSucceededEvent):
assert item.route_node_state.status == RouteNodeState.Status.SUCCESS
assert not isinstance(item, NodeRunFailedEvent)
assert not isinstance(item, GraphRunFailedEvent)
if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in [
"answer2",
"answer3",
"answer4",
"answer5",
]:
assert item.parallel_id is not None
assert len(items) == 23
assert isinstance(items[0], GraphRunStartedEvent)
assert isinstance(items[1], NodeRunStartedEvent)
assert items[1].route_node_state.node_id == "start"
assert isinstance(items[2], NodeRunSucceededEvent)
assert items[2].route_node_state.node_id == "start"
@patch("extensions.ext_database.db.session.remove")
@patch("extensions.ext_database.db.session.close")
def test_run_branch(mock_close, mock_remove):
graph_config = {
"edges": [
{
"id": "1",
"source": "start",
"target": "if-else-1",
},
{
"id": "2",
"source": "if-else-1",
"sourceHandle": "true",
"target": "answer-1",
},
{
"id": "3",
"source": "if-else-1",
"sourceHandle": "false",
"target": "if-else-2",
},
{
"id": "4",
"source": "if-else-2",
"sourceHandle": "true",
"target": "answer-2",
},
{
"id": "5",
"source": "if-else-2",
"sourceHandle": "false",
"target": "answer-3",
},
],
"nodes": [
{
"data": {
"title": "Start",
"type": "start",
"variables": [
{
"label": "uid",
"max_length": 48,
"options": [],
"required": True,
"type": "text-input",
"variable": "uid",
}
],
},
"id": "start",
},
{
"data": {"answer": "1 {{#start.uid#}}", "title": "Answer", "type": "answer", "variables": []},
"id": "answer-1",
},
{
"data": {
"cases": [
{
"case_id": "true",
"conditions": [
{
"comparison_operator": "contains",
"id": "b0f02473-08b6-4a81-af91-15345dcb2ec8",
"value": "hi",
"varType": "string",
"variable_selector": ["sys", "query"],
}
],
"id": "true",
"logical_operator": "and",
}
],
"desc": "",
"title": "IF/ELSE",
"type": "if-else",
},
"id": "if-else-1",
},
{
"data": {
"cases": [
{
"case_id": "true",
"conditions": [
{
"comparison_operator": "contains",
"id": "ae895199-5608-433b-b5f0-0997ae1431e4",
"value": "takatost",
"varType": "string",
"variable_selector": ["sys", "query"],
}
],
"id": "true",
"logical_operator": "and",
}
],
"title": "IF/ELSE 2",
"type": "if-else",
},
"id": "if-else-2",
},
{
"data": {
"answer": "2",
"title": "Answer 2",
"type": "answer",
},
"id": "answer-2",
},
{
"data": {
"answer": "3",
"title": "Answer 3",
"type": "answer",
},
"id": "answer-3",
},
],
}
graph = Graph.init(graph_config=graph_config)
variable_pool = VariablePool(
system_variables={
SystemVariableKey.QUERY: "hi",
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: "abababa",
SystemVariableKey.USER_ID: "aaa",
},
user_inputs={"uid": "takato"},
)
graph_engine = GraphEngine(
tenant_id="111",
app_id="222",
workflow_type=WorkflowType.CHAT,
workflow_id="333",
graph_config=graph_config,
user_id="444",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.WEB_APP,
call_depth=0,
graph=graph,
variable_pool=variable_pool,
max_execution_steps=500,
max_execution_time=1200,
)
# print("")
items = []
generator = graph_engine.run()
for item in generator:
# print(type(item), item)
items.append(item)
assert len(items) == 10
assert items[3].route_node_state.node_id == "if-else-1"
assert items[4].route_node_state.node_id == "if-else-1"
assert isinstance(items[5], NodeRunStreamChunkEvent)
assert items[5].chunk_content == "1 "
assert isinstance(items[6], NodeRunStreamChunkEvent)
assert items[6].chunk_content == "takato"
assert items[7].route_node_state.node_id == "answer-1"
assert items[8].route_node_state.node_id == "answer-1"
assert items[8].route_node_state.node_run_result.outputs["answer"] == "1 takato"
assert isinstance(items[9], GraphRunSucceededEvent)
# print(graph_engine.graph_runtime_state.model_dump_json(indent=2))

View File

@ -0,0 +1,82 @@
import time
import uuid
from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import UserFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.answer.answer_node import AnswerNode
from extensions.ext_database import db
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
def test_execute_answer():
graph_config = {
"edges": [
{
"id": "start-source-llm-target",
"source": "start",
"target": "llm",
},
],
"nodes": [
{"data": {"type": "start"}, "id": "start"},
{
"data": {
"type": "llm",
},
"id": "llm",
},
],
}
graph = Graph.init(graph_config=graph_config)
init_params = GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config=graph_config,
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
# construct variable pool
pool = VariablePool(
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
user_inputs={},
environment_variables=[],
)
pool.add(["start", "weather"], "sunny")
pool.add(["llm", "text"], "You are a helpful AI.")
node = AnswerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={
"id": "answer",
"data": {
"title": "123",
"type": "answer",
"answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
},
},
)
# Mock db.session.close()
db.session.close = MagicMock()
# execute node
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs["answer"] == "Today's weather is sunny\nYou are a helpful AI.\n{{img}}\nFin."

View File

@ -0,0 +1,109 @@
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
def test_init():
graph_config = {
"edges": [
{
"id": "start-source-llm1-target",
"source": "start",
"target": "llm1",
},
{
"id": "start-source-llm2-target",
"source": "start",
"target": "llm2",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm3",
},
{
"id": "llm3-source-llm4-target",
"source": "llm3",
"target": "llm4",
},
{
"id": "llm3-source-llm5-target",
"source": "llm3",
"target": "llm5",
},
{
"id": "llm4-source-answer2-target",
"source": "llm4",
"target": "answer2",
},
{
"id": "llm5-source-answer-target",
"source": "llm5",
"target": "answer",
},
{
"id": "answer2-source-answer-target",
"source": "answer2",
"target": "answer",
},
{
"id": "llm2-source-answer-target",
"source": "llm2",
"target": "answer",
},
{
"id": "llm1-source-answer-target",
"source": "llm1",
"target": "answer",
},
],
"nodes": [
{"data": {"type": "start"}, "id": "start"},
{
"data": {
"type": "llm",
},
"id": "llm1",
},
{
"data": {
"type": "llm",
},
"id": "llm2",
},
{
"data": {
"type": "llm",
},
"id": "llm3",
},
{
"data": {
"type": "llm",
},
"id": "llm4",
},
{
"data": {
"type": "llm",
},
"id": "llm5",
},
{
"data": {"type": "answer", "title": "answer", "answer": "1{{#llm2.text#}}2"},
"id": "answer",
},
{
"data": {"type": "answer", "title": "answer2", "answer": "1{{#llm3.text#}}2"},
"id": "answer2",
},
],
}
graph = Graph.init(graph_config=graph_config)
answer_stream_generate_route = AnswerStreamGeneratorRouter.init(
node_id_config_mapping=graph.node_id_config_mapping, reverse_edge_mapping=graph.reverse_edge_mapping
)
assert answer_stream_generate_route.answer_dependencies["answer"] == ["answer2"]
assert answer_stream_generate_route.answer_dependencies["answer2"] == []

View File

@ -0,0 +1,216 @@
import uuid
from collections.abc import Generator
from datetime import datetime, timezone
from core.workflow.entities.node_entities import NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
from core.workflow.nodes.start.entities import StartNodeData
def _recursive_process(graph: Graph, next_node_id: str) -> Generator[GraphEngineEvent, None, None]:
if next_node_id == "start":
yield from _publish_events(graph, next_node_id)
for edge in graph.edge_mapping.get(next_node_id, []):
yield from _publish_events(graph, edge.target_node_id)
for edge in graph.edge_mapping.get(next_node_id, []):
yield from _recursive_process(graph, edge.target_node_id)
def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEvent, None, None]:
route_node_state = RouteNodeState(node_id=next_node_id, start_at=datetime.now(timezone.utc).replace(tzinfo=None))
parallel_id = graph.node_parallel_mapping.get(next_node_id)
parallel_start_node_id = None
if parallel_id:
parallel = graph.parallel_mapping.get(parallel_id)
parallel_start_node_id = parallel.start_from_node_id if parallel else None
node_execution_id = str(uuid.uuid4())
node_config = graph.node_id_config_mapping[next_node_id]
node_type = NodeType.value_of(node_config.get("data", {}).get("type"))
mock_node_data = StartNodeData(**{"title": "demo", "variables": []})
yield NodeRunStartedEvent(
id=node_execution_id,
node_id=next_node_id,
node_type=node_type,
node_data=mock_node_data,
route_node_state=route_node_state,
parallel_id=graph.node_parallel_mapping.get(next_node_id),
parallel_start_node_id=parallel_start_node_id,
)
if "llm" in next_node_id:
length = int(next_node_id[-1])
for i in range(0, length):
yield NodeRunStreamChunkEvent(
id=node_execution_id,
node_id=next_node_id,
node_type=node_type,
node_data=mock_node_data,
chunk_content=str(i),
route_node_state=route_node_state,
from_variable_selector=[next_node_id, "text"],
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
)
route_node_state.status = RouteNodeState.Status.SUCCESS
route_node_state.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
yield NodeRunSucceededEvent(
id=node_execution_id,
node_id=next_node_id,
node_type=node_type,
node_data=mock_node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
)
def test_process():
graph_config = {
"edges": [
{
"id": "start-source-llm1-target",
"source": "start",
"target": "llm1",
},
{
"id": "start-source-llm2-target",
"source": "start",
"target": "llm2",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm3",
},
{
"id": "llm3-source-llm4-target",
"source": "llm3",
"target": "llm4",
},
{
"id": "llm3-source-llm5-target",
"source": "llm3",
"target": "llm5",
},
{
"id": "llm4-source-answer2-target",
"source": "llm4",
"target": "answer2",
},
{
"id": "llm5-source-answer-target",
"source": "llm5",
"target": "answer",
},
{
"id": "answer2-source-answer-target",
"source": "answer2",
"target": "answer",
},
{
"id": "llm2-source-answer-target",
"source": "llm2",
"target": "answer",
},
{
"id": "llm1-source-answer-target",
"source": "llm1",
"target": "answer",
},
],
"nodes": [
{"data": {"type": "start"}, "id": "start"},
{
"data": {
"type": "llm",
},
"id": "llm1",
},
{
"data": {
"type": "llm",
},
"id": "llm2",
},
{
"data": {
"type": "llm",
},
"id": "llm3",
},
{
"data": {
"type": "llm",
},
"id": "llm4",
},
{
"data": {
"type": "llm",
},
"id": "llm5",
},
{
"data": {"type": "answer", "title": "answer", "answer": "a{{#llm2.text#}}b"},
"id": "answer",
},
{
"data": {"type": "answer", "title": "answer2", "answer": "c{{#llm3.text#}}d"},
"id": "answer2",
},
],
}
graph = Graph.init(graph_config=graph_config)
variable_pool = VariablePool(
system_variables={
SystemVariableKey.QUERY: "what's the weather in SF",
SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: "abababa",
SystemVariableKey.USER_ID: "aaa",
},
user_inputs={},
)
answer_stream_processor = AnswerStreamProcessor(graph=graph, variable_pool=variable_pool)
def graph_generator() -> Generator[GraphEngineEvent, None, None]:
# print("")
for event in _recursive_process(graph, "start"):
# print("[ORIGIN]", event.__class__.__name__ + ":", event.route_node_state.node_id,
# " " + (event.chunk_content if isinstance(event, NodeRunStreamChunkEvent) else ""))
if isinstance(event, NodeRunSucceededEvent):
if "llm" in event.route_node_state.node_id:
variable_pool.add(
[event.route_node_state.node_id, "text"],
"".join(str(i) for i in range(0, int(event.route_node_state.node_id[-1]))),
)
yield event
result_generator = answer_stream_processor.process(graph_generator())
stream_contents = ""
for event in result_generator:
# print("[ANSWER]", event.__class__.__name__ + ":", event.route_node_state.node_id,
# " " + (event.chunk_content if isinstance(event, NodeRunStreamChunkEvent) else ""))
if isinstance(event, NodeRunStreamChunkEvent):
stream_contents += event.chunk_content
pass
assert stream_contents == "c012da01b"

Some files were not shown because too many files have changed in this diff Show More