mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
refactor(core): improve type annotations and file handling consistency
- Use more precise type annotations with Sequence and Mapping for task entities. - Ensure raw_prompt is assigned properly after replacement in advanced prompt transform. - Remove unused generator return type from _fetch_context method. - Refactor tool node file handling to retrieve more comprehensive file attributes, ensuring file existence validation in the database.
This commit is contained in:
parent
dba67cd87a
commit
a36ef8430e
|
@ -213,7 +213,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
|
|||
created_by: Optional[dict] = None
|
||||
created_at: int
|
||||
finished_at: int
|
||||
files: Optional[list[dict]] = []
|
||||
files: Optional[Sequence[Mapping[str, Any]]] = []
|
||||
|
||||
event: StreamEvent = StreamEvent.WORKFLOW_FINISHED
|
||||
workflow_run_id: str
|
||||
|
@ -298,7 +298,7 @@ class NodeFinishStreamResponse(StreamResponse):
|
|||
execution_metadata: Optional[dict] = None
|
||||
created_at: int
|
||||
finished_at: int
|
||||
files: Optional[list[dict]] = []
|
||||
files: Optional[Sequence[Mapping[str, Any]]] = []
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
parent_parallel_id: Optional[str] = None
|
||||
|
|
|
@ -150,7 +150,7 @@ class AdvancedPromptTransform(PromptTransform):
|
|||
for k, v in inputs.items():
|
||||
if k.startswith("#"):
|
||||
vp.add(k[1:-1].split("."), v)
|
||||
raw_prompt.replace("{{#context#}}", context or "")
|
||||
raw_prompt = raw_prompt.replace("{{#context#}}", context or "")
|
||||
prompt = vp.convert_template(raw_prompt).text
|
||||
else:
|
||||
parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
|
||||
|
|
|
@ -359,7 +359,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
return []
|
||||
raise ValueError(f"Invalid variable type: {type(variable)}")
|
||||
|
||||
def _fetch_context(self, node_data: LLMNodeData) -> Generator[RunEvent, None, None]:
|
||||
def _fetch_context(self, node_data: LLMNodeData):
|
||||
if not node_data.context.enabled:
|
||||
return
|
||||
|
||||
|
|
|
@ -2,6 +2,9 @@ from collections.abc import Mapping, Sequence
|
|||
from os import path
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.file.models import File, FileTransferMethod, FileType
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
|
@ -14,6 +17,8 @@ from core.workflow.nodes.base_node import BaseNode
|
|||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from enums import NodeType
|
||||
from extensions.ext_database import db
|
||||
from models import ToolFile
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
|
@ -167,45 +172,59 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||
result = []
|
||||
for response in tool_response:
|
||||
if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
|
||||
url = response.message
|
||||
ext = path.splitext(url)[1]
|
||||
mimetype = response.meta.get("mime_type", "image/jpeg")
|
||||
tool_file_id = response.save_as or url.split("/")[-1]
|
||||
url = str(response.message) if response.message else None
|
||||
ext = path.splitext(url)[1] if url else ".bin"
|
||||
tool_file_id = response.save_as or str(url).split("/")[-1].split(".")[0]
|
||||
transfer_method = response.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
|
||||
|
||||
# get tool file id
|
||||
tool_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ValueError(f"tool file {tool_file_id} not exists")
|
||||
|
||||
result.append(
|
||||
File(
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=transfer_method,
|
||||
remote_url=url,
|
||||
related_id=tool_file_id,
|
||||
filename=tool_file_id,
|
||||
related_id=tool_file.id,
|
||||
filename=tool_file.name,
|
||||
extension=ext,
|
||||
mime_type=mimetype,
|
||||
mime_type=tool_file.mimetype,
|
||||
size=tool_file.size,
|
||||
)
|
||||
)
|
||||
elif response.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
# get tool file id
|
||||
tool_file_id = str(response.message).split("/")[-1].split(".")[0]
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ValueError(f"tool file {tool_file_id} not exists")
|
||||
result.append(
|
||||
File(
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
related_id=tool_file_id,
|
||||
filename=response.save_as,
|
||||
related_id=tool_file.id,
|
||||
filename=tool_file.name,
|
||||
extension=path.splitext(response.save_as)[1],
|
||||
mime_type=response.meta.get("mime_type", "application/octet-stream"),
|
||||
mime_type=tool_file.mimetype,
|
||||
size=tool_file.size,
|
||||
)
|
||||
)
|
||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||
url = str(response.message)
|
||||
transfer_method = FileTransferMethod.TOOL_FILE
|
||||
mimetype = response.meta.get("mime_type", "application/octet-stream")
|
||||
tool_file_id = url.split("/")[-1].split(".")[0]
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ValueError(f"tool file {tool_file_id} not exists")
|
||||
if "." in url:
|
||||
extension = "." + url.split("/")[-1].split(".")[1]
|
||||
else:
|
||||
|
@ -215,10 +234,11 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||
type=FileType(response.save_as),
|
||||
transfer_method=transfer_method,
|
||||
remote_url=url,
|
||||
filename=tool_file_id,
|
||||
related_id=tool_file_id,
|
||||
filename=tool_file.name,
|
||||
related_id=tool_file.id,
|
||||
extension=extension,
|
||||
mime_type=mimetype,
|
||||
mime_type=tool_file.mimetype,
|
||||
size=tool_file.size,
|
||||
)
|
||||
result.append(file)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user