mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 19:59:50 +08:00
feat: add backwards invoke node api
This commit is contained in:
parent
592f85f7a9
commit
68c10a1672
|
@ -8,13 +8,15 @@ from controllers.inner_api.plugin.wraps import get_tenant, plugin_data
|
|||
from controllers.inner_api.wraps import plugin_inner_api_only
|
||||
from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation
|
||||
from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation
|
||||
from core.plugin.backwards_invocation.node import PluginNodeBackwardsInvocation
|
||||
from core.plugin.encrypt import PluginEncrypter
|
||||
from core.plugin.entities.request import (
|
||||
RequestInvokeApp,
|
||||
RequestInvokeEncrypt,
|
||||
RequestInvokeLLM,
|
||||
RequestInvokeModeration,
|
||||
RequestInvokeNode,
|
||||
RequestInvokeParameterExtractorNode,
|
||||
RequestInvokeQuestionClassifierNode,
|
||||
RequestInvokeRerank,
|
||||
RequestInvokeSpeech2Text,
|
||||
RequestInvokeTextEmbedding,
|
||||
|
@ -96,23 +98,46 @@ class PluginInvokeToolApi(Resource):
|
|||
yield (
|
||||
ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.TEXT,
|
||||
message=ToolInvokeMessage.TextMessage(text='helloworld'),
|
||||
message=ToolInvokeMessage.TextMessage(text="helloworld"),
|
||||
)
|
||||
.model_dump_json()
|
||||
.encode()
|
||||
+ b'\n\n'
|
||||
+ b"\n\n"
|
||||
)
|
||||
|
||||
return compact_generate_response(generator())
|
||||
|
||||
|
||||
class PluginInvokeNodeApi(Resource):
|
||||
class PluginInvokeParameterExtractorNodeApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_tenant
|
||||
@plugin_data(payload_type=RequestInvokeNode)
|
||||
def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeNode):
|
||||
pass
|
||||
@plugin_data(payload_type=RequestInvokeParameterExtractorNode)
|
||||
def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeParameterExtractorNode):
|
||||
return PluginNodeBackwardsInvocation.invoke_parameter_extractor(
|
||||
tenant_id=tenant_model.id,
|
||||
user_id=user_id,
|
||||
parameters=payload.parameters,
|
||||
model_config=payload.model,
|
||||
instruction=payload.instruction,
|
||||
query=payload.query,
|
||||
)
|
||||
|
||||
|
||||
class PluginInvokeQuestionClassifierNodeApi(Resource):
|
||||
@setup_required
|
||||
@plugin_inner_api_only
|
||||
@get_tenant
|
||||
@plugin_data(payload_type=RequestInvokeQuestionClassifierNode)
|
||||
def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeQuestionClassifierNode):
|
||||
return PluginNodeBackwardsInvocation.invoke_question_classifier(
|
||||
tenant_id=tenant_model.id,
|
||||
user_id=user_id,
|
||||
query=payload.query,
|
||||
model_config=payload.model,
|
||||
classes=payload.classes,
|
||||
instruction=payload.instruction,
|
||||
)
|
||||
|
||||
|
||||
class PluginInvokeAppApi(Resource):
|
||||
|
@ -127,15 +152,13 @@ class PluginInvokeAppApi(Resource):
|
|||
tenant_id=tenant_model.id,
|
||||
conversation_id=payload.conversation_id,
|
||||
query=payload.query,
|
||||
stream=payload.response_mode == 'streaming',
|
||||
stream=payload.response_mode == "streaming",
|
||||
inputs=payload.inputs,
|
||||
files=payload.files
|
||||
)
|
||||
|
||||
return compact_generate_response(
|
||||
PluginAppBackwardsInvocation.convert_to_event_stream(response)
|
||||
files=payload.files,
|
||||
)
|
||||
|
||||
return compact_generate_response(PluginAppBackwardsInvocation.convert_to_event_stream(response))
|
||||
|
||||
|
||||
class PluginInvokeEncryptApi(Resource):
|
||||
@setup_required
|
||||
|
@ -149,13 +172,14 @@ class PluginInvokeEncryptApi(Resource):
|
|||
return PluginEncrypter.invoke_encrypt(tenant_model, payload)
|
||||
|
||||
|
||||
api.add_resource(PluginInvokeLLMApi, '/invoke/llm')
|
||||
api.add_resource(PluginInvokeTextEmbeddingApi, '/invoke/text-embedding')
|
||||
api.add_resource(PluginInvokeRerankApi, '/invoke/rerank')
|
||||
api.add_resource(PluginInvokeTTSApi, '/invoke/tts')
|
||||
api.add_resource(PluginInvokeSpeech2TextApi, '/invoke/speech2text')
|
||||
api.add_resource(PluginInvokeModerationApi, '/invoke/moderation')
|
||||
api.add_resource(PluginInvokeToolApi, '/invoke/tool')
|
||||
api.add_resource(PluginInvokeNodeApi, '/invoke/node')
|
||||
api.add_resource(PluginInvokeAppApi, '/invoke/app')
|
||||
api.add_resource(PluginInvokeEncryptApi, '/invoke/encrypt')
|
||||
api.add_resource(PluginInvokeLLMApi, "/invoke/llm")
|
||||
api.add_resource(PluginInvokeTextEmbeddingApi, "/invoke/text-embedding")
|
||||
api.add_resource(PluginInvokeRerankApi, "/invoke/rerank")
|
||||
api.add_resource(PluginInvokeTTSApi, "/invoke/tts")
|
||||
api.add_resource(PluginInvokeSpeech2TextApi, "/invoke/speech2text")
|
||||
api.add_resource(PluginInvokeModerationApi, "/invoke/moderation")
|
||||
api.add_resource(PluginInvokeToolApi, "/invoke/tool")
|
||||
api.add_resource(PluginInvokeParameterExtractorNodeApi, "/invoke/parameter-extractor")
|
||||
api.add_resource(PluginInvokeQuestionClassifierNodeApi, "/invoke/question-classifier")
|
||||
api.add_resource(PluginInvokeAppApi, "/invoke/app")
|
||||
api.add_resource(PluginInvokeEncryptApi, "/invoke/encrypt")
|
||||
|
|
114
api/core/plugin/backwards_invocation/node.py
Normal file
114
api/core/plugin/backwards_invocation/node.py
Normal file
|
@ -0,0 +1,114 @@
|
|||
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
|
||||
from core.workflow.nodes.parameter_extractor.entities import (
|
||||
ModelConfig as ParameterExtractorModelConfig,
|
||||
)
|
||||
from core.workflow.nodes.parameter_extractor.entities import (
|
||||
ParameterConfig,
|
||||
ParameterExtractorNodeData,
|
||||
)
|
||||
from core.workflow.nodes.question_classifier.entities import (
|
||||
ClassConfig,
|
||||
QuestionClassifierNodeData,
|
||||
)
|
||||
from core.workflow.nodes.question_classifier.entities import (
|
||||
ModelConfig as QuestionClassifierModelConfig,
|
||||
)
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
|
||||
class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
|
||||
@classmethod
|
||||
def invoke_parameter_extractor(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
parameters: list[ParameterConfig],
|
||||
model_config: ParameterExtractorModelConfig,
|
||||
instruction: str,
|
||||
query: str,
|
||||
) -> dict:
|
||||
"""
|
||||
Invoke parameter extractor node.
|
||||
|
||||
:param tenant_id: str
|
||||
:param user_id: str
|
||||
:param parameters: list[ParameterConfig]
|
||||
:param model_config: ModelConfig
|
||||
:param instruction: str
|
||||
:param query: str
|
||||
:return: dict with __reason, __is_success, and other parameters
|
||||
"""
|
||||
workflow_service = WorkflowService()
|
||||
node_id = "1919810"
|
||||
node_data = ParameterExtractorNodeData(
|
||||
title="parameter_extractor",
|
||||
desc="parameter_extractor",
|
||||
parameters=parameters,
|
||||
reasoning_mode="function_call",
|
||||
query=[node_id, "query"],
|
||||
model=model_config,
|
||||
instruction=instruction, # instruct with variables are not supported
|
||||
)
|
||||
node_data_dict = node_data.model_dump()
|
||||
execution = workflow_service.run_free_workflow_node(
|
||||
node_data_dict,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
node_id=node_id,
|
||||
user_inputs={
|
||||
f"{node_id}.query": query,
|
||||
},
|
||||
)
|
||||
|
||||
output = execution.outputs_dict
|
||||
return output or {
|
||||
"__reason": "No parameters extracted",
|
||||
"__is_success": False,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def invoke_question_classifier(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
model_config: QuestionClassifierModelConfig,
|
||||
classes: list[ClassConfig],
|
||||
instruction: str,
|
||||
query: str,
|
||||
) -> dict:
|
||||
"""
|
||||
Invoke question classifier node.
|
||||
|
||||
:param tenant_id: str
|
||||
:param user_id: str
|
||||
:param model_config: ModelConfig
|
||||
:param classes: list[ClassConfig]
|
||||
:param instruction: str
|
||||
:param query: str
|
||||
:return: dict with class_name
|
||||
"""
|
||||
workflow_service = WorkflowService()
|
||||
node_id = "1919810"
|
||||
node_data = QuestionClassifierNodeData(
|
||||
title="question_classifier",
|
||||
desc="question_classifier",
|
||||
query_variable_selector=[node_id, "query"],
|
||||
model=model_config,
|
||||
classes=classes,
|
||||
instruction=instruction, # instruct with variables are not supported
|
||||
)
|
||||
node_data_dict = node_data.model_dump()
|
||||
execution = workflow_service.run_free_workflow_node(
|
||||
node_data_dict,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
node_id=node_id,
|
||||
user_inputs={
|
||||
f"{node_id}.query": query,
|
||||
},
|
||||
)
|
||||
|
||||
output = execution.outputs_dict
|
||||
return output or {
|
||||
"class_name": classes[0].name,
|
||||
}
|
|
@ -14,6 +14,16 @@ from core.model_runtime.entities.message_entities import (
|
|||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.workflow.nodes.question_classifier.entities import (
|
||||
ClassConfig,
|
||||
ModelConfig as QuestionClassifierModelConfig,
|
||||
)
|
||||
from core.workflow.nodes.parameter_extractor.entities import (
|
||||
ModelConfig as ParameterExtractorModelConfig,
|
||||
)
|
||||
from core.workflow.nodes.parameter_extractor.entities import (
|
||||
ParameterConfig,
|
||||
)
|
||||
|
||||
|
||||
class RequestInvokeTool(BaseModel):
|
||||
|
@ -92,11 +102,27 @@ class RequestInvokeModeration(BaseModel):
|
|||
"""
|
||||
|
||||
|
||||
class RequestInvokeNode(BaseModel):
|
||||
class RequestInvokeParameterExtractorNode(BaseModel):
|
||||
"""
|
||||
Request to invoke node
|
||||
Request to invoke parameter extractor node
|
||||
"""
|
||||
|
||||
parameters: list[ParameterConfig]
|
||||
model: ParameterExtractorModelConfig
|
||||
instruction: str
|
||||
query: str
|
||||
|
||||
|
||||
class RequestInvokeQuestionClassifierNode(BaseModel):
|
||||
"""
|
||||
Request to invoke question classifier node
|
||||
"""
|
||||
|
||||
query: str
|
||||
model: QuestionClassifierModelConfig
|
||||
classes: list[ClassConfig]
|
||||
instruction: str
|
||||
|
||||
|
||||
class RequestInvokeApp(BaseModel):
|
||||
"""
|
||||
|
|
|
@ -205,6 +205,88 @@ class WorkflowEntry:
|
|||
except Exception as e:
|
||||
raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
|
||||
|
||||
@classmethod
|
||||
def run_free_node(
|
||||
cls, node_data: dict, node_id: str, tenant_id: str, user_id: str, user_inputs: dict[str, Any]
|
||||
) -> tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]:
|
||||
"""
|
||||
Run free node
|
||||
|
||||
NOTE: only parameter_extractor/question_classifier are supported
|
||||
|
||||
:param node_data: node data
|
||||
:param user_id: user id
|
||||
:param user_inputs: user inputs
|
||||
:return:
|
||||
"""
|
||||
# generate a fake graph
|
||||
node_config = {"id": node_id, "width": 114, "height": 514, "type": "custom", "data": node_data}
|
||||
graph_dict = {
|
||||
"nodes": [node_config],
|
||||
}
|
||||
|
||||
node_type = NodeType.value_of(node_data.get("type", ""))
|
||||
if node_type not in {NodeType.PARAMETER_EXTRACTOR, NodeType.QUESTION_CLASSIFIER}:
|
||||
raise ValueError(f"Node type {node_type} not supported")
|
||||
|
||||
node_cls = node_classes.get(node_type)
|
||||
if not node_cls:
|
||||
raise ValueError(f"Node class not found for node type {node_type}")
|
||||
|
||||
graph = Graph.init(graph_config=graph_dict)
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
)
|
||||
|
||||
node_cls = cast(type[BaseNode], node_cls)
|
||||
# init workflow run state
|
||||
node_instance: BaseNode = node_cls(
|
||||
id=str(uuid.uuid4()),
|
||||
config=node_config,
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id=tenant_id,
|
||||
app_id="",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="",
|
||||
graph_config=graph_dict,
|
||||
user_id=user_id,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
),
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
)
|
||||
|
||||
try:
|
||||
# variable selector to variable mapping
|
||||
try:
|
||||
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=graph_dict, config=node_config
|
||||
)
|
||||
except NotImplementedError:
|
||||
variable_mapping = {}
|
||||
|
||||
cls.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id=tenant_id,
|
||||
node_type=node_type,
|
||||
node_data=node_instance.node_data,
|
||||
)
|
||||
|
||||
# run node
|
||||
generator = node_instance.run()
|
||||
|
||||
return node_instance, generator
|
||||
except Exception as e:
|
||||
raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
|
||||
|
||||
@classmethod
|
||||
def handle_special_values(cls, value: Optional[Mapping[str, Any]]) -> Optional[dict]:
|
||||
"""
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
import json
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Generator, Sequence
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
|
@ -10,7 +10,9 @@ from core.app.segments import Variable
|
|||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||
from core.workflow.nodes.event import RunCompletedEvent
|
||||
from core.workflow.graph_engine.entities.event import InNodeEvent
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunEvent
|
||||
from core.workflow.nodes.node_mapping import node_classes
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
|
||||
|
@ -216,13 +218,64 @@ class WorkflowService:
|
|||
# run draft workflow node
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
node_instance, generator = WorkflowEntry.single_step_run(
|
||||
workflow_node_execution = self._handle_node_run_result(
|
||||
getter=lambda: WorkflowEntry.single_step_run(
|
||||
workflow=draft_workflow,
|
||||
node_id=node_id,
|
||||
user_inputs=user_inputs,
|
||||
user_id=account.id,
|
||||
)
|
||||
),
|
||||
start_at=start_at,
|
||||
tenant_id=app_model.tenant_id,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
db.session.add(workflow_node_execution)
|
||||
db.session.commit()
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def run_free_workflow_node(
|
||||
self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
|
||||
) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Run draft workflow node
|
||||
"""
|
||||
# run draft workflow node
|
||||
start_at = time.perf_counter()
|
||||
|
||||
workflow_node_execution = self._handle_node_run_result(
|
||||
getter=lambda: WorkflowEntry.run_free_node(
|
||||
node_id=node_id,
|
||||
node_data=node_data,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
user_inputs=user_inputs,
|
||||
),
|
||||
start_at=start_at,
|
||||
tenant_id=tenant_id,
|
||||
node_id=node_id
|
||||
)
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_node_run_result(
|
||||
self,
|
||||
getter: Callable[[], tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]],
|
||||
start_at: float,
|
||||
tenant_id: str,
|
||||
node_id: str,
|
||||
):
|
||||
"""
|
||||
Handle node run result
|
||||
|
||||
:param getter: Callable[[], tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]]
|
||||
:param start_at: float
|
||||
:param tenant_id: str
|
||||
:param node_id: str
|
||||
"""
|
||||
try:
|
||||
node_instance, generator = getter()
|
||||
|
||||
node_run_result: NodeRunResult | None = None
|
||||
for event in generator:
|
||||
|
@ -245,9 +298,7 @@ class WorkflowService:
|
|||
error = e.error
|
||||
|
||||
workflow_node_execution = WorkflowNodeExecution()
|
||||
workflow_node_execution.tenant_id = app_model.tenant_id
|
||||
workflow_node_execution.app_id = app_model.id
|
||||
workflow_node_execution.workflow_id = draft_workflow.id
|
||||
workflow_node_execution.tenant_id = tenant_id
|
||||
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value
|
||||
workflow_node_execution.index = 1
|
||||
workflow_node_execution.node_id = node_id
|
||||
|
@ -255,7 +306,6 @@ class WorkflowService:
|
|||
workflow_node_execution.title = node_instance.node_data.title
|
||||
workflow_node_execution.elapsed_time = time.perf_counter() - start_at
|
||||
workflow_node_execution.created_by_role = CreatedByRole.ACCOUNT.value
|
||||
workflow_node_execution.created_by = account.id
|
||||
workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
|
||||
|
@ -277,9 +327,6 @@ class WorkflowService:
|
|||
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
||||
workflow_node_execution.error = error
|
||||
|
||||
db.session.add(workflow_node_execution)
|
||||
db.session.commit()
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def convert_to_workflow(self, app_model: App, account: Account, args: dict) -> App:
|
||||
|
@ -302,10 +349,10 @@ class WorkflowService:
|
|||
new_app = workflow_converter.convert_to_workflow(
|
||||
app_model=app_model,
|
||||
account=account,
|
||||
name=args.get("name"),
|
||||
icon_type=args.get("icon_type"),
|
||||
icon=args.get("icon"),
|
||||
icon_background=args.get("icon_background"),
|
||||
name=args.get("name", ""),
|
||||
icon_type=args.get("icon_type", ""),
|
||||
icon=args.get("icon", ""),
|
||||
icon_background=args.get("icon_background", ""),
|
||||
)
|
||||
|
||||
return new_app
|
||||
|
|
Loading…
Reference in New Issue
Block a user