refactor(core): improve type annotations and file handling consistency

- Use more precise type annotations with Sequence and Mapping for task entities.
- Ensure raw_prompt is assigned properly after replacement in advanced prompt transform.
- Remove unused generator return type from _fetch_context method.
- Refactor tool node file handling to retrieve more comprehensive file attributes, ensuring file existence validation in the database.
This commit is contained in:
-LAN- 2024-10-15 15:34:14 +08:00
parent dba67cd87a
commit a36ef8430e
4 changed files with 40 additions and 20 deletions

View File

@ -213,7 +213,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
created_by: Optional[dict] = None
created_at: int
finished_at: int
files: Optional[list[dict]] = []
files: Optional[Sequence[Mapping[str, Any]]] = []
event: StreamEvent = StreamEvent.WORKFLOW_FINISHED
workflow_run_id: str
@ -298,7 +298,7 @@ class NodeFinishStreamResponse(StreamResponse):
execution_metadata: Optional[dict] = None
created_at: int
finished_at: int
files: Optional[list[dict]] = []
files: Optional[Sequence[Mapping[str, Any]]] = []
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
parent_parallel_id: Optional[str] = None

View File

@ -150,7 +150,7 @@ class AdvancedPromptTransform(PromptTransform):
for k, v in inputs.items():
if k.startswith("#"):
vp.add(k[1:-1].split("."), v)
raw_prompt.replace("{{#context#}}", context or "")
raw_prompt = raw_prompt.replace("{{#context#}}", context or "")
prompt = vp.convert_template(raw_prompt).text
else:
parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)

View File

@ -359,7 +359,7 @@ class LLMNode(BaseNode[LLMNodeData]):
return []
raise ValueError(f"Invalid variable type: {type(variable)}")
def _fetch_context(self, node_data: LLMNodeData) -> Generator[RunEvent, None, None]:
def _fetch_context(self, node_data: LLMNodeData):
if not node_data.context.enabled:
return

View File

@ -2,6 +2,9 @@ from collections.abc import Mapping, Sequence
from os import path
from typing import Any
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file.models import File, FileTransferMethod, FileType
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
@ -14,6 +17,8 @@ from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from enums import NodeType
from extensions.ext_database import db
from models import ToolFile
from models.workflow import WorkflowNodeExecutionStatus
@ -167,45 +172,59 @@ class ToolNode(BaseNode[ToolNodeData]):
result = []
for response in tool_response:
if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
url = response.message
ext = path.splitext(url)[1]
mimetype = response.meta.get("mime_type", "image/jpeg")
tool_file_id = response.save_as or url.split("/")[-1]
url = str(response.message) if response.message else None
ext = path.splitext(url)[1] if url else ".bin"
tool_file_id = response.save_as or str(url).split("/")[-1].split(".")[0]
transfer_method = response.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
# get tool file id
tool_file_id = str(url).split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ValueError(f"tool file {tool_file_id} not exists")
result.append(
File(
tenant_id=self.tenant_id,
type=FileType.IMAGE,
transfer_method=transfer_method,
remote_url=url,
related_id=tool_file_id,
filename=tool_file_id,
related_id=tool_file.id,
filename=tool_file.name,
extension=ext,
mime_type=mimetype,
mime_type=tool_file.mimetype,
size=tool_file.size,
)
)
elif response.type == ToolInvokeMessage.MessageType.BLOB:
# get tool file id
tool_file_id = str(response.message).split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ValueError(f"tool file {tool_file_id} not exists")
result.append(
File(
tenant_id=self.tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id=tool_file_id,
filename=response.save_as,
related_id=tool_file.id,
filename=tool_file.name,
extension=path.splitext(response.save_as)[1],
mime_type=response.meta.get("mime_type", "application/octet-stream"),
mime_type=tool_file.mimetype,
size=tool_file.size,
)
)
elif response.type == ToolInvokeMessage.MessageType.LINK:
url = str(response.message)
transfer_method = FileTransferMethod.TOOL_FILE
mimetype = response.meta.get("mime_type", "application/octet-stream")
tool_file_id = url.split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ValueError(f"tool file {tool_file_id} not exists")
if "." in url:
extension = "." + url.split("/")[-1].split(".")[1]
else:
@ -215,10 +234,11 @@ class ToolNode(BaseNode[ToolNodeData]):
type=FileType(response.save_as),
transfer_method=transfer_method,
remote_url=url,
filename=tool_file_id,
related_id=tool_file_id,
filename=tool_file.name,
related_id=tool_file.id,
extension=extension,
mime_type=mimetype,
mime_type=tool_file.mimetype,
size=tool_file.size,
)
result.append(file)