mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 03:32:23 +08:00
Compare commits
32 Commits
72c1b7d1c2
...
abacc3768f
Author | SHA1 | Date | |
---|---|---|---|
|
abacc3768f | ||
|
e31358219c | ||
|
4e360ec19a | ||
|
f68d6bd5e2 | ||
|
b860a893c8 | ||
|
0354c7813e | ||
|
94794d892e | ||
|
fb94d0b7cf | ||
|
02c39b2631 | ||
|
87f78ff582 | ||
|
6872b32c7d | ||
|
97fab7649b | ||
|
9f0f82cb1c | ||
|
ef08abafdf | ||
|
d6c9ab8554 | ||
|
bab989e3b3 | ||
|
ddc86503dc | ||
|
abad35f700 | ||
|
620b0e69f5 | ||
|
71cf4c7dbf | ||
|
47e8a5d4d1 | ||
|
93bbb194f2 | ||
|
2106fc5266 | ||
|
229b146525 | ||
|
d9fa6f79be | ||
|
4f89214d89 | ||
|
1fdaea29aa | ||
|
1397d0000d | ||
|
4b2abf8ac2 | ||
|
365cb4b368 | ||
|
c85bff235d | ||
|
ad16180b1a |
|
@ -27,7 +27,6 @@ class DifyConfig(
|
|||
# read from dotenv format config file
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
frozen=True,
|
||||
# ignore extra attributes
|
||||
extra="ignore",
|
||||
)
|
||||
|
|
|
@ -11,7 +11,7 @@ from core.provider_manager import ProviderManager
|
|||
|
||||
class ModelConfigConverter:
|
||||
@classmethod
|
||||
def convert(cls, app_config: EasyUIBasedAppConfig, skip_check: bool = False) -> ModelConfigWithCredentialsEntity:
|
||||
def convert(cls, app_config: EasyUIBasedAppConfig) -> ModelConfigWithCredentialsEntity:
|
||||
"""
|
||||
Convert app model config dict to entity.
|
||||
:param app_config: app config
|
||||
|
@ -38,27 +38,23 @@ class ModelConfigConverter:
|
|||
)
|
||||
|
||||
if model_credentials is None:
|
||||
if not skip_check:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
||||
else:
|
||||
model_credentials = {}
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
||||
|
||||
if not skip_check:
|
||||
# check model
|
||||
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||
model=model_config.model, model_type=ModelType.LLM
|
||||
)
|
||||
# check model
|
||||
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||
model=model_config.model, model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
if provider_model is None:
|
||||
model_name = model_config.model
|
||||
raise ValueError(f"Model {model_name} not exist.")
|
||||
if provider_model is None:
|
||||
model_name = model_config.model
|
||||
raise ValueError(f"Model {model_name} not exist.")
|
||||
|
||||
if provider_model.status == ModelStatus.NO_CONFIGURE:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
||||
elif provider_model.status == ModelStatus.NO_PERMISSION:
|
||||
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
|
||||
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
|
||||
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
|
||||
if provider_model.status == ModelStatus.NO_CONFIGURE:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
||||
elif provider_model.status == ModelStatus.NO_PERMISSION:
|
||||
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
|
||||
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
|
||||
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
|
||||
|
||||
# model config
|
||||
completion_params = model_config.parameters
|
||||
|
@ -76,7 +72,7 @@ class ModelConfigConverter:
|
|||
|
||||
model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials)
|
||||
|
||||
if not skip_check and not model_schema:
|
||||
if not model_schema:
|
||||
raise ValueError(f"Model {model_name} not exist.")
|
||||
|
||||
return ModelConfigWithCredentialsEntity(
|
||||
|
|
|
@ -217,6 +217,7 @@ class WorkflowCycleManage:
|
|||
).total_seconds()
|
||||
db.session.commit()
|
||||
|
||||
db.session.add(workflow_run)
|
||||
db.session.refresh(workflow_run)
|
||||
db.session.close()
|
||||
|
||||
|
|
|
@ -74,6 +74,8 @@ def to_prompt_message_content(
|
|||
data = _to_url(f)
|
||||
else:
|
||||
data = _to_base64_data_string(f)
|
||||
if f.extension is None:
|
||||
raise ValueError("Missing file extension")
|
||||
return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
|
||||
case _:
|
||||
raise ValueError("file type f.type is not supported")
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
|
@ -27,7 +28,7 @@ class TokenBufferMemory:
|
|||
|
||||
def get_history_prompt_messages(
|
||||
self, max_token_limit: int = 2000, message_limit: Optional[int] = None
|
||||
) -> list[PromptMessage]:
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""
|
||||
Get history prompt messages.
|
||||
:param max_token_limit: max token limit
|
||||
|
|
|
@ -100,10 +100,10 @@ class ModelInstance:
|
|||
|
||||
def invoke_llm(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: Optional[dict] = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
|
@ -31,7 +32,7 @@ class Callback(ABC):
|
|||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> None:
|
||||
|
@ -60,7 +61,7 @@ class Callback(ABC):
|
|||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
):
|
||||
|
@ -90,7 +91,7 @@ class Callback(ABC):
|
|||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> None:
|
||||
|
@ -120,7 +121,7 @@ class Callback(ABC):
|
|||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> None:
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from abc import ABC
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
|
@ -57,6 +58,7 @@ class PromptMessageContentType(Enum):
|
|||
IMAGE = "image"
|
||||
AUDIO = "audio"
|
||||
VIDEO = "video"
|
||||
DOCUMENT = "document"
|
||||
|
||||
|
||||
class PromptMessageContent(BaseModel):
|
||||
|
@ -107,7 +109,7 @@ class PromptMessage(ABC, BaseModel):
|
|||
"""
|
||||
|
||||
role: PromptMessageRole
|
||||
content: Optional[str | list[PromptMessageContent]] = None
|
||||
content: Optional[str | Sequence[PromptMessageContent]] = None
|
||||
name: Optional[str] = None
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
|
|
|
@ -87,6 +87,9 @@ class ModelFeature(Enum):
|
|||
AGENT_THOUGHT = "agent-thought"
|
||||
VISION = "vision"
|
||||
STREAM_TOOL_CALL = "stream-tool-call"
|
||||
DOCUMENT = "document"
|
||||
VIDEO = "video"
|
||||
AUDIO = "audio"
|
||||
|
||||
|
||||
class DefaultParameterName(str, Enum):
|
||||
|
|
|
@ -2,7 +2,7 @@ import logging
|
|||
import re
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Generator, Mapping
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
@ -48,7 +48,7 @@ class LargeLanguageModel(AIModel):
|
|||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: Optional[dict] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
|
@ -169,7 +169,7 @@ class LargeLanguageModel(AIModel):
|
|||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
|
@ -212,7 +212,7 @@ if you are not sure about the structure.
|
|||
)
|
||||
|
||||
model_parameters.pop("response_format")
|
||||
stop = stop or []
|
||||
stop = list(stop) if stop is not None else []
|
||||
stop.extend(["\n```", "```\n"])
|
||||
block_prompts = block_prompts.replace("{{block}}", code_block)
|
||||
|
||||
|
@ -408,7 +408,7 @@ if you are not sure about the structure.
|
|||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
|
@ -479,7 +479,7 @@ if you are not sure about the structure.
|
|||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
|
@ -601,7 +601,7 @@ if you are not sure about the structure.
|
|||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
|
@ -647,7 +647,7 @@ if you are not sure about the structure.
|
|||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
|
@ -694,7 +694,7 @@ if you are not sure about the structure.
|
|||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
|
@ -742,7 +742,7 @@ if you are not sure about the structure.
|
|||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
|
|
|
@ -8,6 +8,7 @@ features:
|
|||
- agent-thought
|
||||
- stream-tool-call
|
||||
- vision
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import json
|
||||
import random
|
||||
from collections import UserDict
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
|
@ -10,9 +11,9 @@ class ChatRole:
|
|||
FUNCTION = "function"
|
||||
|
||||
|
||||
class _Dict(dict):
|
||||
__setattr__ = dict.__setitem__
|
||||
__getattr__ = dict.__getitem__
|
||||
class _Dict(UserDict):
|
||||
__setattr__ = UserDict.__setitem__
|
||||
__getattr__ = UserDict.__getitem__
|
||||
|
||||
def __missing__(self, key):
|
||||
return None
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
|
||||
from core.model_runtime.entities import (
|
||||
|
@ -14,7 +15,7 @@ from core.prompt.simple_prompt_transform import ModelMode
|
|||
|
||||
class PromptMessageUtil:
|
||||
@staticmethod
|
||||
def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: list[PromptMessage]) -> list[dict]:
|
||||
def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: Sequence[PromptMessage]) -> list[dict]:
|
||||
"""
|
||||
Prompt messages to prompt for saving.
|
||||
:param model_mode: model mode
|
||||
|
|
52
api/core/tools/provider/builtin/fal/tools/wizper.py
Normal file
52
api/core/tools/provider/builtin/fal/tools/wizper.py
Normal file
|
@ -0,0 +1,52 @@
|
|||
import io
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import fal_client
|
||||
|
||||
from core.file.enums import FileAttribute, FileType
|
||||
from core.file.file_manager import download, get_attr
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class WizperTool(BuiltinTool):
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
|
||||
audio_file = tool_parameters.get("audio_file")
|
||||
task = tool_parameters.get("task", "transcribe")
|
||||
language = tool_parameters.get("language", "en")
|
||||
chunk_level = tool_parameters.get("chunk_level", "segment")
|
||||
version = tool_parameters.get("version", "3")
|
||||
|
||||
if audio_file.type != FileType.AUDIO:
|
||||
return [self.create_text_message("Not a valid audio file.")]
|
||||
|
||||
api_key = self.runtime.credentials["fal_api_key"]
|
||||
|
||||
os.environ["FAL_KEY"] = api_key
|
||||
|
||||
audio_binary = io.BytesIO(download(audio_file))
|
||||
mime_type = get_attr(file=audio_file, attr=FileAttribute.MIME_TYPE)
|
||||
file_data = audio_binary.getvalue()
|
||||
|
||||
try:
|
||||
audio_url = fal_client.upload(file_data, mime_type)
|
||||
|
||||
except Exception as e:
|
||||
return [self.create_text_message(f"Error uploading audio file: {str(e)}")]
|
||||
|
||||
arguments = {
|
||||
"audio_url": audio_url,
|
||||
"task": task,
|
||||
"language": language,
|
||||
"chunk_level": chunk_level,
|
||||
"version": version,
|
||||
}
|
||||
|
||||
result = fal_client.subscribe(
|
||||
"fal-ai/wizper",
|
||||
arguments=arguments,
|
||||
with_logs=False,
|
||||
)
|
||||
|
||||
return self.create_json_message(result)
|
489
api/core/tools/provider/builtin/fal/tools/wizper.yaml
Normal file
489
api/core/tools/provider/builtin/fal/tools/wizper.yaml
Normal file
|
@ -0,0 +1,489 @@
|
|||
identity:
|
||||
name: wizper
|
||||
author: Kalo Chin
|
||||
label:
|
||||
en_US: Wizper
|
||||
zh_Hans: Wizper
|
||||
description:
|
||||
human:
|
||||
en_US: Transcribe an audio file using the Whisper model.
|
||||
zh_Hans: 使用 Whisper 模型转录音频文件。
|
||||
llm: Transcribe an audio file using the Whisper model.
|
||||
parameters:
|
||||
- name: audio_file
|
||||
type: file
|
||||
required: true
|
||||
label:
|
||||
en_US: Audio File
|
||||
zh_Hans: 音频文件
|
||||
human_description:
|
||||
en_US: "Upload an audio file to transcribe. Supports mp3, mp4, mpeg, mpga, m4a, wav, or webm formats."
|
||||
zh_Hans: "上传要转录的音频文件。支持 mp3、mp4、mpeg、mpga、m4a、wav 或 webm 格式。"
|
||||
llm_description: "Audio file to transcribe. Supported formats: mp3, mp4, mpeg, mpga, m4a, wav, or webm."
|
||||
form: llm
|
||||
- name: task
|
||||
type: select
|
||||
required: true
|
||||
label:
|
||||
en_US: Task
|
||||
zh_Hans: 任务
|
||||
human_description:
|
||||
en_US: "Choose whether to transcribe the audio in its original language or translate it to English"
|
||||
zh_Hans: "选择是以原始语言转录音频还是将其翻译成英语"
|
||||
llm_description: "Task to perform on the audio file. Either transcribe or translate. Default value: 'transcribe'. If 'translate' is selected as the task, the audio will be translated to English, regardless of the language selected."
|
||||
form: form
|
||||
default: transcribe
|
||||
options:
|
||||
- value: transcribe
|
||||
label:
|
||||
en_US: Transcribe
|
||||
zh_Hans: 转录
|
||||
- value: translate
|
||||
label:
|
||||
en_US: Translate
|
||||
zh_Hans: 翻译
|
||||
- name: language
|
||||
type: select
|
||||
required: true
|
||||
label:
|
||||
en_US: Language
|
||||
zh_Hans: 语言
|
||||
human_description:
|
||||
en_US: "Select the primary language spoken in the audio file"
|
||||
zh_Hans: "选择音频文件中使用的主要语言"
|
||||
llm_description: "Language of the audio file."
|
||||
form: form
|
||||
default: en
|
||||
options:
|
||||
- value: af
|
||||
label:
|
||||
en_US: Afrikaans
|
||||
zh_Hans: 南非语
|
||||
- value: am
|
||||
label:
|
||||
en_US: Amharic
|
||||
zh_Hans: 阿姆哈拉语
|
||||
- value: ar
|
||||
label:
|
||||
en_US: Arabic
|
||||
zh_Hans: 阿拉伯语
|
||||
- value: as
|
||||
label:
|
||||
en_US: Assamese
|
||||
zh_Hans: 阿萨姆语
|
||||
- value: az
|
||||
label:
|
||||
en_US: Azerbaijani
|
||||
zh_Hans: 阿塞拜疆语
|
||||
- value: ba
|
||||
label:
|
||||
en_US: Bashkir
|
||||
zh_Hans: 巴什基尔语
|
||||
- value: be
|
||||
label:
|
||||
en_US: Belarusian
|
||||
zh_Hans: 白俄罗斯语
|
||||
- value: bg
|
||||
label:
|
||||
en_US: Bulgarian
|
||||
zh_Hans: 保加利亚语
|
||||
- value: bn
|
||||
label:
|
||||
en_US: Bengali
|
||||
zh_Hans: 孟加拉语
|
||||
- value: bo
|
||||
label:
|
||||
en_US: Tibetan
|
||||
zh_Hans: 藏语
|
||||
- value: br
|
||||
label:
|
||||
en_US: Breton
|
||||
zh_Hans: 布列塔尼语
|
||||
- value: bs
|
||||
label:
|
||||
en_US: Bosnian
|
||||
zh_Hans: 波斯尼亚语
|
||||
- value: ca
|
||||
label:
|
||||
en_US: Catalan
|
||||
zh_Hans: 加泰罗尼亚语
|
||||
- value: cs
|
||||
label:
|
||||
en_US: Czech
|
||||
zh_Hans: 捷克语
|
||||
- value: cy
|
||||
label:
|
||||
en_US: Welsh
|
||||
zh_Hans: 威尔士语
|
||||
- value: da
|
||||
label:
|
||||
en_US: Danish
|
||||
zh_Hans: 丹麦语
|
||||
- value: de
|
||||
label:
|
||||
en_US: German
|
||||
zh_Hans: 德语
|
||||
- value: el
|
||||
label:
|
||||
en_US: Greek
|
||||
zh_Hans: 希腊语
|
||||
- value: en
|
||||
label:
|
||||
en_US: English
|
||||
zh_Hans: 英语
|
||||
- value: es
|
||||
label:
|
||||
en_US: Spanish
|
||||
zh_Hans: 西班牙语
|
||||
- value: et
|
||||
label:
|
||||
en_US: Estonian
|
||||
zh_Hans: 爱沙尼亚语
|
||||
- value: eu
|
||||
label:
|
||||
en_US: Basque
|
||||
zh_Hans: 巴斯克语
|
||||
- value: fa
|
||||
label:
|
||||
en_US: Persian
|
||||
zh_Hans: 波斯语
|
||||
- value: fi
|
||||
label:
|
||||
en_US: Finnish
|
||||
zh_Hans: 芬兰语
|
||||
- value: fo
|
||||
label:
|
||||
en_US: Faroese
|
||||
zh_Hans: 法罗语
|
||||
- value: fr
|
||||
label:
|
||||
en_US: French
|
||||
zh_Hans: 法语
|
||||
- value: gl
|
||||
label:
|
||||
en_US: Galician
|
||||
zh_Hans: 加利西亚语
|
||||
- value: gu
|
||||
label:
|
||||
en_US: Gujarati
|
||||
zh_Hans: 古吉拉特语
|
||||
- value: ha
|
||||
label:
|
||||
en_US: Hausa
|
||||
zh_Hans: 毫萨语
|
||||
- value: haw
|
||||
label:
|
||||
en_US: Hawaiian
|
||||
zh_Hans: 夏威夷语
|
||||
- value: he
|
||||
label:
|
||||
en_US: Hebrew
|
||||
zh_Hans: 希伯来语
|
||||
- value: hi
|
||||
label:
|
||||
en_US: Hindi
|
||||
zh_Hans: 印地语
|
||||
- value: hr
|
||||
label:
|
||||
en_US: Croatian
|
||||
zh_Hans: 克罗地亚语
|
||||
- value: ht
|
||||
label:
|
||||
en_US: Haitian Creole
|
||||
zh_Hans: 海地克里奥尔语
|
||||
- value: hu
|
||||
label:
|
||||
en_US: Hungarian
|
||||
zh_Hans: 匈牙利语
|
||||
- value: hy
|
||||
label:
|
||||
en_US: Armenian
|
||||
zh_Hans: 亚美尼亚语
|
||||
- value: id
|
||||
label:
|
||||
en_US: Indonesian
|
||||
zh_Hans: 印度尼西亚语
|
||||
- value: is
|
||||
label:
|
||||
en_US: Icelandic
|
||||
zh_Hans: 冰岛语
|
||||
- value: it
|
||||
label:
|
||||
en_US: Italian
|
||||
zh_Hans: 意大利语
|
||||
- value: ja
|
||||
label:
|
||||
en_US: Japanese
|
||||
zh_Hans: 日语
|
||||
- value: jw
|
||||
label:
|
||||
en_US: Javanese
|
||||
zh_Hans: 爪哇语
|
||||
- value: ka
|
||||
label:
|
||||
en_US: Georgian
|
||||
zh_Hans: 格鲁吉亚语
|
||||
- value: kk
|
||||
label:
|
||||
en_US: Kazakh
|
||||
zh_Hans: 哈萨克语
|
||||
- value: km
|
||||
label:
|
||||
en_US: Khmer
|
||||
zh_Hans: 高棉语
|
||||
- value: kn
|
||||
label:
|
||||
en_US: Kannada
|
||||
zh_Hans: 卡纳达语
|
||||
- value: ko
|
||||
label:
|
||||
en_US: Korean
|
||||
zh_Hans: 韩语
|
||||
- value: la
|
||||
label:
|
||||
en_US: Latin
|
||||
zh_Hans: 拉丁语
|
||||
- value: lb
|
||||
label:
|
||||
en_US: Luxembourgish
|
||||
zh_Hans: 卢森堡语
|
||||
- value: ln
|
||||
label:
|
||||
en_US: Lingala
|
||||
zh_Hans: 林加拉语
|
||||
- value: lo
|
||||
label:
|
||||
en_US: Lao
|
||||
zh_Hans: 老挝语
|
||||
- value: lt
|
||||
label:
|
||||
en_US: Lithuanian
|
||||
zh_Hans: 立陶宛语
|
||||
- value: lv
|
||||
label:
|
||||
en_US: Latvian
|
||||
zh_Hans: 拉脱维亚语
|
||||
- value: mg
|
||||
label:
|
||||
en_US: Malagasy
|
||||
zh_Hans: 马尔加什语
|
||||
- value: mi
|
||||
label:
|
||||
en_US: Maori
|
||||
zh_Hans: 毛利语
|
||||
- value: mk
|
||||
label:
|
||||
en_US: Macedonian
|
||||
zh_Hans: 马其顿语
|
||||
- value: ml
|
||||
label:
|
||||
en_US: Malayalam
|
||||
zh_Hans: 马拉雅拉姆语
|
||||
- value: mn
|
||||
label:
|
||||
en_US: Mongolian
|
||||
zh_Hans: 蒙古语
|
||||
- value: mr
|
||||
label:
|
||||
en_US: Marathi
|
||||
zh_Hans: 马拉地语
|
||||
- value: ms
|
||||
label:
|
||||
en_US: Malay
|
||||
zh_Hans: 马来语
|
||||
- value: mt
|
||||
label:
|
||||
en_US: Maltese
|
||||
zh_Hans: 马耳他语
|
||||
- value: my
|
||||
label:
|
||||
en_US: Burmese
|
||||
zh_Hans: 缅甸语
|
||||
- value: ne
|
||||
label:
|
||||
en_US: Nepali
|
||||
zh_Hans: 尼泊尔语
|
||||
- value: nl
|
||||
label:
|
||||
en_US: Dutch
|
||||
zh_Hans: 荷兰语
|
||||
- value: nn
|
||||
label:
|
||||
en_US: Norwegian Nynorsk
|
||||
zh_Hans: 新挪威语
|
||||
- value: no
|
||||
label:
|
||||
en_US: Norwegian
|
||||
zh_Hans: 挪威语
|
||||
- value: oc
|
||||
label:
|
||||
en_US: Occitan
|
||||
zh_Hans: 奥克语
|
||||
- value: pa
|
||||
label:
|
||||
en_US: Punjabi
|
||||
zh_Hans: 旁遮普语
|
||||
- value: pl
|
||||
label:
|
||||
en_US: Polish
|
||||
zh_Hans: 波兰语
|
||||
- value: ps
|
||||
label:
|
||||
en_US: Pashto
|
||||
zh_Hans: 普什图语
|
||||
- value: pt
|
||||
label:
|
||||
en_US: Portuguese
|
||||
zh_Hans: 葡萄牙语
|
||||
- value: ro
|
||||
label:
|
||||
en_US: Romanian
|
||||
zh_Hans: 罗马尼亚语
|
||||
- value: ru
|
||||
label:
|
||||
en_US: Russian
|
||||
zh_Hans: 俄语
|
||||
- value: sa
|
||||
label:
|
||||
en_US: Sanskrit
|
||||
zh_Hans: 梵语
|
||||
- value: sd
|
||||
label:
|
||||
en_US: Sindhi
|
||||
zh_Hans: 信德语
|
||||
- value: si
|
||||
label:
|
||||
en_US: Sinhala
|
||||
zh_Hans: 僧伽罗语
|
||||
- value: sk
|
||||
label:
|
||||
en_US: Slovak
|
||||
zh_Hans: 斯洛伐克语
|
||||
- value: sl
|
||||
label:
|
||||
en_US: Slovenian
|
||||
zh_Hans: 斯洛文尼亚语
|
||||
- value: sn
|
||||
label:
|
||||
en_US: Shona
|
||||
zh_Hans: 修纳语
|
||||
- value: so
|
||||
label:
|
||||
en_US: Somali
|
||||
zh_Hans: 索马里语
|
||||
- value: sq
|
||||
label:
|
||||
en_US: Albanian
|
||||
zh_Hans: 阿尔巴尼亚语
|
||||
- value: sr
|
||||
label:
|
||||
en_US: Serbian
|
||||
zh_Hans: 塞尔维亚语
|
||||
- value: su
|
||||
label:
|
||||
en_US: Sundanese
|
||||
zh_Hans: 巽他语
|
||||
- value: sv
|
||||
label:
|
||||
en_US: Swedish
|
||||
zh_Hans: 瑞典语
|
||||
- value: sw
|
||||
label:
|
||||
en_US: Swahili
|
||||
zh_Hans: 斯瓦希里语
|
||||
- value: ta
|
||||
label:
|
||||
en_US: Tamil
|
||||
zh_Hans: 泰米尔语
|
||||
- value: te
|
||||
label:
|
||||
en_US: Telugu
|
||||
zh_Hans: 泰卢固语
|
||||
- value: tg
|
||||
label:
|
||||
en_US: Tajik
|
||||
zh_Hans: 塔吉克语
|
||||
- value: th
|
||||
label:
|
||||
en_US: Thai
|
||||
zh_Hans: 泰语
|
||||
- value: tk
|
||||
label:
|
||||
en_US: Turkmen
|
||||
zh_Hans: 土库曼语
|
||||
- value: tl
|
||||
label:
|
||||
en_US: Tagalog
|
||||
zh_Hans: 他加禄语
|
||||
- value: tr
|
||||
label:
|
||||
en_US: Turkish
|
||||
zh_Hans: 土耳其语
|
||||
- value: tt
|
||||
label:
|
||||
en_US: Tatar
|
||||
zh_Hans: 鞑靼语
|
||||
- value: uk
|
||||
label:
|
||||
en_US: Ukrainian
|
||||
zh_Hans: 乌克兰语
|
||||
- value: ur
|
||||
label:
|
||||
en_US: Urdu
|
||||
zh_Hans: 乌尔都语
|
||||
- value: uz
|
||||
label:
|
||||
en_US: Uzbek
|
||||
zh_Hans: 乌兹别克语
|
||||
- value: vi
|
||||
label:
|
||||
en_US: Vietnamese
|
||||
zh_Hans: 越南语
|
||||
- value: yi
|
||||
label:
|
||||
en_US: Yiddish
|
||||
zh_Hans: 意第绪语
|
||||
- value: yo
|
||||
label:
|
||||
en_US: Yoruba
|
||||
zh_Hans: 约鲁巴语
|
||||
- value: yue
|
||||
label:
|
||||
en_US: Cantonese
|
||||
zh_Hans: 粤语
|
||||
- value: zh
|
||||
label:
|
||||
en_US: Chinese
|
||||
zh_Hans: 中文
|
||||
- name: chunk_level
|
||||
type: select
|
||||
label:
|
||||
en_US: Chunk Level
|
||||
zh_Hans: 分块级别
|
||||
human_description:
|
||||
en_US: "Choose how the transcription should be divided into chunks"
|
||||
zh_Hans: "选择如何将转录内容分成块"
|
||||
llm_description: "Level of the chunks to return."
|
||||
form: form
|
||||
default: segment
|
||||
options:
|
||||
- value: segment
|
||||
label:
|
||||
en_US: Segment
|
||||
zh_Hans: 段
|
||||
- name: version
|
||||
type: select
|
||||
label:
|
||||
en_US: Version
|
||||
zh_Hans: 版本
|
||||
human_description:
|
||||
en_US: "Select which version of the Whisper large model to use"
|
||||
zh_Hans: "选择要使用的 Whisper large 模型版本"
|
||||
llm_description: "Version of the model to use. All of the models are the Whisper large variant."
|
||||
form: form
|
||||
default: "3"
|
||||
options:
|
||||
- value: "3"
|
||||
label:
|
||||
en_US: Version 3
|
||||
zh_Hans: 版本 3
|
|
@ -118,11 +118,11 @@ class FileSegment(Segment):
|
|||
|
||||
@property
|
||||
def log(self) -> str:
|
||||
return str(self.value)
|
||||
return ""
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
return str(self.value)
|
||||
return ""
|
||||
|
||||
|
||||
class ArrayAnySegment(ArraySegment):
|
||||
|
@ -155,3 +155,11 @@ class ArrayFileSegment(ArraySegment):
|
|||
for item in self.value:
|
||||
items.append(item.markdown)
|
||||
return "\n".join(items)
|
||||
|
||||
@property
|
||||
def log(self) -> str:
|
||||
return ""
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
return ""
|
||||
|
|
|
@ -39,7 +39,14 @@ class VisionConfig(BaseModel):
|
|||
|
||||
|
||||
class PromptConfig(BaseModel):
|
||||
jinja2_variables: Optional[list[VariableSelector]] = None
|
||||
jinja2_variables: Sequence[VariableSelector] = Field(default_factory=list)
|
||||
|
||||
@field_validator("jinja2_variables", mode="before")
|
||||
@classmethod
|
||||
def convert_none_jinja2_variables(cls, v: Any):
|
||||
if v is None:
|
||||
return []
|
||||
return v
|
||||
|
||||
|
||||
class LLMNodeChatModelMessage(ChatModelMessage):
|
||||
|
@ -53,7 +60,14 @@ class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
|
|||
class LLMNodeData(BaseNodeData):
|
||||
model: ModelConfig
|
||||
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
|
||||
prompt_config: Optional[PromptConfig] = None
|
||||
prompt_config: PromptConfig = Field(default_factory=PromptConfig)
|
||||
memory: Optional[MemoryConfig] = None
|
||||
context: ContextConfig
|
||||
vision: VisionConfig = Field(default_factory=VisionConfig)
|
||||
|
||||
@field_validator("prompt_config", mode="before")
|
||||
@classmethod
|
||||
def convert_none_prompt_config(cls, v: Any):
|
||||
if v is None:
|
||||
return PromptConfig()
|
||||
return v
|
||||
|
|
|
@ -24,3 +24,11 @@ class LLMModeRequiredError(LLMNodeError):
|
|||
|
||||
class NoPromptFoundError(LLMNodeError):
|
||||
"""Raised when no prompt is found in the LLM configuration."""
|
||||
|
||||
|
||||
class NotSupportedPromptTypeError(LLMNodeError):
|
||||
"""Raised when the prompt type is not supported."""
|
||||
|
||||
|
||||
class MemoryRolePrefixRequiredError(LLMNodeError):
|
||||
"""Raised when memory role prefix is required for completion model."""
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import json
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
|
||||
|
@ -6,21 +7,26 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
|
|||
from core.entities.model_entities import ModelStatus
|
||||
from core.entities.provider_entities import QuotaUnit
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.file import FileType, file_manager
|
||||
from core.helper.code_executor import CodeExecutor, CodeLanguage
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities import (
|
||||
AudioPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
TextPromptMessageContent,
|
||||
VideoPromptMessageContent,
|
||||
)
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessageRole,
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.variables import (
|
||||
|
@ -32,8 +38,9 @@ from core.variables import (
|
|||
ObjectSegment,
|
||||
StringSegment,
|
||||
)
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.event import InNodeEvent
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
|
@ -62,14 +69,18 @@ from .exc import (
|
|||
InvalidVariableTypeError,
|
||||
LLMModeRequiredError,
|
||||
LLMNodeError,
|
||||
MemoryRolePrefixRequiredError,
|
||||
ModelNotExistError,
|
||||
NoPromptFoundError,
|
||||
NotSupportedPromptTypeError,
|
||||
VariableNotFoundError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.file.models import File
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMNode(BaseNode[LLMNodeData]):
|
||||
_node_data_cls = LLMNodeData
|
||||
|
@ -123,17 +134,13 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
|
||||
# fetch prompt messages
|
||||
if self.node_data.memory:
|
||||
query = self.graph_runtime_state.variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
|
||||
if not query:
|
||||
raise VariableNotFoundError("Query not found")
|
||||
query = query.text
|
||||
query = self.node_data.memory.query_prompt_template
|
||||
else:
|
||||
query = None
|
||||
|
||||
prompt_messages, stop = self._fetch_prompt_messages(
|
||||
system_query=query,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
user_query=query,
|
||||
user_files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
|
@ -141,6 +148,8 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
memory_config=self.node_data.memory,
|
||||
vision_enabled=self.node_data.vision.enabled,
|
||||
vision_detail=self.node_data.vision.configs.detail,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
|
||||
)
|
||||
|
||||
process_data = {
|
||||
|
@ -181,6 +190,17 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
)
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception(f"Node {self.node_id} failed to run: {e}")
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
inputs=node_inputs,
|
||||
process_data=process_data,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
|
||||
|
||||
|
@ -203,8 +223,8 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
self,
|
||||
node_data_model: ModelConfig,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: list[PromptMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
) -> Generator[NodeEvent, None, None]:
|
||||
db.session.close()
|
||||
|
||||
|
@ -519,9 +539,8 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
def _fetch_prompt_messages(
|
||||
self,
|
||||
*,
|
||||
system_query: str | None = None,
|
||||
inputs: dict[str, str] | None = None,
|
||||
files: Sequence["File"],
|
||||
user_query: str | None = None,
|
||||
user_files: Sequence["File"],
|
||||
context: str | None = None,
|
||||
memory: TokenBufferMemory | None = None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
|
@ -529,58 +548,146 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
memory_config: MemoryConfig | None = None,
|
||||
vision_enabled: bool = False,
|
||||
vision_detail: ImagePromptMessageContent.DETAIL,
|
||||
) -> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
inputs = inputs or {}
|
||||
variable_pool: VariablePool,
|
||||
jinja2_variables: Sequence[VariableSelector],
|
||||
) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]:
|
||||
prompt_messages = []
|
||||
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=prompt_template,
|
||||
inputs=inputs,
|
||||
query=system_query or "",
|
||||
files=files,
|
||||
context=context,
|
||||
memory_config=memory_config,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
)
|
||||
stop = model_config.stop
|
||||
if isinstance(prompt_template, list):
|
||||
# For chat model
|
||||
prompt_messages.extend(
|
||||
_handle_list_messages(
|
||||
messages=prompt_template,
|
||||
context=context,
|
||||
jinja2_variables=jinja2_variables,
|
||||
variable_pool=variable_pool,
|
||||
vision_detail_config=vision_detail,
|
||||
)
|
||||
)
|
||||
|
||||
# Get memory messages for chat mode
|
||||
memory_messages = _handle_memory_chat_mode(
|
||||
memory=memory,
|
||||
memory_config=memory_config,
|
||||
model_config=model_config,
|
||||
)
|
||||
# Extend prompt_messages with memory messages
|
||||
prompt_messages.extend(memory_messages)
|
||||
|
||||
# Add current query to the prompt messages
|
||||
if user_query:
|
||||
message = LLMNodeChatModelMessage(
|
||||
text=user_query,
|
||||
role=PromptMessageRole.USER,
|
||||
edition_type="basic",
|
||||
)
|
||||
prompt_messages.extend(
|
||||
_handle_list_messages(
|
||||
messages=[message],
|
||||
context="",
|
||||
jinja2_variables=[],
|
||||
variable_pool=variable_pool,
|
||||
vision_detail_config=vision_detail,
|
||||
)
|
||||
)
|
||||
|
||||
elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
|
||||
# For completion model
|
||||
prompt_messages.extend(
|
||||
_handle_completion_template(
|
||||
template=prompt_template,
|
||||
context=context,
|
||||
jinja2_variables=jinja2_variables,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
)
|
||||
|
||||
# Get memory text for completion model
|
||||
memory_text = _handle_memory_completion_mode(
|
||||
memory=memory,
|
||||
memory_config=memory_config,
|
||||
model_config=model_config,
|
||||
)
|
||||
# Insert histories into the prompt
|
||||
prompt_content = prompt_messages[0].content
|
||||
if "#histories#" in prompt_content:
|
||||
prompt_content = prompt_content.replace("#histories#", memory_text)
|
||||
else:
|
||||
prompt_content = memory_text + "\n" + prompt_content
|
||||
prompt_messages[0].content = prompt_content
|
||||
|
||||
# Add current query to the prompt message
|
||||
if user_query:
|
||||
prompt_content = prompt_messages[0].content.replace("#sys.query#", user_query)
|
||||
prompt_messages[0].content = prompt_content
|
||||
else:
|
||||
errmsg = f"Prompt type {type(prompt_template)} is not supported"
|
||||
logger.warning(errmsg)
|
||||
raise NotSupportedPromptTypeError(errmsg)
|
||||
|
||||
if vision_enabled and user_files:
|
||||
file_prompts = []
|
||||
for file in user_files:
|
||||
file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
|
||||
file_prompts.append(file_prompt)
|
||||
if (
|
||||
len(prompt_messages) > 0
|
||||
and isinstance(prompt_messages[-1], UserPromptMessage)
|
||||
and isinstance(prompt_messages[-1].content, list)
|
||||
):
|
||||
prompt_messages[-1] = UserPromptMessage(content=prompt_messages[-1].content + file_prompts)
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=file_prompts))
|
||||
|
||||
# Filter prompt messages
|
||||
filtered_prompt_messages = []
|
||||
for prompt_message in prompt_messages:
|
||||
if prompt_message.is_empty():
|
||||
continue
|
||||
|
||||
if not isinstance(prompt_message.content, str):
|
||||
if isinstance(prompt_message.content, list):
|
||||
prompt_message_content = []
|
||||
for content_item in prompt_message.content or []:
|
||||
# Skip image if vision is disabled
|
||||
if not vision_enabled and content_item.type == PromptMessageContentType.IMAGE:
|
||||
for content_item in prompt_message.content:
|
||||
# Skip content if features are not defined
|
||||
if not model_config.model_schema.features:
|
||||
if content_item.type != PromptMessageContentType.TEXT:
|
||||
continue
|
||||
prompt_message_content.append(content_item)
|
||||
continue
|
||||
|
||||
if isinstance(content_item, ImagePromptMessageContent):
|
||||
# Override vision config if LLM node has vision config,
|
||||
# cuz vision detail is related to the configuration from FileUpload feature.
|
||||
content_item.detail = vision_detail
|
||||
prompt_message_content.append(content_item)
|
||||
elif isinstance(
|
||||
content_item, TextPromptMessageContent | AudioPromptMessageContent | VideoPromptMessageContent
|
||||
# Skip content if corresponding feature is not supported
|
||||
if (
|
||||
(
|
||||
content_item.type == PromptMessageContentType.IMAGE
|
||||
and ModelFeature.VISION not in model_config.model_schema.features
|
||||
)
|
||||
or (
|
||||
content_item.type == PromptMessageContentType.DOCUMENT
|
||||
and ModelFeature.DOCUMENT not in model_config.model_schema.features
|
||||
)
|
||||
or (
|
||||
content_item.type == PromptMessageContentType.VIDEO
|
||||
and ModelFeature.VIDEO not in model_config.model_schema.features
|
||||
)
|
||||
or (
|
||||
content_item.type == PromptMessageContentType.AUDIO
|
||||
and ModelFeature.AUDIO not in model_config.model_schema.features
|
||||
)
|
||||
):
|
||||
prompt_message_content.append(content_item)
|
||||
|
||||
if len(prompt_message_content) > 1:
|
||||
prompt_message.content = prompt_message_content
|
||||
elif (
|
||||
len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT
|
||||
):
|
||||
continue
|
||||
prompt_message_content.append(content_item)
|
||||
if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT:
|
||||
prompt_message.content = prompt_message_content[0].data
|
||||
|
||||
else:
|
||||
prompt_message.content = prompt_message_content
|
||||
if prompt_message.is_empty():
|
||||
continue
|
||||
filtered_prompt_messages.append(prompt_message)
|
||||
|
||||
if not filtered_prompt_messages:
|
||||
if len(filtered_prompt_messages) == 0:
|
||||
raise NoPromptFoundError(
|
||||
"No prompt found in the LLM configuration. "
|
||||
"Please ensure a prompt is properly configured before proceeding."
|
||||
)
|
||||
|
||||
stop = model_config.stop
|
||||
return filtered_prompt_messages, stop
|
||||
|
||||
@classmethod
|
||||
|
@ -715,3 +822,198 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _combine_text_message_with_role(*, text: str, role: PromptMessageRole):
|
||||
match role:
|
||||
case PromptMessageRole.USER:
|
||||
return UserPromptMessage(content=[TextPromptMessageContent(data=text)])
|
||||
case PromptMessageRole.ASSISTANT:
|
||||
return AssistantPromptMessage(content=[TextPromptMessageContent(data=text)])
|
||||
case PromptMessageRole.SYSTEM:
|
||||
return SystemPromptMessage(content=[TextPromptMessageContent(data=text)])
|
||||
raise NotImplementedError(f"Role {role} is not supported")
|
||||
|
||||
|
||||
def _render_jinja2_message(
|
||||
*,
|
||||
template: str,
|
||||
jinjia2_variables: Sequence[VariableSelector],
|
||||
variable_pool: VariablePool,
|
||||
):
|
||||
if not template:
|
||||
return ""
|
||||
|
||||
jinjia2_inputs = {}
|
||||
for jinja2_variable in jinjia2_variables:
|
||||
variable = variable_pool.get(jinja2_variable.value_selector)
|
||||
jinjia2_inputs[jinja2_variable.variable] = variable.to_object() if variable else ""
|
||||
code_execute_resp = CodeExecutor.execute_workflow_code_template(
|
||||
language=CodeLanguage.JINJA2,
|
||||
code=template,
|
||||
inputs=jinjia2_inputs,
|
||||
)
|
||||
result_text = code_execute_resp["result"]
|
||||
return result_text
|
||||
|
||||
|
||||
def _handle_list_messages(
|
||||
*,
|
||||
messages: Sequence[LLMNodeChatModelMessage],
|
||||
context: Optional[str],
|
||||
jinja2_variables: Sequence[VariableSelector],
|
||||
variable_pool: VariablePool,
|
||||
vision_detail_config: ImagePromptMessageContent.DETAIL,
|
||||
) -> Sequence[PromptMessage]:
|
||||
prompt_messages = []
|
||||
for message in messages:
|
||||
if message.edition_type == "jinja2":
|
||||
result_text = _render_jinja2_message(
|
||||
template=message.jinja2_text or "",
|
||||
jinjia2_variables=jinja2_variables,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
prompt_message = _combine_text_message_with_role(text=result_text, role=message.role)
|
||||
prompt_messages.append(prompt_message)
|
||||
else:
|
||||
# Get segment group from basic message
|
||||
if context:
|
||||
template = message.text.replace("{#context#}", context)
|
||||
else:
|
||||
template = message.text
|
||||
segment_group = variable_pool.convert_template(template)
|
||||
|
||||
# Process segments for images
|
||||
file_contents = []
|
||||
for segment in segment_group.value:
|
||||
if isinstance(segment, ArrayFileSegment):
|
||||
for file in segment.value:
|
||||
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO}:
|
||||
file_content = file_manager.to_prompt_message_content(
|
||||
file, image_detail_config=vision_detail_config
|
||||
)
|
||||
file_contents.append(file_content)
|
||||
if isinstance(segment, FileSegment):
|
||||
file = segment.value
|
||||
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO}:
|
||||
file_content = file_manager.to_prompt_message_content(
|
||||
file, image_detail_config=vision_detail_config
|
||||
)
|
||||
file_contents.append(file_content)
|
||||
|
||||
# Create message with text from all segments
|
||||
plain_text = segment_group.text
|
||||
if plain_text:
|
||||
prompt_message = _combine_text_message_with_role(text=plain_text, role=message.role)
|
||||
prompt_messages.append(prompt_message)
|
||||
|
||||
if file_contents:
|
||||
# Create message with image contents
|
||||
prompt_message = UserPromptMessage(content=file_contents)
|
||||
prompt_messages.append(prompt_message)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
|
||||
def _calculate_rest_token(
|
||||
*, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
|
||||
) -> int:
|
||||
rest_tokens = 2000
|
||||
|
||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
if model_context_tokens:
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
|
||||
)
|
||||
|
||||
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
if parameter_rule.name == "max_tokens" or (
|
||||
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
||||
):
|
||||
max_tokens = (
|
||||
model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(str(parameter_rule.use_template))
|
||||
or 0
|
||||
)
|
||||
|
||||
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
|
||||
rest_tokens = max(rest_tokens, 0)
|
||||
|
||||
return rest_tokens
|
||||
|
||||
|
||||
def _handle_memory_chat_mode(
|
||||
*,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
) -> Sequence[PromptMessage]:
|
||||
memory_messages = []
|
||||
# Get messages from memory for chat model
|
||||
if memory and memory_config:
|
||||
rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config)
|
||||
memory_messages = memory.get_history_prompt_messages(
|
||||
max_token_limit=rest_tokens,
|
||||
message_limit=memory_config.window.size if memory_config.window.enabled else None,
|
||||
)
|
||||
return memory_messages
|
||||
|
||||
|
||||
def _handle_memory_completion_mode(
|
||||
*,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
) -> str:
|
||||
memory_text = ""
|
||||
# Get history text from memory for completion model
|
||||
if memory and memory_config:
|
||||
rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config)
|
||||
if not memory_config.role_prefix:
|
||||
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
|
||||
memory_text = memory.get_history_prompt_text(
|
||||
max_token_limit=rest_tokens,
|
||||
message_limit=memory_config.window.size if memory_config.window.enabled else None,
|
||||
human_prefix=memory_config.role_prefix.user,
|
||||
ai_prefix=memory_config.role_prefix.assistant,
|
||||
)
|
||||
return memory_text
|
||||
|
||||
|
||||
def _handle_completion_template(
|
||||
*,
|
||||
template: LLMNodeCompletionModelPromptTemplate,
|
||||
context: Optional[str],
|
||||
jinja2_variables: Sequence[VariableSelector],
|
||||
variable_pool: VariablePool,
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""Handle completion template processing outside of LLMNode class.
|
||||
|
||||
Args:
|
||||
template: The completion model prompt template
|
||||
context: Optional context string
|
||||
jinja2_variables: Variables for jinja2 template rendering
|
||||
variable_pool: Variable pool for template conversion
|
||||
|
||||
Returns:
|
||||
Sequence of prompt messages
|
||||
"""
|
||||
prompt_messages = []
|
||||
if template.edition_type == "jinja2":
|
||||
result_text = _render_jinja2_message(
|
||||
template=template.jinja2_text or "",
|
||||
jinjia2_variables=jinja2_variables,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
else:
|
||||
if context:
|
||||
template_text = template.text.replace("{#context#}", context)
|
||||
else:
|
||||
template_text = template.text
|
||||
result_text = variable_pool.convert_template(template_text).text
|
||||
prompt_message = _combine_text_message_with_role(text=result_text, role=PromptMessageRole.USER)
|
||||
prompt_messages.append(prompt_message)
|
||||
return prompt_messages
|
||||
|
|
|
@ -86,12 +86,14 @@ class QuestionClassifierNode(LLMNode):
|
|||
)
|
||||
prompt_messages, stop = self._fetch_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
system_query=query,
|
||||
user_query=query,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
files=files,
|
||||
user_files=files,
|
||||
vision_enabled=node_data.vision.enabled,
|
||||
vision_detail=node_data.vision.configs.detail,
|
||||
variable_pool=variable_pool,
|
||||
jinja2_variables=[],
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
from collections.abc import Mapping, Sequence
|
||||
from os import path
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
|
@ -180,7 +179,6 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||
for response in tool_response:
|
||||
if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
|
||||
url = str(response.message) if response.message else None
|
||||
ext = path.splitext(url)[1] if url else ".bin"
|
||||
tool_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
transfer_method = response.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
|
||||
|
||||
|
@ -202,7 +200,6 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||
)
|
||||
result.append(file)
|
||||
elif response.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
# get tool file id
|
||||
tool_file_id = str(response.message).split("/")[-1].split(".")[0]
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
|
@ -211,7 +208,6 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||
raise ValueError(f"tool file {tool_file_id} not exists")
|
||||
mapping = {
|
||||
"tool_file_id": tool_file_id,
|
||||
"type": FileType.IMAGE,
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
|
@ -228,13 +224,8 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ToolFileError(f"Tool file {tool_file_id} does not exist")
|
||||
if "." in url:
|
||||
extension = "." + url.split("/")[-1].split(".")[1]
|
||||
else:
|
||||
extension = ".bin"
|
||||
mapping = {
|
||||
"tool_file_id": tool_file_id,
|
||||
"type": FileType.IMAGE,
|
||||
"transfer_method": transfer_method,
|
||||
"url": url,
|
||||
}
|
||||
|
|
|
@ -180,6 +180,20 @@ def _get_remote_file_info(url: str):
|
|||
return mime_type, filename, file_size
|
||||
|
||||
|
||||
def _get_file_type_by_mimetype(mime_type: str) -> FileType:
|
||||
if "image" in mime_type:
|
||||
file_type = FileType.IMAGE
|
||||
elif "video" in mime_type:
|
||||
file_type = FileType.VIDEO
|
||||
elif "audio" in mime_type:
|
||||
file_type = FileType.AUDIO
|
||||
elif "text" in mime_type or "pdf" in mime_type:
|
||||
file_type = FileType.DOCUMENT
|
||||
else:
|
||||
file_type = FileType.CUSTOM
|
||||
return file_type
|
||||
|
||||
|
||||
def _build_from_tool_file(
|
||||
*,
|
||||
mapping: Mapping[str, Any],
|
||||
|
@ -199,12 +213,13 @@ def _build_from_tool_file(
|
|||
raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found")
|
||||
|
||||
extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
|
||||
file_type = mapping.get("type", _get_file_type_by_mimetype(tool_file.mimetype))
|
||||
|
||||
return File(
|
||||
id=mapping.get("id"),
|
||||
tenant_id=tenant_id,
|
||||
filename=tool_file.name,
|
||||
type=FileType.value_of(mapping.get("type")),
|
||||
type=file_type,
|
||||
transfer_method=transfer_method,
|
||||
remote_url=tool_file.original_url,
|
||||
related_id=tool_file.id,
|
||||
|
|
86
api/poetry.lock
generated
86
api/poetry.lock
generated
|
@ -2411,6 +2411,41 @@ files = [
|
|||
[package.extras]
|
||||
test = ["pytest (>=6)"]
|
||||
|
||||
[[package]]
|
||||
name = "faker"
|
||||
version = "32.1.0"
|
||||
description = "Faker is a Python package that generates fake data for you."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "Faker-32.1.0-py3-none-any.whl", hash = "sha256:c77522577863c264bdc9dad3a2a750ad3f7ee43ff8185072e482992288898814"},
|
||||
{file = "faker-32.1.0.tar.gz", hash = "sha256:aac536ba04e6b7beb2332c67df78485fc29c1880ff723beac6d1efd45e2f10f5"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
python-dateutil = ">=2.4"
|
||||
typing-extensions = "*"
|
||||
|
||||
[[package]]
|
||||
name = "fal-client"
|
||||
version = "0.5.6"
|
||||
description = "Python client for fal.ai"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "fal_client-0.5.6-py3-none-any.whl", hash = "sha256:631fd857a3c44753ee46a2eea1e7276471453aca58faac9c3702f744c7c84050"},
|
||||
{file = "fal_client-0.5.6.tar.gz", hash = "sha256:d3afc4b6250023d0ee8437ec504558231d3b106d7aabc12cda8c39883faddecb"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
httpx = ">=0.21.0,<1"
|
||||
httpx-sse = ">=0.4.0,<0.5"
|
||||
|
||||
[package.extras]
|
||||
dev = ["fal-client[docs,test]"]
|
||||
docs = ["sphinx", "sphinx-autodoc-typehints", "sphinx-rtd-theme"]
|
||||
test = ["pillow", "pytest", "pytest-asyncio"]
|
||||
|
||||
[[package]]
|
||||
name = "fastapi"
|
||||
version = "0.115.4"
|
||||
|
@ -4049,6 +4084,17 @@ http2 = ["h2 (>=3,<5)"]
|
|||
socks = ["socksio (==1.*)"]
|
||||
zstd = ["zstandard (>=0.18.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "httpx-sse"
|
||||
version = "0.4.0"
|
||||
description = "Consume Server-Sent Event (SSE) messages with HTTPX."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "httpx-sse-0.4.0.tar.gz", hash = "sha256:1e81a3a3070ce322add1d3529ed42eb5f70817f45ed6ec915ab753f961139721"},
|
||||
{file = "httpx_sse-0.4.0-py3-none-any.whl", hash = "sha256:f329af6eae57eaa2bdfd962b42524764af68075ea87370a2de920af5341e318f"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "huggingface-hub"
|
||||
version = "0.16.4"
|
||||
|
@ -8466,29 +8512,29 @@ pyasn1 = ">=0.1.3"
|
|||
|
||||
[[package]]
|
||||
name = "ruff"
|
||||
version = "0.6.9"
|
||||
version = "0.7.3"
|
||||
description = "An extremely fast Python linter and code formatter, written in Rust."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "ruff-0.6.9-py3-none-linux_armv6l.whl", hash = "sha256:064df58d84ccc0ac0fcd63bc3090b251d90e2a372558c0f057c3f75ed73e1ccd"},
|
||||
{file = "ruff-0.6.9-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:140d4b5c9f5fc7a7b074908a78ab8d384dd7f6510402267bc76c37195c02a7ec"},
|
||||
{file = "ruff-0.6.9-py3-none-macosx_11_0_arm64.whl", hash = "sha256:53fd8ca5e82bdee8da7f506d7b03a261f24cd43d090ea9db9a1dc59d9313914c"},
|
||||
{file = "ruff-0.6.9-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:645d7d8761f915e48a00d4ecc3686969761df69fb561dd914a773c1a8266e14e"},
|
||||
{file = "ruff-0.6.9-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eae02b700763e3847595b9d2891488989cac00214da7f845f4bcf2989007d577"},
|
||||
{file = "ruff-0.6.9-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7d5ccc9e58112441de8ad4b29dcb7a86dc25c5f770e3c06a9d57e0e5eba48829"},
|
||||
{file = "ruff-0.6.9-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:417b81aa1c9b60b2f8edc463c58363075412866ae4e2b9ab0f690dc1e87ac1b5"},
|
||||
{file = "ruff-0.6.9-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3c866b631f5fbce896a74a6e4383407ba7507b815ccc52bcedabb6810fdb3ef7"},
|
||||
{file = "ruff-0.6.9-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7b118afbb3202f5911486ad52da86d1d52305b59e7ef2031cea3425142b97d6f"},
|
||||
{file = "ruff-0.6.9-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a67267654edc23c97335586774790cde402fb6bbdb3c2314f1fc087dee320bfa"},
|
||||
{file = "ruff-0.6.9-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:3ef0cc774b00fec123f635ce5c547dac263f6ee9fb9cc83437c5904183b55ceb"},
|
||||
{file = "ruff-0.6.9-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:12edd2af0c60fa61ff31cefb90aef4288ac4d372b4962c2864aeea3a1a2460c0"},
|
||||
{file = "ruff-0.6.9-py3-none-musllinux_1_2_i686.whl", hash = "sha256:55bb01caeaf3a60b2b2bba07308a02fca6ab56233302406ed5245180a05c5625"},
|
||||
{file = "ruff-0.6.9-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:925d26471fa24b0ce5a6cdfab1bb526fb4159952385f386bdcc643813d472039"},
|
||||
{file = "ruff-0.6.9-py3-none-win32.whl", hash = "sha256:eb61ec9bdb2506cffd492e05ac40e5bc6284873aceb605503d8494180d6fc84d"},
|
||||
{file = "ruff-0.6.9-py3-none-win_amd64.whl", hash = "sha256:785d31851c1ae91f45b3d8fe23b8ae4b5170089021fbb42402d811135f0b7117"},
|
||||
{file = "ruff-0.6.9-py3-none-win_arm64.whl", hash = "sha256:a9641e31476d601f83cd602608739a0840e348bda93fec9f1ee816f8b6798b93"},
|
||||
{file = "ruff-0.6.9.tar.gz", hash = "sha256:b076ef717a8e5bc819514ee1d602bbdca5b4420ae13a9cf61a0c0a4f53a2baa2"},
|
||||
{file = "ruff-0.7.3-py3-none-linux_armv6l.whl", hash = "sha256:34f2339dc22687ec7e7002792d1f50712bf84a13d5152e75712ac08be565d344"},
|
||||
{file = "ruff-0.7.3-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:fb397332a1879b9764a3455a0bb1087bda876c2db8aca3a3cbb67b3dbce8cda0"},
|
||||
{file = "ruff-0.7.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:37d0b619546103274e7f62643d14e1adcbccb242efda4e4bdb9544d7764782e9"},
|
||||
{file = "ruff-0.7.3-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5d59f0c3ee4d1a6787614e7135b72e21024875266101142a09a61439cb6e38a5"},
|
||||
{file = "ruff-0.7.3-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:44eb93c2499a169d49fafd07bc62ac89b1bc800b197e50ff4633aed212569299"},
|
||||
{file = "ruff-0.7.3-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6d0242ce53f3a576c35ee32d907475a8d569944c0407f91d207c8af5be5dae4e"},
|
||||
{file = "ruff-0.7.3-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:6b6224af8b5e09772c2ecb8dc9f3f344c1aa48201c7f07e7315367f6dd90ac29"},
|
||||
{file = "ruff-0.7.3-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c50f95a82b94421c964fae4c27c0242890a20fe67d203d127e84fbb8013855f5"},
|
||||
{file = "ruff-0.7.3-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7f3eff9961b5d2644bcf1616c606e93baa2d6b349e8aa8b035f654df252c8c67"},
|
||||
{file = "ruff-0.7.3-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8963cab06d130c4df2fd52c84e9f10d297826d2e8169ae0c798b6221be1d1d2"},
|
||||
{file = "ruff-0.7.3-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:61b46049d6edc0e4317fb14b33bd693245281a3007288b68a3f5b74a22a0746d"},
|
||||
{file = "ruff-0.7.3-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:10ebce7696afe4644e8c1a23b3cf8c0f2193a310c18387c06e583ae9ef284de2"},
|
||||
{file = "ruff-0.7.3-py3-none-musllinux_1_2_i686.whl", hash = "sha256:3f36d56326b3aef8eeee150b700e519880d1aab92f471eefdef656fd57492aa2"},
|
||||
{file = "ruff-0.7.3-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:5d024301109a0007b78d57ab0ba190087b43dce852e552734ebf0b0b85e4fb16"},
|
||||
{file = "ruff-0.7.3-py3-none-win32.whl", hash = "sha256:4ba81a5f0c5478aa61674c5a2194de8b02652f17addf8dfc40c8937e6e7d79fc"},
|
||||
{file = "ruff-0.7.3-py3-none-win_amd64.whl", hash = "sha256:588a9ff2fecf01025ed065fe28809cd5a53b43505f48b69a1ac7707b1b7e4088"},
|
||||
{file = "ruff-0.7.3-py3-none-win_arm64.whl", hash = "sha256:1713e2c5545863cdbfe2cbce21f69ffaf37b813bfd1fb3b90dc9a6f1963f5a8c"},
|
||||
{file = "ruff-0.7.3.tar.gz", hash = "sha256:e1d1ba2e40b6e71a61b063354d04be669ab0d39c352461f3d789cac68b54a313"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -11005,4 +11051,4 @@ cffi = ["cffi (>=1.11)"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.10,<3.13"
|
||||
content-hash = "f20bd678044926913dbbc24bd0cf22503a75817aa55f59457ff7822032139b77"
|
||||
content-hash = "cf4e0467f622e58b51411ee1d784928962f52dbf877b8ee013c810909a1f07db"
|
||||
|
|
|
@ -122,6 +122,7 @@ celery = "~5.4.0"
|
|||
chardet = "~5.1.0"
|
||||
cohere = "~5.2.4"
|
||||
dashscope = { version = "~1.17.0", extras = ["tokenizer"] }
|
||||
fal-client = "0.5.6"
|
||||
flask = "~3.0.1"
|
||||
flask-compress = "~1.14"
|
||||
flask-cors = "~4.0.0"
|
||||
|
@ -265,6 +266,7 @@ weaviate-client = "~3.21.0"
|
|||
optional = true
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
coverage = "~7.2.4"
|
||||
faker = "~32.1.0"
|
||||
pytest = "~8.3.2"
|
||||
pytest-benchmark = "~4.0.0"
|
||||
pytest-env = "~1.1.3"
|
||||
|
@ -278,4 +280,4 @@ pytest-mock = "~3.14.0"
|
|||
optional = true
|
||||
[tool.poetry.group.lint.dependencies]
|
||||
dotenv-linter = "~0.5.0"
|
||||
ruff = "~0.6.9"
|
||||
ruff = "~0.7.3"
|
||||
|
|
|
@ -11,7 +11,6 @@ from core.model_runtime.entities.message_entities import (
|
|||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.azure_ai_studio.llm.llm import AzureAIStudioLargeLanguageModel
|
||||
from tests.integration_tests.model_runtime.__mock.azure_ai_studio import setup_azure_ai_studio_mock
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_azure_ai_studio_mock", [["chat"]], indirect=True)
|
||||
|
|
|
@ -4,29 +4,21 @@ import pytest
|
|||
|
||||
from core.model_runtime.entities.rerank_entities import RerankResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.azure_ai_studio.rerank.rerank import AzureAIStudioRerankModel
|
||||
from core.model_runtime.model_providers.azure_ai_studio.rerank.rerank import AzureRerankModel
|
||||
|
||||
|
||||
def test_validate_credentials():
|
||||
model = AzureAIStudioRerankModel()
|
||||
model = AzureRerankModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model="azure-ai-studio-rerank-v1",
|
||||
credentials={"api_key": "invalid_key", "api_base": os.getenv("AZURE_AI_STUDIO_API_BASE")},
|
||||
query="What is the capital of the United States?",
|
||||
docs=[
|
||||
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
|
||||
"Census, Carson City had a population of 55,274.",
|
||||
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
|
||||
"are a political division controlled by the United States. Its capital is Saipan.",
|
||||
],
|
||||
score_threshold=0.8,
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = AzureAIStudioRerankModel()
|
||||
model = AzureRerankModel()
|
||||
|
||||
result = model.invoke(
|
||||
model="azure-ai-studio-rerank-v1",
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
from collections import UserDict
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
@ -11,7 +12,7 @@ from pymochow.model.table import Table
|
|||
from requests.adapters import HTTPAdapter
|
||||
|
||||
|
||||
class AttrDict(dict):
|
||||
class AttrDict(UserDict):
|
||||
def __getattr__(self, item):
|
||||
return self.get(item)
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
from collections import UserDict
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
@ -50,7 +51,7 @@ class MockIndex:
|
|||
return AttrDict({"dimension": 1024})
|
||||
|
||||
|
||||
class AttrDict(dict):
|
||||
class AttrDict(UserDict):
|
||||
def __getattr__(self, item):
|
||||
return self.get(item)
|
||||
|
||||
|
|
|
@ -1,125 +1,484 @@
|
|||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
|
||||
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
||||
from core.entities.provider_entities import CustomConfiguration, SystemConfiguration
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageRole,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType, ProviderModel
|
||||
from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.nodes.answer import AnswerStreamGenerateRoute
|
||||
from core.workflow.nodes.end import EndStreamParam
|
||||
from core.workflow.nodes.llm.entities import ContextConfig, LLMNodeData, ModelConfig, VisionConfig, VisionConfigOptions
|
||||
from core.workflow.nodes.llm.entities import (
|
||||
ContextConfig,
|
||||
LLMNodeChatModelMessage,
|
||||
LLMNodeData,
|
||||
ModelConfig,
|
||||
VisionConfig,
|
||||
VisionConfigOptions,
|
||||
)
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from models.enums import UserFrom
|
||||
from models.provider import ProviderType
|
||||
from models.workflow import WorkflowType
|
||||
from tests.unit_tests.core.workflow.nodes.llm.test_scenarios import LLMNodeTestScenario
|
||||
|
||||
|
||||
class TestLLMNode:
|
||||
@pytest.fixture
|
||||
def llm_node(self):
|
||||
data = LLMNodeData(
|
||||
title="Test LLM",
|
||||
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
|
||||
prompt_template=[],
|
||||
memory=None,
|
||||
context=ContextConfig(enabled=False),
|
||||
vision=VisionConfig(
|
||||
enabled=True,
|
||||
configs=VisionConfigOptions(
|
||||
variable_selector=["sys", "files"],
|
||||
detail=ImagePromptMessageContent.DETAIL.HIGH,
|
||||
),
|
||||
),
|
||||
)
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
)
|
||||
node = LLMNode(
|
||||
id="1",
|
||||
config={
|
||||
"id": "1",
|
||||
"data": data.model_dump(),
|
||||
},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config={},
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
call_depth=0,
|
||||
),
|
||||
graph=Graph(
|
||||
root_node_id="1",
|
||||
answer_stream_generate_routes=AnswerStreamGenerateRoute(
|
||||
answer_dependencies={},
|
||||
answer_generate_route={},
|
||||
),
|
||||
end_stream_param=EndStreamParam(
|
||||
end_dependencies={},
|
||||
end_stream_variable_selector_mapping={},
|
||||
),
|
||||
),
|
||||
graph_runtime_state=GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=0,
|
||||
),
|
||||
)
|
||||
return node
|
||||
class MockTokenBufferMemory:
|
||||
def __init__(self, history_messages=None):
|
||||
self.history_messages = history_messages or []
|
||||
|
||||
def test_fetch_files_with_file_segment(self, llm_node):
|
||||
file = File(
|
||||
def get_history_prompt_messages(
|
||||
self, max_token_limit: int = 2000, message_limit: Optional[int] = None
|
||||
) -> Sequence[PromptMessage]:
|
||||
if message_limit is not None:
|
||||
return self.history_messages[-message_limit * 2 :]
|
||||
return self.history_messages
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm_node():
|
||||
data = LLMNodeData(
|
||||
title="Test LLM",
|
||||
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
|
||||
prompt_template=[],
|
||||
memory=None,
|
||||
context=ContextConfig(enabled=False),
|
||||
vision=VisionConfig(
|
||||
enabled=True,
|
||||
configs=VisionConfigOptions(
|
||||
variable_selector=["sys", "files"],
|
||||
detail=ImagePromptMessageContent.DETAIL.HIGH,
|
||||
),
|
||||
),
|
||||
)
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
)
|
||||
node = LLMNode(
|
||||
id="1",
|
||||
config={
|
||||
"id": "1",
|
||||
"data": data.model_dump(),
|
||||
},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config={},
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
call_depth=0,
|
||||
),
|
||||
graph=Graph(
|
||||
root_node_id="1",
|
||||
answer_stream_generate_routes=AnswerStreamGenerateRoute(
|
||||
answer_dependencies={},
|
||||
answer_generate_route={},
|
||||
),
|
||||
end_stream_param=EndStreamParam(
|
||||
end_dependencies={},
|
||||
end_stream_variable_selector_mapping={},
|
||||
),
|
||||
),
|
||||
graph_runtime_state=GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=0,
|
||||
),
|
||||
)
|
||||
return node
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_config():
|
||||
# Create actual provider and model type instances
|
||||
model_provider_factory = ModelProviderFactory()
|
||||
provider_instance = model_provider_factory.get_provider_instance("openai")
|
||||
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
|
||||
|
||||
# Create a ProviderModelBundle
|
||||
provider_model_bundle = ProviderModelBundle(
|
||||
configuration=ProviderConfiguration(
|
||||
tenant_id="1",
|
||||
provider=provider_instance.get_provider_schema(),
|
||||
preferred_provider_type=ProviderType.CUSTOM,
|
||||
using_provider_type=ProviderType.CUSTOM,
|
||||
system_configuration=SystemConfiguration(enabled=False),
|
||||
custom_configuration=CustomConfiguration(provider=None),
|
||||
model_settings=[],
|
||||
),
|
||||
provider_instance=provider_instance,
|
||||
model_type_instance=model_type_instance,
|
||||
)
|
||||
|
||||
# Create and return a ModelConfigWithCredentialsEntity
|
||||
return ModelConfigWithCredentialsEntity(
|
||||
provider="openai",
|
||||
model="gpt-3.5-turbo",
|
||||
model_schema=AIModelEntity(
|
||||
model="gpt-3.5-turbo",
|
||||
label=I18nObject(en_US="GPT-3.5 Turbo"),
|
||||
model_type=ModelType.LLM,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={},
|
||||
),
|
||||
mode="chat",
|
||||
credentials={},
|
||||
parameters={},
|
||||
provider_model_bundle=provider_model_bundle,
|
||||
)
|
||||
|
||||
|
||||
def test_fetch_files_with_file_segment(llm_node):
|
||||
file = File(
|
||||
id="1",
|
||||
tenant_id="test",
|
||||
type=FileType.IMAGE,
|
||||
filename="test.jpg",
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="1",
|
||||
)
|
||||
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file)
|
||||
|
||||
result = llm_node._fetch_files(selector=["sys", "files"])
|
||||
assert result == [file]
|
||||
|
||||
|
||||
def test_fetch_files_with_array_file_segment(llm_node):
|
||||
files = [
|
||||
File(
|
||||
id="1",
|
||||
tenant_id="test",
|
||||
type=FileType.IMAGE,
|
||||
filename="test.jpg",
|
||||
filename="test1.jpg",
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="1",
|
||||
),
|
||||
File(
|
||||
id="2",
|
||||
tenant_id="test",
|
||||
type=FileType.IMAGE,
|
||||
filename="test2.jpg",
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="2",
|
||||
),
|
||||
]
|
||||
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))
|
||||
|
||||
result = llm_node._fetch_files(selector=["sys", "files"])
|
||||
assert result == files
|
||||
|
||||
|
||||
def test_fetch_files_with_none_segment(llm_node):
|
||||
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment())
|
||||
|
||||
result = llm_node._fetch_files(selector=["sys", "files"])
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_fetch_files_with_array_any_segment(llm_node):
|
||||
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[]))
|
||||
|
||||
result = llm_node._fetch_files(selector=["sys", "files"])
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_fetch_files_with_non_existent_variable(llm_node):
|
||||
result = llm_node._fetch_files(selector=["sys", "files"])
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config):
|
||||
prompt_template = []
|
||||
llm_node.node_data.prompt_template = prompt_template
|
||||
|
||||
fake_vision_detail = faker.random_element(
|
||||
[ImagePromptMessageContent.DETAIL.HIGH, ImagePromptMessageContent.DETAIL.LOW]
|
||||
)
|
||||
fake_remote_url = faker.url()
|
||||
files = [
|
||||
File(
|
||||
id="1",
|
||||
tenant_id="test",
|
||||
type=FileType.IMAGE,
|
||||
filename="test1.jpg",
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url=fake_remote_url,
|
||||
)
|
||||
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file)
|
||||
]
|
||||
|
||||
result = llm_node._fetch_files(selector=["sys", "files"])
|
||||
assert result == [file]
|
||||
fake_query = faker.sentence()
|
||||
|
||||
def test_fetch_files_with_array_file_segment(self, llm_node):
|
||||
files = [
|
||||
File(
|
||||
id="1",
|
||||
tenant_id="test",
|
||||
type=FileType.IMAGE,
|
||||
filename="test1.jpg",
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="1",
|
||||
),
|
||||
File(
|
||||
id="2",
|
||||
tenant_id="test",
|
||||
type=FileType.IMAGE,
|
||||
filename="test2.jpg",
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="2",
|
||||
),
|
||||
]
|
||||
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))
|
||||
prompt_messages, _ = llm_node._fetch_prompt_messages(
|
||||
user_query=fake_query,
|
||||
user_files=files,
|
||||
context=None,
|
||||
memory=None,
|
||||
model_config=model_config,
|
||||
prompt_template=prompt_template,
|
||||
memory_config=None,
|
||||
vision_enabled=False,
|
||||
vision_detail=fake_vision_detail,
|
||||
variable_pool=llm_node.graph_runtime_state.variable_pool,
|
||||
jinja2_variables=[],
|
||||
)
|
||||
|
||||
result = llm_node._fetch_files(selector=["sys", "files"])
|
||||
assert result == files
|
||||
assert prompt_messages == [UserPromptMessage(content=fake_query)]
|
||||
|
||||
def test_fetch_files_with_none_segment(self, llm_node):
|
||||
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment())
|
||||
|
||||
result = llm_node._fetch_files(selector=["sys", "files"])
|
||||
assert result == []
|
||||
def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
|
||||
# Setup dify config
|
||||
dify_config.MULTIMODAL_SEND_IMAGE_FORMAT = "url"
|
||||
dify_config.MULTIMODAL_SEND_VIDEO_FORMAT = "url"
|
||||
|
||||
def test_fetch_files_with_array_any_segment(self, llm_node):
|
||||
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[]))
|
||||
# Generate fake values for prompt template
|
||||
fake_assistant_prompt = faker.sentence()
|
||||
fake_query = faker.sentence()
|
||||
fake_context = faker.sentence()
|
||||
fake_window_size = faker.random_int(min=1, max=3)
|
||||
fake_vision_detail = faker.random_element(
|
||||
[ImagePromptMessageContent.DETAIL.HIGH, ImagePromptMessageContent.DETAIL.LOW]
|
||||
)
|
||||
fake_remote_url = faker.url()
|
||||
|
||||
result = llm_node._fetch_files(selector=["sys", "files"])
|
||||
assert result == []
|
||||
# Setup mock memory with history messages
|
||||
mock_history = [
|
||||
UserPromptMessage(content=faker.sentence()),
|
||||
AssistantPromptMessage(content=faker.sentence()),
|
||||
UserPromptMessage(content=faker.sentence()),
|
||||
AssistantPromptMessage(content=faker.sentence()),
|
||||
UserPromptMessage(content=faker.sentence()),
|
||||
AssistantPromptMessage(content=faker.sentence()),
|
||||
]
|
||||
|
||||
def test_fetch_files_with_non_existent_variable(self, llm_node):
|
||||
result = llm_node._fetch_files(selector=["sys", "files"])
|
||||
assert result == []
|
||||
# Setup memory configuration
|
||||
memory_config = MemoryConfig(
|
||||
role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"),
|
||||
window=MemoryConfig.WindowConfig(enabled=True, size=fake_window_size),
|
||||
query_prompt_template=None,
|
||||
)
|
||||
|
||||
memory = MockTokenBufferMemory(history_messages=mock_history)
|
||||
|
||||
# Test scenarios covering different file input combinations
|
||||
test_scenarios = [
|
||||
LLMNodeTestScenario(
|
||||
description="No files",
|
||||
user_query=fake_query,
|
||||
user_files=[],
|
||||
features=[],
|
||||
vision_enabled=False,
|
||||
vision_detail=None,
|
||||
window_size=fake_window_size,
|
||||
prompt_template=[
|
||||
LLMNodeChatModelMessage(
|
||||
text=fake_context,
|
||||
role=PromptMessageRole.SYSTEM,
|
||||
edition_type="basic",
|
||||
),
|
||||
LLMNodeChatModelMessage(
|
||||
text="{#context#}",
|
||||
role=PromptMessageRole.USER,
|
||||
edition_type="basic",
|
||||
),
|
||||
LLMNodeChatModelMessage(
|
||||
text=fake_assistant_prompt,
|
||||
role=PromptMessageRole.ASSISTANT,
|
||||
edition_type="basic",
|
||||
),
|
||||
],
|
||||
expected_messages=[
|
||||
SystemPromptMessage(content=fake_context),
|
||||
UserPromptMessage(content=fake_context),
|
||||
AssistantPromptMessage(content=fake_assistant_prompt),
|
||||
]
|
||||
+ mock_history[fake_window_size * -2 :]
|
||||
+ [
|
||||
UserPromptMessage(content=fake_query),
|
||||
],
|
||||
),
|
||||
LLMNodeTestScenario(
|
||||
description="User files",
|
||||
user_query=fake_query,
|
||||
user_files=[
|
||||
File(
|
||||
tenant_id="test",
|
||||
type=FileType.IMAGE,
|
||||
filename="test1.jpg",
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url=fake_remote_url,
|
||||
)
|
||||
],
|
||||
vision_enabled=True,
|
||||
vision_detail=fake_vision_detail,
|
||||
features=[ModelFeature.VISION],
|
||||
window_size=fake_window_size,
|
||||
prompt_template=[
|
||||
LLMNodeChatModelMessage(
|
||||
text=fake_context,
|
||||
role=PromptMessageRole.SYSTEM,
|
||||
edition_type="basic",
|
||||
),
|
||||
LLMNodeChatModelMessage(
|
||||
text="{#context#}",
|
||||
role=PromptMessageRole.USER,
|
||||
edition_type="basic",
|
||||
),
|
||||
LLMNodeChatModelMessage(
|
||||
text=fake_assistant_prompt,
|
||||
role=PromptMessageRole.ASSISTANT,
|
||||
edition_type="basic",
|
||||
),
|
||||
],
|
||||
expected_messages=[
|
||||
SystemPromptMessage(content=fake_context),
|
||||
UserPromptMessage(content=fake_context),
|
||||
AssistantPromptMessage(content=fake_assistant_prompt),
|
||||
]
|
||||
+ mock_history[fake_window_size * -2 :]
|
||||
+ [
|
||||
UserPromptMessage(
|
||||
content=[
|
||||
TextPromptMessageContent(data=fake_query),
|
||||
ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail),
|
||||
]
|
||||
),
|
||||
],
|
||||
),
|
||||
LLMNodeTestScenario(
|
||||
description="Prompt template with variable selector of File",
|
||||
user_query=fake_query,
|
||||
user_files=[],
|
||||
vision_enabled=False,
|
||||
vision_detail=fake_vision_detail,
|
||||
features=[ModelFeature.VISION],
|
||||
window_size=fake_window_size,
|
||||
prompt_template=[
|
||||
LLMNodeChatModelMessage(
|
||||
text="{{#input.image#}}",
|
||||
role=PromptMessageRole.USER,
|
||||
edition_type="basic",
|
||||
),
|
||||
],
|
||||
expected_messages=[
|
||||
UserPromptMessage(
|
||||
content=[
|
||||
ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail),
|
||||
]
|
||||
),
|
||||
]
|
||||
+ mock_history[fake_window_size * -2 :]
|
||||
+ [UserPromptMessage(content=fake_query)],
|
||||
file_variables={
|
||||
"input.image": File(
|
||||
tenant_id="test",
|
||||
type=FileType.IMAGE,
|
||||
filename="test1.jpg",
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url=fake_remote_url,
|
||||
)
|
||||
},
|
||||
),
|
||||
LLMNodeTestScenario(
|
||||
description="Prompt template with variable selector of File without vision feature",
|
||||
user_query=fake_query,
|
||||
user_files=[],
|
||||
vision_enabled=True,
|
||||
vision_detail=fake_vision_detail,
|
||||
features=[],
|
||||
window_size=fake_window_size,
|
||||
prompt_template=[
|
||||
LLMNodeChatModelMessage(
|
||||
text="{{#input.image#}}",
|
||||
role=PromptMessageRole.USER,
|
||||
edition_type="basic",
|
||||
),
|
||||
],
|
||||
expected_messages=mock_history[fake_window_size * -2 :] + [UserPromptMessage(content=fake_query)],
|
||||
file_variables={
|
||||
"input.image": File(
|
||||
tenant_id="test",
|
||||
type=FileType.IMAGE,
|
||||
filename="test1.jpg",
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url=fake_remote_url,
|
||||
)
|
||||
},
|
||||
),
|
||||
LLMNodeTestScenario(
|
||||
description="Prompt template with variable selector of File with video file and vision feature",
|
||||
user_query=fake_query,
|
||||
user_files=[],
|
||||
vision_enabled=True,
|
||||
vision_detail=fake_vision_detail,
|
||||
features=[ModelFeature.VISION],
|
||||
window_size=fake_window_size,
|
||||
prompt_template=[
|
||||
LLMNodeChatModelMessage(
|
||||
text="{{#input.image#}}",
|
||||
role=PromptMessageRole.USER,
|
||||
edition_type="basic",
|
||||
),
|
||||
],
|
||||
expected_messages=mock_history[fake_window_size * -2 :] + [UserPromptMessage(content=fake_query)],
|
||||
file_variables={
|
||||
"input.image": File(
|
||||
tenant_id="test",
|
||||
type=FileType.VIDEO,
|
||||
filename="test1.mp4",
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url=fake_remote_url,
|
||||
extension="mp4",
|
||||
)
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
for scenario in test_scenarios:
|
||||
model_config.model_schema.features = scenario.features
|
||||
|
||||
for k, v in scenario.file_variables.items():
|
||||
selector = k.split(".")
|
||||
llm_node.graph_runtime_state.variable_pool.add(selector, v)
|
||||
|
||||
# Call the method under test
|
||||
prompt_messages, _ = llm_node._fetch_prompt_messages(
|
||||
user_query=scenario.user_query,
|
||||
user_files=scenario.user_files,
|
||||
context=fake_context,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
prompt_template=scenario.prompt_template,
|
||||
memory_config=memory_config,
|
||||
vision_enabled=scenario.vision_enabled,
|
||||
vision_detail=scenario.vision_detail,
|
||||
variable_pool=llm_node.graph_runtime_state.variable_pool,
|
||||
jinja2_variables=[],
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert len(prompt_messages) == len(scenario.expected_messages), f"Scenario failed: {scenario.description}"
|
||||
assert (
|
||||
prompt_messages == scenario.expected_messages
|
||||
), f"Message content mismatch in scenario: {scenario.description}"
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
from collections.abc import Mapping, Sequence
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.file import File
|
||||
from core.model_runtime.entities.message_entities import PromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelFeature
|
||||
from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage
|
||||
|
||||
|
||||
class LLMNodeTestScenario(BaseModel):
|
||||
"""Test scenario for LLM node testing."""
|
||||
|
||||
description: str = Field(..., description="Description of the test scenario")
|
||||
user_query: str = Field(..., description="User query input")
|
||||
user_files: Sequence[File] = Field(default_factory=list, description="List of user files")
|
||||
vision_enabled: bool = Field(default=False, description="Whether vision is enabled")
|
||||
vision_detail: str | None = Field(None, description="Vision detail level if vision is enabled")
|
||||
features: Sequence[ModelFeature] = Field(default_factory=list, description="List of model features")
|
||||
window_size: int = Field(..., description="Window size for memory")
|
||||
prompt_template: Sequence[LLMNodeChatModelMessage] = Field(..., description="Template for prompt messages")
|
||||
file_variables: Mapping[str, File | Sequence[File]] = Field(
|
||||
default_factory=dict, description="List of file variables"
|
||||
)
|
||||
expected_messages: Sequence[PromptMessage] = Field(..., description="Expected messages after processing")
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
from collections import UserDict
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
@ -14,7 +15,7 @@ from tests.unit_tests.oss.__mock.base import (
|
|||
)
|
||||
|
||||
|
||||
class AttrDict(dict):
|
||||
class AttrDict(UserDict):
|
||||
def __getattr__(self, item):
|
||||
return self.get(item)
|
||||
|
||||
|
|
|
@ -44,12 +44,6 @@ export const fileUpload: FileUpload = ({
|
|||
}
|
||||
|
||||
export const getFileExtension = (fileName: string, fileMimetype: string, isRemote?: boolean) => {
|
||||
if (fileMimetype)
|
||||
return mime.getExtension(fileMimetype) || ''
|
||||
|
||||
if (isRemote)
|
||||
return ''
|
||||
|
||||
if (fileName) {
|
||||
const fileNamePair = fileName.split('.')
|
||||
const fileNamePairLength = fileNamePair.length
|
||||
|
@ -58,6 +52,12 @@ export const getFileExtension = (fileName: string, fileMimetype: string, isRemot
|
|||
return fileNamePair[fileNamePairLength - 1]
|
||||
}
|
||||
|
||||
if (fileMimetype)
|
||||
return mime.getExtension(fileMimetype) || ''
|
||||
|
||||
if (isRemote)
|
||||
return ''
|
||||
|
||||
return ''
|
||||
}
|
||||
|
||||
|
|
|
@ -144,6 +144,7 @@ const ConfigPromptItem: FC<Props> = ({
|
|||
onEditionTypeChange={onEditionTypeChange}
|
||||
varList={varList}
|
||||
handleAddVariable={handleAddVariable}
|
||||
isSupportFileVar
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
|
|
@ -67,6 +67,7 @@ const Panel: FC<NodePanelProps<LLMNodeType>> = ({
|
|||
handleStop,
|
||||
varInputs,
|
||||
runResult,
|
||||
filterJinjia2InputVar,
|
||||
} = useConfig(id, data)
|
||||
|
||||
const model = inputs.model
|
||||
|
@ -194,7 +195,7 @@ const Panel: FC<NodePanelProps<LLMNodeType>> = ({
|
|||
list={inputs.prompt_config?.jinja2_variables || []}
|
||||
onChange={handleVarListChange}
|
||||
onVarNameChange={handleVarNameChange}
|
||||
filterVar={filterVar}
|
||||
filterVar={filterJinjia2InputVar}
|
||||
/>
|
||||
</Field>
|
||||
)}
|
||||
|
@ -233,6 +234,7 @@ const Panel: FC<NodePanelProps<LLMNodeType>> = ({
|
|||
hasSetBlockStatus={hasSetBlockStatus}
|
||||
nodesOutputVars={availableVars}
|
||||
availableNodes={availableNodesWithParent}
|
||||
isSupportFileVar
|
||||
/>
|
||||
|
||||
{inputs.memory.query_prompt_template && !inputs.memory.query_prompt_template.includes('{{#sys.query#}}') && (
|
||||
|
|
|
@ -278,11 +278,15 @@ const useConfig = (id: string, payload: LLMNodeType) => {
|
|||
}, [inputs, setInputs])
|
||||
|
||||
const filterInputVar = useCallback((varPayload: Var) => {
|
||||
return [VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber, VarType.arrayFile].includes(varPayload.type)
|
||||
}, [])
|
||||
|
||||
const filterJinjia2InputVar = useCallback((varPayload: Var) => {
|
||||
return [VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber].includes(varPayload.type)
|
||||
}, [])
|
||||
|
||||
const filterMemoryPromptVar = useCallback((varPayload: Var) => {
|
||||
return [VarType.arrayObject, VarType.array, VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber].includes(varPayload.type)
|
||||
return [VarType.arrayObject, VarType.array, VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber, VarType.arrayFile].includes(varPayload.type)
|
||||
}, [])
|
||||
|
||||
const {
|
||||
|
@ -406,6 +410,7 @@ const useConfig = (id: string, payload: LLMNodeType) => {
|
|||
handleRun,
|
||||
handleStop,
|
||||
runResult,
|
||||
filterJinjia2InputVar,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -86,8 +86,8 @@ const translation = {
|
|||
agenteLogDetail: {
|
||||
agentMode: 'Modo Agente',
|
||||
toolUsed: 'Ferramenta usada',
|
||||
iterações: 'Iterações',
|
||||
iteração: 'Iteração',
|
||||
iterations: 'Iterações',
|
||||
iteration: 'Iteração',
|
||||
finalProcessing: 'Processamento Final',
|
||||
},
|
||||
agentLogDetail: {
|
||||
|
|
Loading…
Reference in New Issue
Block a user