mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 03:32:23 +08:00
chore(api/core): Improve FileVar's type hint and imports. (#7290)
This commit is contained in:
parent
6ff7fd80a1
commit
8f16165f92
|
@ -1,6 +1,6 @@
|
|||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
|
@ -14,7 +14,6 @@ from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChu
|
|||
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
|
||||
from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature
|
||||
from core.external_data_tool.external_data_fetch import ExternalDataFetch
|
||||
from core.file.file_obj import FileVar
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
|
@ -27,13 +26,16 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp
|
|||
from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform
|
||||
from models.model import App, AppMode, Message, MessageAnnotation
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.file.file_obj import FileVar
|
||||
|
||||
|
||||
class AppRunner:
|
||||
def get_pre_calculate_rest_tokens(self, app_record: App,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: dict[str, str],
|
||||
files: list[FileVar],
|
||||
files: list["FileVar"],
|
||||
query: Optional[str] = None) -> int:
|
||||
"""
|
||||
Get pre calculate rest tokens
|
||||
|
@ -126,7 +128,7 @@ class AppRunner:
|
|||
model_config: ModelConfigWithCredentialsEntity,
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: dict[str, str],
|
||||
files: list[FileVar],
|
||||
files: list["FileVar"],
|
||||
query: Optional[str] = None,
|
||||
context: Optional[str] = None,
|
||||
memory: Optional[TokenBufferMemory] = None) \
|
||||
|
@ -366,7 +368,7 @@ class AppRunner:
|
|||
message_id=message_id,
|
||||
trace_manager=app_generate_entity.trace_manager
|
||||
)
|
||||
|
||||
|
||||
def check_hosting_moderation(self, application_generate_entity: EasyUIBasedAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
prompt_messages: list[PromptMessage]) -> bool:
|
||||
|
@ -418,7 +420,7 @@ class AppRunner:
|
|||
inputs=inputs,
|
||||
query=query
|
||||
)
|
||||
|
||||
|
||||
def query_app_annotations_to_reply(self, app_record: App,
|
||||
message: Message,
|
||||
query: str,
|
||||
|
|
|
@ -166,4 +166,4 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
|
|||
node_id: str
|
||||
inputs: dict
|
||||
|
||||
single_iteration_run: Optional[SingleIterationRunEntity] = None
|
||||
single_iteration_run: Optional[SingleIterationRunEntity] = None
|
||||
|
|
|
@ -99,7 +99,7 @@ class MessageFileParser:
|
|||
# return all file objs
|
||||
return new_files
|
||||
|
||||
def transform_message_files(self, files: list[MessageFile], file_extra_config: FileExtraConfig) -> list[FileVar]:
|
||||
def transform_message_files(self, files: list[MessageFile], file_extra_config: FileExtraConfig):
|
||||
"""
|
||||
transform message files
|
||||
|
||||
|
@ -144,7 +144,7 @@ class MessageFileParser:
|
|||
|
||||
return type_file_objs
|
||||
|
||||
def _to_file_obj(self, file: Union[dict, MessageFile], file_extra_config: FileExtraConfig) -> FileVar:
|
||||
def _to_file_obj(self, file: Union[dict, MessageFile], file_extra_config: FileExtraConfig):
|
||||
"""
|
||||
transform file to file obj
|
||||
|
||||
|
|
|
@ -1,11 +1,10 @@
|
|||
import enum
|
||||
import json
|
||||
import os
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from core.app.app_config.entities import PromptTemplateEntity
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.file.file_obj import FileVar
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
PromptMessage,
|
||||
|
@ -18,6 +17,9 @@ from core.prompt.prompt_transform import PromptTransform
|
|||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from models.model import AppMode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.file.file_obj import FileVar
|
||||
|
||||
|
||||
class ModelMode(enum.Enum):
|
||||
COMPLETION = 'completion'
|
||||
|
@ -50,7 +52,7 @@ class SimplePromptTransform(PromptTransform):
|
|||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: dict,
|
||||
query: str,
|
||||
files: list[FileVar],
|
||||
files: list["FileVar"],
|
||||
context: Optional[str],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity) -> \
|
||||
|
@ -163,7 +165,7 @@ class SimplePromptTransform(PromptTransform):
|
|||
inputs: dict,
|
||||
query: str,
|
||||
context: Optional[str],
|
||||
files: list[FileVar],
|
||||
files: list["FileVar"],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity) \
|
||||
-> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
|
@ -206,7 +208,7 @@ class SimplePromptTransform(PromptTransform):
|
|||
inputs: dict,
|
||||
query: str,
|
||||
context: Optional[str],
|
||||
files: list[FileVar],
|
||||
files: list["FileVar"],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity) \
|
||||
-> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
|
@ -255,7 +257,7 @@ class SimplePromptTransform(PromptTransform):
|
|||
|
||||
return [self.get_last_user_message(prompt, files)], stops
|
||||
|
||||
def get_last_user_message(self, prompt: str, files: list[FileVar]) -> UserPromptMessage:
|
||||
def get_last_user_message(self, prompt: str, files: list["FileVar"]) -> UserPromptMessage:
|
||||
if files:
|
||||
prompt_message_contents = [TextPromptMessageContent(data=prompt)]
|
||||
for file in files:
|
||||
|
|
|
@ -2,13 +2,12 @@ from abc import ABC, abstractmethod
|
|||
from collections.abc import Mapping
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.file_obj import FileVar
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolDescription,
|
||||
ToolIdentity,
|
||||
|
@ -23,6 +22,9 @@ from core.tools.entities.tool_entities import (
|
|||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.file.file_obj import FileVar
|
||||
|
||||
|
||||
class Tool(BaseModel, ABC):
|
||||
identity: Optional[ToolIdentity] = None
|
||||
|
@ -76,7 +78,7 @@ class Tool(BaseModel, ABC):
|
|||
description=self.description.model_copy() if self.description else None,
|
||||
runtime=Tool.Runtime(**runtime),
|
||||
)
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
"""
|
||||
|
@ -84,7 +86,7 @@ class Tool(BaseModel, ABC):
|
|||
|
||||
:return: the tool provider type
|
||||
"""
|
||||
|
||||
|
||||
def load_variables(self, variables: ToolRuntimeVariablePool):
|
||||
"""
|
||||
load variables from database
|
||||
|
@ -99,7 +101,7 @@ class Tool(BaseModel, ABC):
|
|||
"""
|
||||
if not self.variables:
|
||||
return
|
||||
|
||||
|
||||
self.variables.set_file(self.identity.name, variable_name, image_key)
|
||||
|
||||
def set_text_variable(self, variable_name: str, text: str) -> None:
|
||||
|
@ -108,9 +110,9 @@ class Tool(BaseModel, ABC):
|
|||
"""
|
||||
if not self.variables:
|
||||
return
|
||||
|
||||
|
||||
self.variables.set_text(self.identity.name, variable_name, text)
|
||||
|
||||
|
||||
def get_variable(self, name: Union[str, Enum]) -> Optional[ToolRuntimeVariable]:
|
||||
"""
|
||||
get a variable
|
||||
|
@ -120,14 +122,14 @@ class Tool(BaseModel, ABC):
|
|||
"""
|
||||
if not self.variables:
|
||||
return None
|
||||
|
||||
|
||||
if isinstance(name, Enum):
|
||||
name = name.value
|
||||
|
||||
|
||||
for variable in self.variables.pool:
|
||||
if variable.name == name:
|
||||
return variable
|
||||
|
||||
|
||||
return None
|
||||
|
||||
def get_default_image_variable(self) -> Optional[ToolRuntimeVariable]:
|
||||
|
@ -138,9 +140,9 @@ class Tool(BaseModel, ABC):
|
|||
"""
|
||||
if not self.variables:
|
||||
return None
|
||||
|
||||
|
||||
return self.get_variable(self.VARIABLE_KEY.IMAGE)
|
||||
|
||||
|
||||
def get_variable_file(self, name: Union[str, Enum]) -> Optional[bytes]:
|
||||
"""
|
||||
get a variable file
|
||||
|
@ -151,7 +153,7 @@ class Tool(BaseModel, ABC):
|
|||
variable = self.get_variable(name)
|
||||
if not variable:
|
||||
return None
|
||||
|
||||
|
||||
if not isinstance(variable, ToolRuntimeImageVariable):
|
||||
return None
|
||||
|
||||
|
@ -160,9 +162,9 @@ class Tool(BaseModel, ABC):
|
|||
file_binary = ToolFileManager.get_file_binary_by_message_file_id(message_file_id)
|
||||
if not file_binary:
|
||||
return None
|
||||
|
||||
|
||||
return file_binary[0]
|
||||
|
||||
|
||||
def list_variables(self) -> list[ToolRuntimeVariable]:
|
||||
"""
|
||||
list all variables
|
||||
|
@ -171,9 +173,9 @@ class Tool(BaseModel, ABC):
|
|||
"""
|
||||
if not self.variables:
|
||||
return []
|
||||
|
||||
|
||||
return self.variables.pool
|
||||
|
||||
|
||||
def list_default_image_variables(self) -> list[ToolRuntimeVariable]:
|
||||
"""
|
||||
list all image variables
|
||||
|
@ -182,9 +184,9 @@ class Tool(BaseModel, ABC):
|
|||
"""
|
||||
if not self.variables:
|
||||
return []
|
||||
|
||||
|
||||
result = []
|
||||
|
||||
|
||||
for variable in self.variables.pool:
|
||||
if variable.name.startswith(self.VARIABLE_KEY.IMAGE.value):
|
||||
result.append(variable)
|
||||
|
@ -225,7 +227,7 @@ class Tool(BaseModel, ABC):
|
|||
@abstractmethod
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
pass
|
||||
|
||||
|
||||
def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None:
|
||||
"""
|
||||
validate the credentials
|
||||
|
@ -244,7 +246,7 @@ class Tool(BaseModel, ABC):
|
|||
:return: the runtime parameters
|
||||
"""
|
||||
return self.parameters or []
|
||||
|
||||
|
||||
def get_all_runtime_parameters(self) -> list[ToolParameter]:
|
||||
"""
|
||||
get all runtime parameters
|
||||
|
@ -278,7 +280,7 @@ class Tool(BaseModel, ABC):
|
|||
parameters.append(parameter)
|
||||
|
||||
return parameters
|
||||
|
||||
|
||||
def create_image_message(self, image: str, save_as: str = '') -> ToolInvokeMessage:
|
||||
"""
|
||||
create an image message
|
||||
|
@ -286,18 +288,18 @@ class Tool(BaseModel, ABC):
|
|||
:param image: the url of the image
|
||||
:return: the image message
|
||||
"""
|
||||
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE,
|
||||
message=image,
|
||||
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE,
|
||||
message=image,
|
||||
save_as=save_as)
|
||||
|
||||
def create_file_var_message(self, file_var: FileVar) -> ToolInvokeMessage:
|
||||
|
||||
def create_file_var_message(self, file_var: "FileVar") -> ToolInvokeMessage:
|
||||
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.FILE_VAR,
|
||||
message='',
|
||||
meta={
|
||||
'file_var': file_var
|
||||
},
|
||||
save_as='')
|
||||
|
||||
|
||||
def create_link_message(self, link: str, save_as: str = '') -> ToolInvokeMessage:
|
||||
"""
|
||||
create a link message
|
||||
|
@ -305,10 +307,10 @@ class Tool(BaseModel, ABC):
|
|||
:param link: the url of the link
|
||||
:return: the link message
|
||||
"""
|
||||
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.LINK,
|
||||
message=link,
|
||||
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.LINK,
|
||||
message=link,
|
||||
save_as=save_as)
|
||||
|
||||
|
||||
def create_text_message(self, text: str, save_as: str = '') -> ToolInvokeMessage:
|
||||
"""
|
||||
create a text message
|
||||
|
@ -321,7 +323,7 @@ class Tool(BaseModel, ABC):
|
|||
message=text,
|
||||
save_as=save_as
|
||||
)
|
||||
|
||||
|
||||
def create_blob_message(self, blob: bytes, meta: dict = None, save_as: str = '') -> ToolInvokeMessage:
|
||||
"""
|
||||
create a blob message
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import logging
|
||||
from mimetypes import guess_extension
|
||||
|
||||
from core.file.file_obj import FileTransferMethod, FileType, FileVar
|
||||
from core.file.file_obj import FileTransferMethod, FileType
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
|
||||
|
@ -27,12 +27,12 @@ class ToolFileMessageTransformer:
|
|||
# try to download image
|
||||
try:
|
||||
file = ToolFileManager.create_file_by_url(
|
||||
user_id=user_id,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=conversation_id,
|
||||
file_url=message.message
|
||||
)
|
||||
|
||||
|
||||
url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}'
|
||||
|
||||
result.append(ToolInvokeMessage(
|
||||
|
@ -55,14 +55,14 @@ class ToolFileMessageTransformer:
|
|||
# if message is str, encode it to bytes
|
||||
if isinstance(message.message, str):
|
||||
message.message = message.message.encode('utf-8')
|
||||
|
||||
|
||||
file = ToolFileManager.create_file_by_raw(
|
||||
user_id=user_id, tenant_id=tenant_id,
|
||||
conversation_id=conversation_id,
|
||||
file_binary=message.message,
|
||||
mimetype=mimetype
|
||||
)
|
||||
|
||||
|
||||
url = cls.get_tool_file_url(file.id, guess_extension(file.mimetype))
|
||||
|
||||
# check if file is image
|
||||
|
@ -81,7 +81,7 @@ class ToolFileMessageTransformer:
|
|||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
))
|
||||
elif message.type == ToolInvokeMessage.MessageType.FILE_VAR:
|
||||
file_var: FileVar = message.meta.get('file_var')
|
||||
file_var = message.meta.get('file_var')
|
||||
if file_var:
|
||||
if file_var.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
url = cls.get_tool_file_url(file_var.related_id, file_var.extension)
|
||||
|
@ -103,7 +103,7 @@ class ToolFileMessageTransformer:
|
|||
result.append(message)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_tool_file_url(cls, tool_file_id: str, extension: str) -> str:
|
||||
return f'/files/tools/{tool_file_id}{extension or ".bin"}'
|
||||
return f'/files/tools/{tool_file_id}{extension or ".bin"}'
|
||||
|
|
|
@ -1,14 +1,13 @@
|
|||
import json
|
||||
from collections.abc import Generator
|
||||
from copy import deepcopy
|
||||
from typing import Optional, cast
|
||||
from typing import TYPE_CHECKING, Optional, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.entities.provider_entities import QuotaUnit
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.file.file_obj import FileVar
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
|
@ -39,6 +38,10 @@ from models.model import Conversation
|
|||
from models.provider import Provider, ProviderType
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.file.file_obj import FileVar
|
||||
|
||||
|
||||
|
||||
class LLMNode(BaseNode):
|
||||
_node_data_cls = LLMNodeData
|
||||
|
@ -71,7 +74,7 @@ class LLMNode(BaseNode):
|
|||
node_inputs = {}
|
||||
|
||||
# fetch files
|
||||
files: list[FileVar] = self._fetch_files(node_data, variable_pool)
|
||||
files = self._fetch_files(node_data, variable_pool)
|
||||
|
||||
if files:
|
||||
node_inputs['#files#'] = [file.to_dict() for file in files]
|
||||
|
@ -322,7 +325,7 @@ class LLMNode(BaseNode):
|
|||
|
||||
return inputs
|
||||
|
||||
def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list[FileVar]:
|
||||
def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list["FileVar"]:
|
||||
"""
|
||||
Fetch files
|
||||
:param node_data: node data
|
||||
|
@ -521,7 +524,7 @@ class LLMNode(BaseNode):
|
|||
query: Optional[str],
|
||||
query_prompt_template: Optional[str],
|
||||
inputs: dict[str, str],
|
||||
files: list[FileVar],
|
||||
files: list["FileVar"],
|
||||
context: Optional[str],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity) \
|
||||
|
|
Loading…
Reference in New Issue
Block a user