diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 58c7d04b83..6fb387c15a 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -1,6 +1,6 @@ import time from collections.abc import Generator -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Union from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -14,7 +14,6 @@ from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChu from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature from core.external_data_tool.external_data_fetch import ExternalDataFetch -from core.file.file_obj import FileVar from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage @@ -27,13 +26,16 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform from models.model import App, AppMode, Message, MessageAnnotation +if TYPE_CHECKING: + from core.file.file_obj import FileVar + class AppRunner: def get_pre_calculate_rest_tokens(self, app_record: App, model_config: ModelConfigWithCredentialsEntity, prompt_template_entity: PromptTemplateEntity, inputs: dict[str, str], - files: list[FileVar], + files: list["FileVar"], query: Optional[str] = None) -> int: """ Get pre calculate rest tokens @@ -126,7 +128,7 @@ class AppRunner: model_config: ModelConfigWithCredentialsEntity, prompt_template_entity: PromptTemplateEntity, inputs: dict[str, str], - files: list[FileVar], + files: list["FileVar"], query: Optional[str] = None, context: Optional[str] = None, memory: Optional[TokenBufferMemory] = None) \ @@ -366,7 +368,7 @@ class AppRunner: message_id=message_id, trace_manager=app_generate_entity.trace_manager ) - + def check_hosting_moderation(self, application_generate_entity: EasyUIBasedAppGenerateEntity, queue_manager: AppQueueManager, prompt_messages: list[PromptMessage]) -> bool: @@ -418,7 +420,7 @@ class AppRunner: inputs=inputs, query=query ) - + def query_app_annotations_to_reply(self, app_record: App, message: Message, query: str, diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 9a861c29e2..6a1ab23041 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -166,4 +166,4 @@ class WorkflowAppGenerateEntity(AppGenerateEntity): node_id: str inputs: dict - single_iteration_run: Optional[SingleIterationRunEntity] = None \ No newline at end of file + single_iteration_run: Optional[SingleIterationRunEntity] = None diff --git a/api/core/file/message_file_parser.py b/api/core/file/message_file_parser.py index 01b89907db..085ff07cfd 100644 --- a/api/core/file/message_file_parser.py +++ b/api/core/file/message_file_parser.py @@ -99,7 +99,7 @@ class MessageFileParser: # return all file objs return new_files - def transform_message_files(self, files: list[MessageFile], file_extra_config: FileExtraConfig) -> list[FileVar]: + def transform_message_files(self, files: list[MessageFile], file_extra_config: FileExtraConfig): """ transform message files @@ -144,7 +144,7 @@ class MessageFileParser: return type_file_objs - def _to_file_obj(self, file: Union[dict, MessageFile], file_extra_config: FileExtraConfig) -> FileVar: + def _to_file_obj(self, file: Union[dict, MessageFile], file_extra_config: FileExtraConfig): """ transform file to file obj diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index 452b270348..fd7ed0181b 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -1,11 +1,10 @@ import enum import json import os -from typing import Optional +from typing import TYPE_CHECKING, Optional from core.app.app_config.entities import PromptTemplateEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.file.file_obj import FileVar from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.message_entities import ( PromptMessage, @@ -18,6 +17,9 @@ from core.prompt.prompt_transform import PromptTransform from core.prompt.utils.prompt_template_parser import PromptTemplateParser from models.model import AppMode +if TYPE_CHECKING: + from core.file.file_obj import FileVar + class ModelMode(enum.Enum): COMPLETION = 'completion' @@ -50,7 +52,7 @@ class SimplePromptTransform(PromptTransform): prompt_template_entity: PromptTemplateEntity, inputs: dict, query: str, - files: list[FileVar], + files: list["FileVar"], context: Optional[str], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity) -> \ @@ -163,7 +165,7 @@ class SimplePromptTransform(PromptTransform): inputs: dict, query: str, context: Optional[str], - files: list[FileVar], + files: list["FileVar"], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity) \ -> tuple[list[PromptMessage], Optional[list[str]]]: @@ -206,7 +208,7 @@ class SimplePromptTransform(PromptTransform): inputs: dict, query: str, context: Optional[str], - files: list[FileVar], + files: list["FileVar"], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity) \ -> tuple[list[PromptMessage], Optional[list[str]]]: @@ -255,7 +257,7 @@ class SimplePromptTransform(PromptTransform): return [self.get_last_user_message(prompt, files)], stops - def get_last_user_message(self, prompt: str, files: list[FileVar]) -> UserPromptMessage: + def get_last_user_message(self, prompt: str, files: list["FileVar"]) -> UserPromptMessage: if files: prompt_message_contents = [TextPromptMessageContent(data=prompt)] for file in files: diff --git a/api/core/tools/tool/tool.py b/api/core/tools/tool/tool.py index 5d561911d1..d990131b5f 100644 --- a/api/core/tools/tool/tool.py +++ b/api/core/tools/tool/tool.py @@ -2,13 +2,12 @@ from abc import ABC, abstractmethod from collections.abc import Mapping from copy import deepcopy from enum import Enum -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union from pydantic import BaseModel, ConfigDict, field_validator from pydantic_core.core_schema import ValidationInfo from core.app.entities.app_invoke_entities import InvokeFrom -from core.file.file_obj import FileVar from core.tools.entities.tool_entities import ( ToolDescription, ToolIdentity, @@ -23,6 +22,9 @@ from core.tools.entities.tool_entities import ( from core.tools.tool_file_manager import ToolFileManager from core.tools.utils.tool_parameter_converter import ToolParameterConverter +if TYPE_CHECKING: + from core.file.file_obj import FileVar + class Tool(BaseModel, ABC): identity: Optional[ToolIdentity] = None @@ -76,7 +78,7 @@ class Tool(BaseModel, ABC): description=self.description.model_copy() if self.description else None, runtime=Tool.Runtime(**runtime), ) - + @abstractmethod def tool_provider_type(self) -> ToolProviderType: """ @@ -84,7 +86,7 @@ class Tool(BaseModel, ABC): :return: the tool provider type """ - + def load_variables(self, variables: ToolRuntimeVariablePool): """ load variables from database @@ -99,7 +101,7 @@ class Tool(BaseModel, ABC): """ if not self.variables: return - + self.variables.set_file(self.identity.name, variable_name, image_key) def set_text_variable(self, variable_name: str, text: str) -> None: @@ -108,9 +110,9 @@ class Tool(BaseModel, ABC): """ if not self.variables: return - + self.variables.set_text(self.identity.name, variable_name, text) - + def get_variable(self, name: Union[str, Enum]) -> Optional[ToolRuntimeVariable]: """ get a variable @@ -120,14 +122,14 @@ class Tool(BaseModel, ABC): """ if not self.variables: return None - + if isinstance(name, Enum): name = name.value - + for variable in self.variables.pool: if variable.name == name: return variable - + return None def get_default_image_variable(self) -> Optional[ToolRuntimeVariable]: @@ -138,9 +140,9 @@ class Tool(BaseModel, ABC): """ if not self.variables: return None - + return self.get_variable(self.VARIABLE_KEY.IMAGE) - + def get_variable_file(self, name: Union[str, Enum]) -> Optional[bytes]: """ get a variable file @@ -151,7 +153,7 @@ class Tool(BaseModel, ABC): variable = self.get_variable(name) if not variable: return None - + if not isinstance(variable, ToolRuntimeImageVariable): return None @@ -160,9 +162,9 @@ class Tool(BaseModel, ABC): file_binary = ToolFileManager.get_file_binary_by_message_file_id(message_file_id) if not file_binary: return None - + return file_binary[0] - + def list_variables(self) -> list[ToolRuntimeVariable]: """ list all variables @@ -171,9 +173,9 @@ class Tool(BaseModel, ABC): """ if not self.variables: return [] - + return self.variables.pool - + def list_default_image_variables(self) -> list[ToolRuntimeVariable]: """ list all image variables @@ -182,9 +184,9 @@ class Tool(BaseModel, ABC): """ if not self.variables: return [] - + result = [] - + for variable in self.variables.pool: if variable.name.startswith(self.VARIABLE_KEY.IMAGE.value): result.append(variable) @@ -225,7 +227,7 @@ class Tool(BaseModel, ABC): @abstractmethod def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: pass - + def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None: """ validate the credentials @@ -244,7 +246,7 @@ class Tool(BaseModel, ABC): :return: the runtime parameters """ return self.parameters or [] - + def get_all_runtime_parameters(self) -> list[ToolParameter]: """ get all runtime parameters @@ -278,7 +280,7 @@ class Tool(BaseModel, ABC): parameters.append(parameter) return parameters - + def create_image_message(self, image: str, save_as: str = '') -> ToolInvokeMessage: """ create an image message @@ -286,18 +288,18 @@ class Tool(BaseModel, ABC): :param image: the url of the image :return: the image message """ - return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE, - message=image, + return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE, + message=image, save_as=save_as) - - def create_file_var_message(self, file_var: FileVar) -> ToolInvokeMessage: + + def create_file_var_message(self, file_var: "FileVar") -> ToolInvokeMessage: return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.FILE_VAR, message='', meta={ 'file_var': file_var }, save_as='') - + def create_link_message(self, link: str, save_as: str = '') -> ToolInvokeMessage: """ create a link message @@ -305,10 +307,10 @@ class Tool(BaseModel, ABC): :param link: the url of the link :return: the link message """ - return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.LINK, - message=link, + return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.LINK, + message=link, save_as=save_as) - + def create_text_message(self, text: str, save_as: str = '') -> ToolInvokeMessage: """ create a text message @@ -321,7 +323,7 @@ class Tool(BaseModel, ABC): message=text, save_as=save_as ) - + def create_blob_message(self, blob: bytes, meta: dict = None, save_as: str = '') -> ToolInvokeMessage: """ create a blob message diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index ef9e5b67ae..564b9d3e14 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -1,7 +1,7 @@ import logging from mimetypes import guess_extension -from core.file.file_obj import FileTransferMethod, FileType, FileVar +from core.file.file_obj import FileTransferMethod, FileType from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool_file_manager import ToolFileManager @@ -27,12 +27,12 @@ class ToolFileMessageTransformer: # try to download image try: file = ToolFileManager.create_file_by_url( - user_id=user_id, + user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, file_url=message.message ) - + url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}' result.append(ToolInvokeMessage( @@ -55,14 +55,14 @@ class ToolFileMessageTransformer: # if message is str, encode it to bytes if isinstance(message.message, str): message.message = message.message.encode('utf-8') - + file = ToolFileManager.create_file_by_raw( user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, file_binary=message.message, mimetype=mimetype ) - + url = cls.get_tool_file_url(file.id, guess_extension(file.mimetype)) # check if file is image @@ -81,7 +81,7 @@ class ToolFileMessageTransformer: meta=message.meta.copy() if message.meta is not None else {}, )) elif message.type == ToolInvokeMessage.MessageType.FILE_VAR: - file_var: FileVar = message.meta.get('file_var') + file_var = message.meta.get('file_var') if file_var: if file_var.transfer_method == FileTransferMethod.TOOL_FILE: url = cls.get_tool_file_url(file_var.related_id, file_var.extension) @@ -103,7 +103,7 @@ class ToolFileMessageTransformer: result.append(message) return result - + @classmethod def get_tool_file_url(cls, tool_file_id: str, extension: str) -> str: - return f'/files/tools/{tool_file_id}{extension or ".bin"}' \ No newline at end of file + return f'/files/tools/{tool_file_id}{extension or ".bin"}' diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index 97b64d4b05..c20e0d4506 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -1,14 +1,13 @@ import json from collections.abc import Generator from copy import deepcopy -from typing import Optional, cast +from typing import TYPE_CHECKING, Optional, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.entities.queue_entities import QueueRetrieverResourcesEvent from core.entities.model_entities import ModelStatus from core.entities.provider_entities import QuotaUnit from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from core.file.file_obj import FileVar from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.llm_entities import LLMUsage @@ -39,6 +38,10 @@ from models.model import Conversation from models.provider import Provider, ProviderType from models.workflow import WorkflowNodeExecutionStatus +if TYPE_CHECKING: + from core.file.file_obj import FileVar + + class LLMNode(BaseNode): _node_data_cls = LLMNodeData @@ -71,7 +74,7 @@ class LLMNode(BaseNode): node_inputs = {} # fetch files - files: list[FileVar] = self._fetch_files(node_data, variable_pool) + files = self._fetch_files(node_data, variable_pool) if files: node_inputs['#files#'] = [file.to_dict() for file in files] @@ -322,7 +325,7 @@ class LLMNode(BaseNode): return inputs - def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list[FileVar]: + def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list["FileVar"]: """ Fetch files :param node_data: node data @@ -521,7 +524,7 @@ class LLMNode(BaseNode): query: Optional[str], query_prompt_template: Optional[str], inputs: dict[str, str], - files: list[FileVar], + files: list["FileVar"], context: Optional[str], memory: Optional[TokenBufferMemory], model_config: ModelConfigWithCredentialsEntity) \