feat: invoke node

This commit is contained in:
Yeuoly 2024-09-24 20:15:13 +08:00
parent 68c10a1672
commit a91951b374
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61
4 changed files with 43 additions and 15 deletions

View File

@ -1,4 +1,5 @@
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
from core.workflow.entities.node_entities import NodeType
from core.workflow.nodes.parameter_extractor.entities import (
ModelConfig as ParameterExtractorModelConfig,
)
@ -36,7 +37,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
:param model_config: ModelConfig
:param instruction: str
:param query: str
:return: dict with __reason, __is_success, and other parameters
:return: dict
"""
workflow_service = WorkflowService()
node_id = "1919810"
@ -50,6 +51,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
instruction=instruction, # instruct with variables are not supported
)
node_data_dict = node_data.model_dump()
node_data_dict["type"] = NodeType.PARAMETER_EXTRACTOR.value
execution = workflow_service.run_free_workflow_node(
node_data_dict,
tenant_id=tenant_id,
@ -60,10 +62,10 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
},
)
output = execution.outputs_dict
return output or {
"__reason": "No parameters extracted",
"__is_success": False,
return {
"inputs": execution.inputs_dict,
"outputs": execution.outputs_dict,
"process_data": execution.process_data_dict,
}
@classmethod
@ -85,7 +87,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
:param classes: list[ClassConfig]
:param instruction: str
:param query: str
:return: dict with class_name
:return: dict
"""
workflow_service = WorkflowService()
node_id = "1919810"
@ -108,7 +110,8 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
},
)
output = execution.outputs_dict
return output or {
"class_name": classes[0].name,
}
return {
"inputs": execution.inputs_dict,
"outputs": execution.outputs_dict,
"process_data": execution.process_data_dict,
}

View File

@ -14,16 +14,18 @@ from core.model_runtime.entities.message_entities import (
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import ModelType
from core.workflow.nodes.question_classifier.entities import (
ClassConfig,
ModelConfig as QuestionClassifierModelConfig,
)
from core.workflow.nodes.parameter_extractor.entities import (
ModelConfig as ParameterExtractorModelConfig,
)
from core.workflow.nodes.parameter_extractor.entities import (
ParameterConfig,
)
from core.workflow.nodes.question_classifier.entities import (
ClassConfig,
)
from core.workflow.nodes.question_classifier.entities import (
ModelConfig as QuestionClassifierModelConfig,
)
class RequestInvokeTool(BaseModel):

View File

@ -221,8 +221,27 @@ class WorkflowEntry:
"""
# generate a fake graph
node_config = {"id": node_id, "width": 114, "height": 514, "type": "custom", "data": node_data}
start_node_config = {
"id": "start",
"width": 114,
"height": 514,
"type": "custom",
"data": {
"type": NodeType.START.value,
"title": "Start",
"desc": "Start",
},
}
graph_dict = {
"nodes": [node_config],
"nodes": [start_node_config, node_config],
"edges": [
{
"source": "start",
"target": node_id,
"sourceHandle": "source",
"targetHandle": "target",
}
],
}
node_type = NodeType.value_of(node_data.get("type", ""))

View File

@ -230,6 +230,10 @@ class WorkflowService:
node_id=node_id,
)
workflow_node_execution.app_id = app_model.id
workflow_node_execution.created_by = account.id
workflow_node_execution.workflow_id = draft_workflow.id
db.session.add(workflow_node_execution)
db.session.commit()