mirror of
https://github.com/langgenius/dify.git
synced 2024-11-15 19:22:36 +08:00
Compare commits
32 Commits
72c1b7d1c2
...
abacc3768f
Author | SHA1 | Date | |
---|---|---|---|
|
abacc3768f | ||
|
e31358219c | ||
|
4e360ec19a | ||
|
f68d6bd5e2 | ||
|
b860a893c8 | ||
|
0354c7813e | ||
|
94794d892e | ||
|
fb94d0b7cf | ||
|
02c39b2631 | ||
|
87f78ff582 | ||
|
6872b32c7d | ||
|
97fab7649b | ||
|
9f0f82cb1c | ||
|
ef08abafdf | ||
|
d6c9ab8554 | ||
|
bab989e3b3 | ||
|
ddc86503dc | ||
|
abad35f700 | ||
|
620b0e69f5 | ||
|
71cf4c7dbf | ||
|
47e8a5d4d1 | ||
|
93bbb194f2 | ||
|
2106fc5266 | ||
|
229b146525 | ||
|
d9fa6f79be | ||
|
4f89214d89 | ||
|
1fdaea29aa | ||
|
1397d0000d | ||
|
4b2abf8ac2 | ||
|
365cb4b368 | ||
|
c85bff235d | ||
|
ad16180b1a |
|
@ -27,7 +27,6 @@ class DifyConfig(
|
||||||
# read from dotenv format config file
|
# read from dotenv format config file
|
||||||
env_file=".env",
|
env_file=".env",
|
||||||
env_file_encoding="utf-8",
|
env_file_encoding="utf-8",
|
||||||
frozen=True,
|
|
||||||
# ignore extra attributes
|
# ignore extra attributes
|
||||||
extra="ignore",
|
extra="ignore",
|
||||||
)
|
)
|
||||||
|
|
|
@ -11,7 +11,7 @@ from core.provider_manager import ProviderManager
|
||||||
|
|
||||||
class ModelConfigConverter:
|
class ModelConfigConverter:
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert(cls, app_config: EasyUIBasedAppConfig, skip_check: bool = False) -> ModelConfigWithCredentialsEntity:
|
def convert(cls, app_config: EasyUIBasedAppConfig) -> ModelConfigWithCredentialsEntity:
|
||||||
"""
|
"""
|
||||||
Convert app model config dict to entity.
|
Convert app model config dict to entity.
|
||||||
:param app_config: app config
|
:param app_config: app config
|
||||||
|
@ -38,27 +38,23 @@ class ModelConfigConverter:
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_credentials is None:
|
if model_credentials is None:
|
||||||
if not skip_check:
|
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
||||||
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
|
||||||
else:
|
|
||||||
model_credentials = {}
|
|
||||||
|
|
||||||
if not skip_check:
|
# check model
|
||||||
# check model
|
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||||
provider_model = provider_model_bundle.configuration.get_provider_model(
|
model=model_config.model, model_type=ModelType.LLM
|
||||||
model=model_config.model, model_type=ModelType.LLM
|
)
|
||||||
)
|
|
||||||
|
|
||||||
if provider_model is None:
|
if provider_model is None:
|
||||||
model_name = model_config.model
|
model_name = model_config.model
|
||||||
raise ValueError(f"Model {model_name} not exist.")
|
raise ValueError(f"Model {model_name} not exist.")
|
||||||
|
|
||||||
if provider_model.status == ModelStatus.NO_CONFIGURE:
|
if provider_model.status == ModelStatus.NO_CONFIGURE:
|
||||||
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
||||||
elif provider_model.status == ModelStatus.NO_PERMISSION:
|
elif provider_model.status == ModelStatus.NO_PERMISSION:
|
||||||
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
|
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
|
||||||
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
|
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
|
||||||
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
|
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
|
||||||
|
|
||||||
# model config
|
# model config
|
||||||
completion_params = model_config.parameters
|
completion_params = model_config.parameters
|
||||||
|
@ -76,7 +72,7 @@ class ModelConfigConverter:
|
||||||
|
|
||||||
model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials)
|
model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials)
|
||||||
|
|
||||||
if not skip_check and not model_schema:
|
if not model_schema:
|
||||||
raise ValueError(f"Model {model_name} not exist.")
|
raise ValueError(f"Model {model_name} not exist.")
|
||||||
|
|
||||||
return ModelConfigWithCredentialsEntity(
|
return ModelConfigWithCredentialsEntity(
|
||||||
|
|
|
@ -217,6 +217,7 @@ class WorkflowCycleManage:
|
||||||
).total_seconds()
|
).total_seconds()
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
|
db.session.add(workflow_run)
|
||||||
db.session.refresh(workflow_run)
|
db.session.refresh(workflow_run)
|
||||||
db.session.close()
|
db.session.close()
|
||||||
|
|
||||||
|
|
|
@ -74,6 +74,8 @@ def to_prompt_message_content(
|
||||||
data = _to_url(f)
|
data = _to_url(f)
|
||||||
else:
|
else:
|
||||||
data = _to_base64_data_string(f)
|
data = _to_base64_data_string(f)
|
||||||
|
if f.extension is None:
|
||||||
|
raise ValueError("Missing file extension")
|
||||||
return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
|
return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
|
||||||
case _:
|
case _:
|
||||||
raise ValueError("file type f.type is not supported")
|
raise ValueError("file type f.type is not supported")
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from collections.abc import Sequence
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||||
|
@ -27,7 +28,7 @@ class TokenBufferMemory:
|
||||||
|
|
||||||
def get_history_prompt_messages(
|
def get_history_prompt_messages(
|
||||||
self, max_token_limit: int = 2000, message_limit: Optional[int] = None
|
self, max_token_limit: int = 2000, message_limit: Optional[int] = None
|
||||||
) -> list[PromptMessage]:
|
) -> Sequence[PromptMessage]:
|
||||||
"""
|
"""
|
||||||
Get history prompt messages.
|
Get history prompt messages.
|
||||||
:param max_token_limit: max token limit
|
:param max_token_limit: max token limit
|
||||||
|
|
|
@ -100,10 +100,10 @@ class ModelInstance:
|
||||||
|
|
||||||
def invoke_llm(
|
def invoke_llm(
|
||||||
self,
|
self,
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: Sequence[PromptMessage],
|
||||||
model_parameters: Optional[dict] = None,
|
model_parameters: Optional[dict] = None,
|
||||||
tools: Sequence[PromptMessageTool] | None = None,
|
tools: Sequence[PromptMessageTool] | None = None,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
callbacks: Optional[list[Callback]] = None,
|
callbacks: Optional[list[Callback]] = None,
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import Sequence
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||||
|
@ -31,7 +32,7 @@ class Callback(ABC):
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
model_parameters: dict,
|
model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -60,7 +61,7 @@ class Callback(ABC):
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
model_parameters: dict,
|
model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
):
|
):
|
||||||
|
@ -90,7 +91,7 @@ class Callback(ABC):
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
model_parameters: dict,
|
model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -120,7 +121,7 @@ class Callback(ABC):
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
model_parameters: dict,
|
model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
|
from collections.abc import Sequence
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
@ -57,6 +58,7 @@ class PromptMessageContentType(Enum):
|
||||||
IMAGE = "image"
|
IMAGE = "image"
|
||||||
AUDIO = "audio"
|
AUDIO = "audio"
|
||||||
VIDEO = "video"
|
VIDEO = "video"
|
||||||
|
DOCUMENT = "document"
|
||||||
|
|
||||||
|
|
||||||
class PromptMessageContent(BaseModel):
|
class PromptMessageContent(BaseModel):
|
||||||
|
@ -107,7 +109,7 @@ class PromptMessage(ABC, BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
role: PromptMessageRole
|
role: PromptMessageRole
|
||||||
content: Optional[str | list[PromptMessageContent]] = None
|
content: Optional[str | Sequence[PromptMessageContent]] = None
|
||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
|
|
||||||
def is_empty(self) -> bool:
|
def is_empty(self) -> bool:
|
||||||
|
|
|
@ -87,6 +87,9 @@ class ModelFeature(Enum):
|
||||||
AGENT_THOUGHT = "agent-thought"
|
AGENT_THOUGHT = "agent-thought"
|
||||||
VISION = "vision"
|
VISION = "vision"
|
||||||
STREAM_TOOL_CALL = "stream-tool-call"
|
STREAM_TOOL_CALL = "stream-tool-call"
|
||||||
|
DOCUMENT = "document"
|
||||||
|
VIDEO = "video"
|
||||||
|
AUDIO = "audio"
|
||||||
|
|
||||||
|
|
||||||
class DefaultParameterName(str, Enum):
|
class DefaultParameterName(str, Enum):
|
||||||
|
|
|
@ -2,7 +2,7 @@ import logging
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from collections.abc import Generator, Mapping
|
from collections.abc import Generator, Mapping, Sequence
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from pydantic import ConfigDict
|
from pydantic import ConfigDict
|
||||||
|
@ -48,7 +48,7 @@ class LargeLanguageModel(AIModel):
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
model_parameters: Optional[dict] = None,
|
model_parameters: Optional[dict] = None,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
callbacks: Optional[list[Callback]] = None,
|
callbacks: Optional[list[Callback]] = None,
|
||||||
|
@ -169,7 +169,7 @@ class LargeLanguageModel(AIModel):
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
model_parameters: dict,
|
model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
callbacks: Optional[list[Callback]] = None,
|
callbacks: Optional[list[Callback]] = None,
|
||||||
|
@ -212,7 +212,7 @@ if you are not sure about the structure.
|
||||||
)
|
)
|
||||||
|
|
||||||
model_parameters.pop("response_format")
|
model_parameters.pop("response_format")
|
||||||
stop = stop or []
|
stop = list(stop) if stop is not None else []
|
||||||
stop.extend(["\n```", "```\n"])
|
stop.extend(["\n```", "```\n"])
|
||||||
block_prompts = block_prompts.replace("{{block}}", code_block)
|
block_prompts = block_prompts.replace("{{block}}", code_block)
|
||||||
|
|
||||||
|
@ -408,7 +408,7 @@ if you are not sure about the structure.
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
model_parameters: dict,
|
model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
callbacks: Optional[list[Callback]] = None,
|
callbacks: Optional[list[Callback]] = None,
|
||||||
|
@ -479,7 +479,7 @@ if you are not sure about the structure.
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
model_parameters: dict,
|
model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
) -> Union[LLMResult, Generator]:
|
) -> Union[LLMResult, Generator]:
|
||||||
|
@ -601,7 +601,7 @@ if you are not sure about the structure.
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
model_parameters: dict,
|
model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
callbacks: Optional[list[Callback]] = None,
|
callbacks: Optional[list[Callback]] = None,
|
||||||
|
@ -647,7 +647,7 @@ if you are not sure about the structure.
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
model_parameters: dict,
|
model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
callbacks: Optional[list[Callback]] = None,
|
callbacks: Optional[list[Callback]] = None,
|
||||||
|
@ -694,7 +694,7 @@ if you are not sure about the structure.
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
model_parameters: dict,
|
model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
callbacks: Optional[list[Callback]] = None,
|
callbacks: Optional[list[Callback]] = None,
|
||||||
|
@ -742,7 +742,7 @@ if you are not sure about the structure.
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
model_parameters: dict,
|
model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None,
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
callbacks: Optional[list[Callback]] = None,
|
callbacks: Optional[list[Callback]] = None,
|
||||||
|
|
|
@ -8,6 +8,7 @@ features:
|
||||||
- agent-thought
|
- agent-thought
|
||||||
- stream-tool-call
|
- stream-tool-call
|
||||||
- vision
|
- vision
|
||||||
|
- audio
|
||||||
model_properties:
|
model_properties:
|
||||||
mode: chat
|
mode: chat
|
||||||
context_size: 128000
|
context_size: 128000
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
|
from collections import UserDict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
@ -10,9 +11,9 @@ class ChatRole:
|
||||||
FUNCTION = "function"
|
FUNCTION = "function"
|
||||||
|
|
||||||
|
|
||||||
class _Dict(dict):
|
class _Dict(UserDict):
|
||||||
__setattr__ = dict.__setitem__
|
__setattr__ = UserDict.__setitem__
|
||||||
__getattr__ = dict.__getitem__
|
__getattr__ = UserDict.__getitem__
|
||||||
|
|
||||||
def __missing__(self, key):
|
def __missing__(self, key):
|
||||||
return None
|
return None
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from collections.abc import Sequence
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
from core.model_runtime.entities import (
|
from core.model_runtime.entities import (
|
||||||
|
@ -14,7 +15,7 @@ from core.prompt.simple_prompt_transform import ModelMode
|
||||||
|
|
||||||
class PromptMessageUtil:
|
class PromptMessageUtil:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: list[PromptMessage]) -> list[dict]:
|
def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: Sequence[PromptMessage]) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
Prompt messages to prompt for saving.
|
Prompt messages to prompt for saving.
|
||||||
:param model_mode: model mode
|
:param model_mode: model mode
|
||||||
|
|
52
api/core/tools/provider/builtin/fal/tools/wizper.py
Normal file
52
api/core/tools/provider/builtin/fal/tools/wizper.py
Normal file
|
@ -0,0 +1,52 @@
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import fal_client
|
||||||
|
|
||||||
|
from core.file.enums import FileAttribute, FileType
|
||||||
|
from core.file.file_manager import download, get_attr
|
||||||
|
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||||
|
from core.tools.tool.builtin_tool import BuiltinTool
|
||||||
|
|
||||||
|
|
||||||
|
class WizperTool(BuiltinTool):
|
||||||
|
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage:
|
||||||
|
audio_file = tool_parameters.get("audio_file")
|
||||||
|
task = tool_parameters.get("task", "transcribe")
|
||||||
|
language = tool_parameters.get("language", "en")
|
||||||
|
chunk_level = tool_parameters.get("chunk_level", "segment")
|
||||||
|
version = tool_parameters.get("version", "3")
|
||||||
|
|
||||||
|
if audio_file.type != FileType.AUDIO:
|
||||||
|
return [self.create_text_message("Not a valid audio file.")]
|
||||||
|
|
||||||
|
api_key = self.runtime.credentials["fal_api_key"]
|
||||||
|
|
||||||
|
os.environ["FAL_KEY"] = api_key
|
||||||
|
|
||||||
|
audio_binary = io.BytesIO(download(audio_file))
|
||||||
|
mime_type = get_attr(file=audio_file, attr=FileAttribute.MIME_TYPE)
|
||||||
|
file_data = audio_binary.getvalue()
|
||||||
|
|
||||||
|
try:
|
||||||
|
audio_url = fal_client.upload(file_data, mime_type)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return [self.create_text_message(f"Error uploading audio file: {str(e)}")]
|
||||||
|
|
||||||
|
arguments = {
|
||||||
|
"audio_url": audio_url,
|
||||||
|
"task": task,
|
||||||
|
"language": language,
|
||||||
|
"chunk_level": chunk_level,
|
||||||
|
"version": version,
|
||||||
|
}
|
||||||
|
|
||||||
|
result = fal_client.subscribe(
|
||||||
|
"fal-ai/wizper",
|
||||||
|
arguments=arguments,
|
||||||
|
with_logs=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.create_json_message(result)
|
489
api/core/tools/provider/builtin/fal/tools/wizper.yaml
Normal file
489
api/core/tools/provider/builtin/fal/tools/wizper.yaml
Normal file
|
@ -0,0 +1,489 @@
|
||||||
|
identity:
|
||||||
|
name: wizper
|
||||||
|
author: Kalo Chin
|
||||||
|
label:
|
||||||
|
en_US: Wizper
|
||||||
|
zh_Hans: Wizper
|
||||||
|
description:
|
||||||
|
human:
|
||||||
|
en_US: Transcribe an audio file using the Whisper model.
|
||||||
|
zh_Hans: 使用 Whisper 模型转录音频文件。
|
||||||
|
llm: Transcribe an audio file using the Whisper model.
|
||||||
|
parameters:
|
||||||
|
- name: audio_file
|
||||||
|
type: file
|
||||||
|
required: true
|
||||||
|
label:
|
||||||
|
en_US: Audio File
|
||||||
|
zh_Hans: 音频文件
|
||||||
|
human_description:
|
||||||
|
en_US: "Upload an audio file to transcribe. Supports mp3, mp4, mpeg, mpga, m4a, wav, or webm formats."
|
||||||
|
zh_Hans: "上传要转录的音频文件。支持 mp3、mp4、mpeg、mpga、m4a、wav 或 webm 格式。"
|
||||||
|
llm_description: "Audio file to transcribe. Supported formats: mp3, mp4, mpeg, mpga, m4a, wav, or webm."
|
||||||
|
form: llm
|
||||||
|
- name: task
|
||||||
|
type: select
|
||||||
|
required: true
|
||||||
|
label:
|
||||||
|
en_US: Task
|
||||||
|
zh_Hans: 任务
|
||||||
|
human_description:
|
||||||
|
en_US: "Choose whether to transcribe the audio in its original language or translate it to English"
|
||||||
|
zh_Hans: "选择是以原始语言转录音频还是将其翻译成英语"
|
||||||
|
llm_description: "Task to perform on the audio file. Either transcribe or translate. Default value: 'transcribe'. If 'translate' is selected as the task, the audio will be translated to English, regardless of the language selected."
|
||||||
|
form: form
|
||||||
|
default: transcribe
|
||||||
|
options:
|
||||||
|
- value: transcribe
|
||||||
|
label:
|
||||||
|
en_US: Transcribe
|
||||||
|
zh_Hans: 转录
|
||||||
|
- value: translate
|
||||||
|
label:
|
||||||
|
en_US: Translate
|
||||||
|
zh_Hans: 翻译
|
||||||
|
- name: language
|
||||||
|
type: select
|
||||||
|
required: true
|
||||||
|
label:
|
||||||
|
en_US: Language
|
||||||
|
zh_Hans: 语言
|
||||||
|
human_description:
|
||||||
|
en_US: "Select the primary language spoken in the audio file"
|
||||||
|
zh_Hans: "选择音频文件中使用的主要语言"
|
||||||
|
llm_description: "Language of the audio file."
|
||||||
|
form: form
|
||||||
|
default: en
|
||||||
|
options:
|
||||||
|
- value: af
|
||||||
|
label:
|
||||||
|
en_US: Afrikaans
|
||||||
|
zh_Hans: 南非语
|
||||||
|
- value: am
|
||||||
|
label:
|
||||||
|
en_US: Amharic
|
||||||
|
zh_Hans: 阿姆哈拉语
|
||||||
|
- value: ar
|
||||||
|
label:
|
||||||
|
en_US: Arabic
|
||||||
|
zh_Hans: 阿拉伯语
|
||||||
|
- value: as
|
||||||
|
label:
|
||||||
|
en_US: Assamese
|
||||||
|
zh_Hans: 阿萨姆语
|
||||||
|
- value: az
|
||||||
|
label:
|
||||||
|
en_US: Azerbaijani
|
||||||
|
zh_Hans: 阿塞拜疆语
|
||||||
|
- value: ba
|
||||||
|
label:
|
||||||
|
en_US: Bashkir
|
||||||
|
zh_Hans: 巴什基尔语
|
||||||
|
- value: be
|
||||||
|
label:
|
||||||
|
en_US: Belarusian
|
||||||
|
zh_Hans: 白俄罗斯语
|
||||||
|
- value: bg
|
||||||
|
label:
|
||||||
|
en_US: Bulgarian
|
||||||
|
zh_Hans: 保加利亚语
|
||||||
|
- value: bn
|
||||||
|
label:
|
||||||
|
en_US: Bengali
|
||||||
|
zh_Hans: 孟加拉语
|
||||||
|
- value: bo
|
||||||
|
label:
|
||||||
|
en_US: Tibetan
|
||||||
|
zh_Hans: 藏语
|
||||||
|
- value: br
|
||||||
|
label:
|
||||||
|
en_US: Breton
|
||||||
|
zh_Hans: 布列塔尼语
|
||||||
|
- value: bs
|
||||||
|
label:
|
||||||
|
en_US: Bosnian
|
||||||
|
zh_Hans: 波斯尼亚语
|
||||||
|
- value: ca
|
||||||
|
label:
|
||||||
|
en_US: Catalan
|
||||||
|
zh_Hans: 加泰罗尼亚语
|
||||||
|
- value: cs
|
||||||
|
label:
|
||||||
|
en_US: Czech
|
||||||
|
zh_Hans: 捷克语
|
||||||
|
- value: cy
|
||||||
|
label:
|
||||||
|
en_US: Welsh
|
||||||
|
zh_Hans: 威尔士语
|
||||||
|
- value: da
|
||||||
|
label:
|
||||||
|
en_US: Danish
|
||||||
|
zh_Hans: 丹麦语
|
||||||
|
- value: de
|
||||||
|
label:
|
||||||
|
en_US: German
|
||||||
|
zh_Hans: 德语
|
||||||
|
- value: el
|
||||||
|
label:
|
||||||
|
en_US: Greek
|
||||||
|
zh_Hans: 希腊语
|
||||||
|
- value: en
|
||||||
|
label:
|
||||||
|
en_US: English
|
||||||
|
zh_Hans: 英语
|
||||||
|
- value: es
|
||||||
|
label:
|
||||||
|
en_US: Spanish
|
||||||
|
zh_Hans: 西班牙语
|
||||||
|
- value: et
|
||||||
|
label:
|
||||||
|
en_US: Estonian
|
||||||
|
zh_Hans: 爱沙尼亚语
|
||||||
|
- value: eu
|
||||||
|
label:
|
||||||
|
en_US: Basque
|
||||||
|
zh_Hans: 巴斯克语
|
||||||
|
- value: fa
|
||||||
|
label:
|
||||||
|
en_US: Persian
|
||||||
|
zh_Hans: 波斯语
|
||||||
|
- value: fi
|
||||||
|
label:
|
||||||
|
en_US: Finnish
|
||||||
|
zh_Hans: 芬兰语
|
||||||
|
- value: fo
|
||||||
|
label:
|
||||||
|
en_US: Faroese
|
||||||
|
zh_Hans: 法罗语
|
||||||
|
- value: fr
|
||||||
|
label:
|
||||||
|
en_US: French
|
||||||
|
zh_Hans: 法语
|
||||||
|
- value: gl
|
||||||
|
label:
|
||||||
|
en_US: Galician
|
||||||
|
zh_Hans: 加利西亚语
|
||||||
|
- value: gu
|
||||||
|
label:
|
||||||
|
en_US: Gujarati
|
||||||
|
zh_Hans: 古吉拉特语
|
||||||
|
- value: ha
|
||||||
|
label:
|
||||||
|
en_US: Hausa
|
||||||
|
zh_Hans: 毫萨语
|
||||||
|
- value: haw
|
||||||
|
label:
|
||||||
|
en_US: Hawaiian
|
||||||
|
zh_Hans: 夏威夷语
|
||||||
|
- value: he
|
||||||
|
label:
|
||||||
|
en_US: Hebrew
|
||||||
|
zh_Hans: 希伯来语
|
||||||
|
- value: hi
|
||||||
|
label:
|
||||||
|
en_US: Hindi
|
||||||
|
zh_Hans: 印地语
|
||||||
|
- value: hr
|
||||||
|
label:
|
||||||
|
en_US: Croatian
|
||||||
|
zh_Hans: 克罗地亚语
|
||||||
|
- value: ht
|
||||||
|
label:
|
||||||
|
en_US: Haitian Creole
|
||||||
|
zh_Hans: 海地克里奥尔语
|
||||||
|
- value: hu
|
||||||
|
label:
|
||||||
|
en_US: Hungarian
|
||||||
|
zh_Hans: 匈牙利语
|
||||||
|
- value: hy
|
||||||
|
label:
|
||||||
|
en_US: Armenian
|
||||||
|
zh_Hans: 亚美尼亚语
|
||||||
|
- value: id
|
||||||
|
label:
|
||||||
|
en_US: Indonesian
|
||||||
|
zh_Hans: 印度尼西亚语
|
||||||
|
- value: is
|
||||||
|
label:
|
||||||
|
en_US: Icelandic
|
||||||
|
zh_Hans: 冰岛语
|
||||||
|
- value: it
|
||||||
|
label:
|
||||||
|
en_US: Italian
|
||||||
|
zh_Hans: 意大利语
|
||||||
|
- value: ja
|
||||||
|
label:
|
||||||
|
en_US: Japanese
|
||||||
|
zh_Hans: 日语
|
||||||
|
- value: jw
|
||||||
|
label:
|
||||||
|
en_US: Javanese
|
||||||
|
zh_Hans: 爪哇语
|
||||||
|
- value: ka
|
||||||
|
label:
|
||||||
|
en_US: Georgian
|
||||||
|
zh_Hans: 格鲁吉亚语
|
||||||
|
- value: kk
|
||||||
|
label:
|
||||||
|
en_US: Kazakh
|
||||||
|
zh_Hans: 哈萨克语
|
||||||
|
- value: km
|
||||||
|
label:
|
||||||
|
en_US: Khmer
|
||||||
|
zh_Hans: 高棉语
|
||||||
|
- value: kn
|
||||||
|
label:
|
||||||
|
en_US: Kannada
|
||||||
|
zh_Hans: 卡纳达语
|
||||||
|
- value: ko
|
||||||
|
label:
|
||||||
|
en_US: Korean
|
||||||
|
zh_Hans: 韩语
|
||||||
|
- value: la
|
||||||
|
label:
|
||||||
|
en_US: Latin
|
||||||
|
zh_Hans: 拉丁语
|
||||||
|
- value: lb
|
||||||
|
label:
|
||||||
|
en_US: Luxembourgish
|
||||||
|
zh_Hans: 卢森堡语
|
||||||
|
- value: ln
|
||||||
|
label:
|
||||||
|
en_US: Lingala
|
||||||
|
zh_Hans: 林加拉语
|
||||||
|
- value: lo
|
||||||
|
label:
|
||||||
|
en_US: Lao
|
||||||
|
zh_Hans: 老挝语
|
||||||
|
- value: lt
|
||||||
|
label:
|
||||||
|
en_US: Lithuanian
|
||||||
|
zh_Hans: 立陶宛语
|
||||||
|
- value: lv
|
||||||
|
label:
|
||||||
|
en_US: Latvian
|
||||||
|
zh_Hans: 拉脱维亚语
|
||||||
|
- value: mg
|
||||||
|
label:
|
||||||
|
en_US: Malagasy
|
||||||
|
zh_Hans: 马尔加什语
|
||||||
|
- value: mi
|
||||||
|
label:
|
||||||
|
en_US: Maori
|
||||||
|
zh_Hans: 毛利语
|
||||||
|
- value: mk
|
||||||
|
label:
|
||||||
|
en_US: Macedonian
|
||||||
|
zh_Hans: 马其顿语
|
||||||
|
- value: ml
|
||||||
|
label:
|
||||||
|
en_US: Malayalam
|
||||||
|
zh_Hans: 马拉雅拉姆语
|
||||||
|
- value: mn
|
||||||
|
label:
|
||||||
|
en_US: Mongolian
|
||||||
|
zh_Hans: 蒙古语
|
||||||
|
- value: mr
|
||||||
|
label:
|
||||||
|
en_US: Marathi
|
||||||
|
zh_Hans: 马拉地语
|
||||||
|
- value: ms
|
||||||
|
label:
|
||||||
|
en_US: Malay
|
||||||
|
zh_Hans: 马来语
|
||||||
|
- value: mt
|
||||||
|
label:
|
||||||
|
en_US: Maltese
|
||||||
|
zh_Hans: 马耳他语
|
||||||
|
- value: my
|
||||||
|
label:
|
||||||
|
en_US: Burmese
|
||||||
|
zh_Hans: 缅甸语
|
||||||
|
- value: ne
|
||||||
|
label:
|
||||||
|
en_US: Nepali
|
||||||
|
zh_Hans: 尼泊尔语
|
||||||
|
- value: nl
|
||||||
|
label:
|
||||||
|
en_US: Dutch
|
||||||
|
zh_Hans: 荷兰语
|
||||||
|
- value: nn
|
||||||
|
label:
|
||||||
|
en_US: Norwegian Nynorsk
|
||||||
|
zh_Hans: 新挪威语
|
||||||
|
- value: no
|
||||||
|
label:
|
||||||
|
en_US: Norwegian
|
||||||
|
zh_Hans: 挪威语
|
||||||
|
- value: oc
|
||||||
|
label:
|
||||||
|
en_US: Occitan
|
||||||
|
zh_Hans: 奥克语
|
||||||
|
- value: pa
|
||||||
|
label:
|
||||||
|
en_US: Punjabi
|
||||||
|
zh_Hans: 旁遮普语
|
||||||
|
- value: pl
|
||||||
|
label:
|
||||||
|
en_US: Polish
|
||||||
|
zh_Hans: 波兰语
|
||||||
|
- value: ps
|
||||||
|
label:
|
||||||
|
en_US: Pashto
|
||||||
|
zh_Hans: 普什图语
|
||||||
|
- value: pt
|
||||||
|
label:
|
||||||
|
en_US: Portuguese
|
||||||
|
zh_Hans: 葡萄牙语
|
||||||
|
- value: ro
|
||||||
|
label:
|
||||||
|
en_US: Romanian
|
||||||
|
zh_Hans: 罗马尼亚语
|
||||||
|
- value: ru
|
||||||
|
label:
|
||||||
|
en_US: Russian
|
||||||
|
zh_Hans: 俄语
|
||||||
|
- value: sa
|
||||||
|
label:
|
||||||
|
en_US: Sanskrit
|
||||||
|
zh_Hans: 梵语
|
||||||
|
- value: sd
|
||||||
|
label:
|
||||||
|
en_US: Sindhi
|
||||||
|
zh_Hans: 信德语
|
||||||
|
- value: si
|
||||||
|
label:
|
||||||
|
en_US: Sinhala
|
||||||
|
zh_Hans: 僧伽罗语
|
||||||
|
- value: sk
|
||||||
|
label:
|
||||||
|
en_US: Slovak
|
||||||
|
zh_Hans: 斯洛伐克语
|
||||||
|
- value: sl
|
||||||
|
label:
|
||||||
|
en_US: Slovenian
|
||||||
|
zh_Hans: 斯洛文尼亚语
|
||||||
|
- value: sn
|
||||||
|
label:
|
||||||
|
en_US: Shona
|
||||||
|
zh_Hans: 修纳语
|
||||||
|
- value: so
|
||||||
|
label:
|
||||||
|
en_US: Somali
|
||||||
|
zh_Hans: 索马里语
|
||||||
|
- value: sq
|
||||||
|
label:
|
||||||
|
en_US: Albanian
|
||||||
|
zh_Hans: 阿尔巴尼亚语
|
||||||
|
- value: sr
|
||||||
|
label:
|
||||||
|
en_US: Serbian
|
||||||
|
zh_Hans: 塞尔维亚语
|
||||||
|
- value: su
|
||||||
|
label:
|
||||||
|
en_US: Sundanese
|
||||||
|
zh_Hans: 巽他语
|
||||||
|
- value: sv
|
||||||
|
label:
|
||||||
|
en_US: Swedish
|
||||||
|
zh_Hans: 瑞典语
|
||||||
|
- value: sw
|
||||||
|
label:
|
||||||
|
en_US: Swahili
|
||||||
|
zh_Hans: 斯瓦希里语
|
||||||
|
- value: ta
|
||||||
|
label:
|
||||||
|
en_US: Tamil
|
||||||
|
zh_Hans: 泰米尔语
|
||||||
|
- value: te
|
||||||
|
label:
|
||||||
|
en_US: Telugu
|
||||||
|
zh_Hans: 泰卢固语
|
||||||
|
- value: tg
|
||||||
|
label:
|
||||||
|
en_US: Tajik
|
||||||
|
zh_Hans: 塔吉克语
|
||||||
|
- value: th
|
||||||
|
label:
|
||||||
|
en_US: Thai
|
||||||
|
zh_Hans: 泰语
|
||||||
|
- value: tk
|
||||||
|
label:
|
||||||
|
en_US: Turkmen
|
||||||
|
zh_Hans: 土库曼语
|
||||||
|
- value: tl
|
||||||
|
label:
|
||||||
|
en_US: Tagalog
|
||||||
|
zh_Hans: 他加禄语
|
||||||
|
- value: tr
|
||||||
|
label:
|
||||||
|
en_US: Turkish
|
||||||
|
zh_Hans: 土耳其语
|
||||||
|
- value: tt
|
||||||
|
label:
|
||||||
|
en_US: Tatar
|
||||||
|
zh_Hans: 鞑靼语
|
||||||
|
- value: uk
|
||||||
|
label:
|
||||||
|
en_US: Ukrainian
|
||||||
|
zh_Hans: 乌克兰语
|
||||||
|
- value: ur
|
||||||
|
label:
|
||||||
|
en_US: Urdu
|
||||||
|
zh_Hans: 乌尔都语
|
||||||
|
- value: uz
|
||||||
|
label:
|
||||||
|
en_US: Uzbek
|
||||||
|
zh_Hans: 乌兹别克语
|
||||||
|
- value: vi
|
||||||
|
label:
|
||||||
|
en_US: Vietnamese
|
||||||
|
zh_Hans: 越南语
|
||||||
|
- value: yi
|
||||||
|
label:
|
||||||
|
en_US: Yiddish
|
||||||
|
zh_Hans: 意第绪语
|
||||||
|
- value: yo
|
||||||
|
label:
|
||||||
|
en_US: Yoruba
|
||||||
|
zh_Hans: 约鲁巴语
|
||||||
|
- value: yue
|
||||||
|
label:
|
||||||
|
en_US: Cantonese
|
||||||
|
zh_Hans: 粤语
|
||||||
|
- value: zh
|
||||||
|
label:
|
||||||
|
en_US: Chinese
|
||||||
|
zh_Hans: 中文
|
||||||
|
- name: chunk_level
|
||||||
|
type: select
|
||||||
|
label:
|
||||||
|
en_US: Chunk Level
|
||||||
|
zh_Hans: 分块级别
|
||||||
|
human_description:
|
||||||
|
en_US: "Choose how the transcription should be divided into chunks"
|
||||||
|
zh_Hans: "选择如何将转录内容分成块"
|
||||||
|
llm_description: "Level of the chunks to return."
|
||||||
|
form: form
|
||||||
|
default: segment
|
||||||
|
options:
|
||||||
|
- value: segment
|
||||||
|
label:
|
||||||
|
en_US: Segment
|
||||||
|
zh_Hans: 段
|
||||||
|
- name: version
|
||||||
|
type: select
|
||||||
|
label:
|
||||||
|
en_US: Version
|
||||||
|
zh_Hans: 版本
|
||||||
|
human_description:
|
||||||
|
en_US: "Select which version of the Whisper large model to use"
|
||||||
|
zh_Hans: "选择要使用的 Whisper large 模型版本"
|
||||||
|
llm_description: "Version of the model to use. All of the models are the Whisper large variant."
|
||||||
|
form: form
|
||||||
|
default: "3"
|
||||||
|
options:
|
||||||
|
- value: "3"
|
||||||
|
label:
|
||||||
|
en_US: Version 3
|
||||||
|
zh_Hans: 版本 3
|
|
@ -118,11 +118,11 @@ class FileSegment(Segment):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def log(self) -> str:
|
def log(self) -> str:
|
||||||
return str(self.value)
|
return ""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def text(self) -> str:
|
def text(self) -> str:
|
||||||
return str(self.value)
|
return ""
|
||||||
|
|
||||||
|
|
||||||
class ArrayAnySegment(ArraySegment):
|
class ArrayAnySegment(ArraySegment):
|
||||||
|
@ -155,3 +155,11 @@ class ArrayFileSegment(ArraySegment):
|
||||||
for item in self.value:
|
for item in self.value:
|
||||||
items.append(item.markdown)
|
items.append(item.markdown)
|
||||||
return "\n".join(items)
|
return "\n".join(items)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def log(self) -> str:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def text(self) -> str:
|
||||||
|
return ""
|
||||||
|
|
|
@ -39,7 +39,14 @@ class VisionConfig(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class PromptConfig(BaseModel):
|
class PromptConfig(BaseModel):
|
||||||
jinja2_variables: Optional[list[VariableSelector]] = None
|
jinja2_variables: Sequence[VariableSelector] = Field(default_factory=list)
|
||||||
|
|
||||||
|
@field_validator("jinja2_variables", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def convert_none_jinja2_variables(cls, v: Any):
|
||||||
|
if v is None:
|
||||||
|
return []
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
class LLMNodeChatModelMessage(ChatModelMessage):
|
class LLMNodeChatModelMessage(ChatModelMessage):
|
||||||
|
@ -53,7 +60,14 @@ class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
|
||||||
class LLMNodeData(BaseNodeData):
|
class LLMNodeData(BaseNodeData):
|
||||||
model: ModelConfig
|
model: ModelConfig
|
||||||
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
|
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
|
||||||
prompt_config: Optional[PromptConfig] = None
|
prompt_config: PromptConfig = Field(default_factory=PromptConfig)
|
||||||
memory: Optional[MemoryConfig] = None
|
memory: Optional[MemoryConfig] = None
|
||||||
context: ContextConfig
|
context: ContextConfig
|
||||||
vision: VisionConfig = Field(default_factory=VisionConfig)
|
vision: VisionConfig = Field(default_factory=VisionConfig)
|
||||||
|
|
||||||
|
@field_validator("prompt_config", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def convert_none_prompt_config(cls, v: Any):
|
||||||
|
if v is None:
|
||||||
|
return PromptConfig()
|
||||||
|
return v
|
||||||
|
|
|
@ -24,3 +24,11 @@ class LLMModeRequiredError(LLMNodeError):
|
||||||
|
|
||||||
class NoPromptFoundError(LLMNodeError):
|
class NoPromptFoundError(LLMNodeError):
|
||||||
"""Raised when no prompt is found in the LLM configuration."""
|
"""Raised when no prompt is found in the LLM configuration."""
|
||||||
|
|
||||||
|
|
||||||
|
class NotSupportedPromptTypeError(LLMNodeError):
|
||||||
|
"""Raised when the prompt type is not supported."""
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryRolePrefixRequiredError(LLMNodeError):
|
||||||
|
"""Raised when memory role prefix is required for completion model."""
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
from collections.abc import Generator, Mapping, Sequence
|
from collections.abc import Generator, Mapping, Sequence
|
||||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||||
|
|
||||||
|
@ -6,21 +7,26 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
|
||||||
from core.entities.model_entities import ModelStatus
|
from core.entities.model_entities import ModelStatus
|
||||||
from core.entities.provider_entities import QuotaUnit
|
from core.entities.provider_entities import QuotaUnit
|
||||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||||
|
from core.file import FileType, file_manager
|
||||||
|
from core.helper.code_executor import CodeExecutor, CodeLanguage
|
||||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
from core.model_manager import ModelInstance, ModelManager
|
from core.model_manager import ModelInstance, ModelManager
|
||||||
from core.model_runtime.entities import (
|
from core.model_runtime.entities import (
|
||||||
AudioPromptMessageContent,
|
|
||||||
ImagePromptMessageContent,
|
ImagePromptMessageContent,
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
PromptMessageContentType,
|
PromptMessageContentType,
|
||||||
TextPromptMessageContent,
|
TextPromptMessageContent,
|
||||||
VideoPromptMessageContent,
|
|
||||||
)
|
)
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.message_entities import (
|
||||||
|
AssistantPromptMessage,
|
||||||
|
PromptMessageRole,
|
||||||
|
SystemPromptMessage,
|
||||||
|
UserPromptMessage,
|
||||||
|
)
|
||||||
|
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
|
||||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||||
from core.variables import (
|
from core.variables import (
|
||||||
|
@ -32,8 +38,9 @@ from core.variables import (
|
||||||
ObjectSegment,
|
ObjectSegment,
|
||||||
StringSegment,
|
StringSegment,
|
||||||
)
|
)
|
||||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
|
||||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||||
|
from core.workflow.entities.variable_entities import VariableSelector
|
||||||
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.enums import SystemVariableKey
|
from core.workflow.enums import SystemVariableKey
|
||||||
from core.workflow.graph_engine.entities.event import InNodeEvent
|
from core.workflow.graph_engine.entities.event import InNodeEvent
|
||||||
from core.workflow.nodes.base import BaseNode
|
from core.workflow.nodes.base import BaseNode
|
||||||
|
@ -62,14 +69,18 @@ from .exc import (
|
||||||
InvalidVariableTypeError,
|
InvalidVariableTypeError,
|
||||||
LLMModeRequiredError,
|
LLMModeRequiredError,
|
||||||
LLMNodeError,
|
LLMNodeError,
|
||||||
|
MemoryRolePrefixRequiredError,
|
||||||
ModelNotExistError,
|
ModelNotExistError,
|
||||||
NoPromptFoundError,
|
NoPromptFoundError,
|
||||||
|
NotSupportedPromptTypeError,
|
||||||
VariableNotFoundError,
|
VariableNotFoundError,
|
||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from core.file.models import File
|
from core.file.models import File
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LLMNode(BaseNode[LLMNodeData]):
|
class LLMNode(BaseNode[LLMNodeData]):
|
||||||
_node_data_cls = LLMNodeData
|
_node_data_cls = LLMNodeData
|
||||||
|
@ -123,17 +134,13 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||||
|
|
||||||
# fetch prompt messages
|
# fetch prompt messages
|
||||||
if self.node_data.memory:
|
if self.node_data.memory:
|
||||||
query = self.graph_runtime_state.variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
|
query = self.node_data.memory.query_prompt_template
|
||||||
if not query:
|
|
||||||
raise VariableNotFoundError("Query not found")
|
|
||||||
query = query.text
|
|
||||||
else:
|
else:
|
||||||
query = None
|
query = None
|
||||||
|
|
||||||
prompt_messages, stop = self._fetch_prompt_messages(
|
prompt_messages, stop = self._fetch_prompt_messages(
|
||||||
system_query=query,
|
user_query=query,
|
||||||
inputs=inputs,
|
user_files=files,
|
||||||
files=files,
|
|
||||||
context=context,
|
context=context,
|
||||||
memory=memory,
|
memory=memory,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
|
@ -141,6 +148,8 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||||
memory_config=self.node_data.memory,
|
memory_config=self.node_data.memory,
|
||||||
vision_enabled=self.node_data.vision.enabled,
|
vision_enabled=self.node_data.vision.enabled,
|
||||||
vision_detail=self.node_data.vision.configs.detail,
|
vision_detail=self.node_data.vision.configs.detail,
|
||||||
|
variable_pool=self.graph_runtime_state.variable_pool,
|
||||||
|
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
|
||||||
)
|
)
|
||||||
|
|
||||||
process_data = {
|
process_data = {
|
||||||
|
@ -181,6 +190,17 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Node {self.node_id} failed to run: {e}")
|
||||||
|
yield RunCompletedEvent(
|
||||||
|
run_result=NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
|
error=str(e),
|
||||||
|
inputs=node_inputs,
|
||||||
|
process_data=process_data,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
|
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
|
||||||
|
|
||||||
|
@ -203,8 +223,8 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||||
self,
|
self,
|
||||||
node_data_model: ModelConfig,
|
node_data_model: ModelConfig,
|
||||||
model_instance: ModelInstance,
|
model_instance: ModelInstance,
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: Sequence[PromptMessage],
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
) -> Generator[NodeEvent, None, None]:
|
) -> Generator[NodeEvent, None, None]:
|
||||||
db.session.close()
|
db.session.close()
|
||||||
|
|
||||||
|
@ -519,9 +539,8 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||||
def _fetch_prompt_messages(
|
def _fetch_prompt_messages(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
system_query: str | None = None,
|
user_query: str | None = None,
|
||||||
inputs: dict[str, str] | None = None,
|
user_files: Sequence["File"],
|
||||||
files: Sequence["File"],
|
|
||||||
context: str | None = None,
|
context: str | None = None,
|
||||||
memory: TokenBufferMemory | None = None,
|
memory: TokenBufferMemory | None = None,
|
||||||
model_config: ModelConfigWithCredentialsEntity,
|
model_config: ModelConfigWithCredentialsEntity,
|
||||||
|
@ -529,58 +548,146 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||||
memory_config: MemoryConfig | None = None,
|
memory_config: MemoryConfig | None = None,
|
||||||
vision_enabled: bool = False,
|
vision_enabled: bool = False,
|
||||||
vision_detail: ImagePromptMessageContent.DETAIL,
|
vision_detail: ImagePromptMessageContent.DETAIL,
|
||||||
) -> tuple[list[PromptMessage], Optional[list[str]]]:
|
variable_pool: VariablePool,
|
||||||
inputs = inputs or {}
|
jinja2_variables: Sequence[VariableSelector],
|
||||||
|
) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]:
|
||||||
|
prompt_messages = []
|
||||||
|
|
||||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
if isinstance(prompt_template, list):
|
||||||
prompt_messages = prompt_transform.get_prompt(
|
# For chat model
|
||||||
prompt_template=prompt_template,
|
prompt_messages.extend(
|
||||||
inputs=inputs,
|
_handle_list_messages(
|
||||||
query=system_query or "",
|
messages=prompt_template,
|
||||||
files=files,
|
context=context,
|
||||||
context=context,
|
jinja2_variables=jinja2_variables,
|
||||||
memory_config=memory_config,
|
variable_pool=variable_pool,
|
||||||
memory=memory,
|
vision_detail_config=vision_detail,
|
||||||
model_config=model_config,
|
)
|
||||||
)
|
)
|
||||||
stop = model_config.stop
|
|
||||||
|
# Get memory messages for chat mode
|
||||||
|
memory_messages = _handle_memory_chat_mode(
|
||||||
|
memory=memory,
|
||||||
|
memory_config=memory_config,
|
||||||
|
model_config=model_config,
|
||||||
|
)
|
||||||
|
# Extend prompt_messages with memory messages
|
||||||
|
prompt_messages.extend(memory_messages)
|
||||||
|
|
||||||
|
# Add current query to the prompt messages
|
||||||
|
if user_query:
|
||||||
|
message = LLMNodeChatModelMessage(
|
||||||
|
text=user_query,
|
||||||
|
role=PromptMessageRole.USER,
|
||||||
|
edition_type="basic",
|
||||||
|
)
|
||||||
|
prompt_messages.extend(
|
||||||
|
_handle_list_messages(
|
||||||
|
messages=[message],
|
||||||
|
context="",
|
||||||
|
jinja2_variables=[],
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
vision_detail_config=vision_detail,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
|
||||||
|
# For completion model
|
||||||
|
prompt_messages.extend(
|
||||||
|
_handle_completion_template(
|
||||||
|
template=prompt_template,
|
||||||
|
context=context,
|
||||||
|
jinja2_variables=jinja2_variables,
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get memory text for completion model
|
||||||
|
memory_text = _handle_memory_completion_mode(
|
||||||
|
memory=memory,
|
||||||
|
memory_config=memory_config,
|
||||||
|
model_config=model_config,
|
||||||
|
)
|
||||||
|
# Insert histories into the prompt
|
||||||
|
prompt_content = prompt_messages[0].content
|
||||||
|
if "#histories#" in prompt_content:
|
||||||
|
prompt_content = prompt_content.replace("#histories#", memory_text)
|
||||||
|
else:
|
||||||
|
prompt_content = memory_text + "\n" + prompt_content
|
||||||
|
prompt_messages[0].content = prompt_content
|
||||||
|
|
||||||
|
# Add current query to the prompt message
|
||||||
|
if user_query:
|
||||||
|
prompt_content = prompt_messages[0].content.replace("#sys.query#", user_query)
|
||||||
|
prompt_messages[0].content = prompt_content
|
||||||
|
else:
|
||||||
|
errmsg = f"Prompt type {type(prompt_template)} is not supported"
|
||||||
|
logger.warning(errmsg)
|
||||||
|
raise NotSupportedPromptTypeError(errmsg)
|
||||||
|
|
||||||
|
if vision_enabled and user_files:
|
||||||
|
file_prompts = []
|
||||||
|
for file in user_files:
|
||||||
|
file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
|
||||||
|
file_prompts.append(file_prompt)
|
||||||
|
if (
|
||||||
|
len(prompt_messages) > 0
|
||||||
|
and isinstance(prompt_messages[-1], UserPromptMessage)
|
||||||
|
and isinstance(prompt_messages[-1].content, list)
|
||||||
|
):
|
||||||
|
prompt_messages[-1] = UserPromptMessage(content=prompt_messages[-1].content + file_prompts)
|
||||||
|
else:
|
||||||
|
prompt_messages.append(UserPromptMessage(content=file_prompts))
|
||||||
|
|
||||||
|
# Filter prompt messages
|
||||||
filtered_prompt_messages = []
|
filtered_prompt_messages = []
|
||||||
for prompt_message in prompt_messages:
|
for prompt_message in prompt_messages:
|
||||||
if prompt_message.is_empty():
|
if isinstance(prompt_message.content, list):
|
||||||
continue
|
|
||||||
|
|
||||||
if not isinstance(prompt_message.content, str):
|
|
||||||
prompt_message_content = []
|
prompt_message_content = []
|
||||||
for content_item in prompt_message.content or []:
|
for content_item in prompt_message.content:
|
||||||
# Skip image if vision is disabled
|
# Skip content if features are not defined
|
||||||
if not vision_enabled and content_item.type == PromptMessageContentType.IMAGE:
|
if not model_config.model_schema.features:
|
||||||
|
if content_item.type != PromptMessageContentType.TEXT:
|
||||||
|
continue
|
||||||
|
prompt_message_content.append(content_item)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if isinstance(content_item, ImagePromptMessageContent):
|
# Skip content if corresponding feature is not supported
|
||||||
# Override vision config if LLM node has vision config,
|
if (
|
||||||
# cuz vision detail is related to the configuration from FileUpload feature.
|
(
|
||||||
content_item.detail = vision_detail
|
content_item.type == PromptMessageContentType.IMAGE
|
||||||
prompt_message_content.append(content_item)
|
and ModelFeature.VISION not in model_config.model_schema.features
|
||||||
elif isinstance(
|
)
|
||||||
content_item, TextPromptMessageContent | AudioPromptMessageContent | VideoPromptMessageContent
|
or (
|
||||||
|
content_item.type == PromptMessageContentType.DOCUMENT
|
||||||
|
and ModelFeature.DOCUMENT not in model_config.model_schema.features
|
||||||
|
)
|
||||||
|
or (
|
||||||
|
content_item.type == PromptMessageContentType.VIDEO
|
||||||
|
and ModelFeature.VIDEO not in model_config.model_schema.features
|
||||||
|
)
|
||||||
|
or (
|
||||||
|
content_item.type == PromptMessageContentType.AUDIO
|
||||||
|
and ModelFeature.AUDIO not in model_config.model_schema.features
|
||||||
|
)
|
||||||
):
|
):
|
||||||
prompt_message_content.append(content_item)
|
continue
|
||||||
|
prompt_message_content.append(content_item)
|
||||||
if len(prompt_message_content) > 1:
|
if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT:
|
||||||
prompt_message.content = prompt_message_content
|
|
||||||
elif (
|
|
||||||
len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT
|
|
||||||
):
|
|
||||||
prompt_message.content = prompt_message_content[0].data
|
prompt_message.content = prompt_message_content[0].data
|
||||||
|
else:
|
||||||
|
prompt_message.content = prompt_message_content
|
||||||
|
if prompt_message.is_empty():
|
||||||
|
continue
|
||||||
filtered_prompt_messages.append(prompt_message)
|
filtered_prompt_messages.append(prompt_message)
|
||||||
|
|
||||||
if not filtered_prompt_messages:
|
if len(filtered_prompt_messages) == 0:
|
||||||
raise NoPromptFoundError(
|
raise NoPromptFoundError(
|
||||||
"No prompt found in the LLM configuration. "
|
"No prompt found in the LLM configuration. "
|
||||||
"Please ensure a prompt is properly configured before proceeding."
|
"Please ensure a prompt is properly configured before proceeding."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
stop = model_config.stop
|
||||||
return filtered_prompt_messages, stop
|
return filtered_prompt_messages, stop
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -715,3 +822,198 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _combine_text_message_with_role(*, text: str, role: PromptMessageRole):
|
||||||
|
match role:
|
||||||
|
case PromptMessageRole.USER:
|
||||||
|
return UserPromptMessage(content=[TextPromptMessageContent(data=text)])
|
||||||
|
case PromptMessageRole.ASSISTANT:
|
||||||
|
return AssistantPromptMessage(content=[TextPromptMessageContent(data=text)])
|
||||||
|
case PromptMessageRole.SYSTEM:
|
||||||
|
return SystemPromptMessage(content=[TextPromptMessageContent(data=text)])
|
||||||
|
raise NotImplementedError(f"Role {role} is not supported")
|
||||||
|
|
||||||
|
|
||||||
|
def _render_jinja2_message(
|
||||||
|
*,
|
||||||
|
template: str,
|
||||||
|
jinjia2_variables: Sequence[VariableSelector],
|
||||||
|
variable_pool: VariablePool,
|
||||||
|
):
|
||||||
|
if not template:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
jinjia2_inputs = {}
|
||||||
|
for jinja2_variable in jinjia2_variables:
|
||||||
|
variable = variable_pool.get(jinja2_variable.value_selector)
|
||||||
|
jinjia2_inputs[jinja2_variable.variable] = variable.to_object() if variable else ""
|
||||||
|
code_execute_resp = CodeExecutor.execute_workflow_code_template(
|
||||||
|
language=CodeLanguage.JINJA2,
|
||||||
|
code=template,
|
||||||
|
inputs=jinjia2_inputs,
|
||||||
|
)
|
||||||
|
result_text = code_execute_resp["result"]
|
||||||
|
return result_text
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_list_messages(
|
||||||
|
*,
|
||||||
|
messages: Sequence[LLMNodeChatModelMessage],
|
||||||
|
context: Optional[str],
|
||||||
|
jinja2_variables: Sequence[VariableSelector],
|
||||||
|
variable_pool: VariablePool,
|
||||||
|
vision_detail_config: ImagePromptMessageContent.DETAIL,
|
||||||
|
) -> Sequence[PromptMessage]:
|
||||||
|
prompt_messages = []
|
||||||
|
for message in messages:
|
||||||
|
if message.edition_type == "jinja2":
|
||||||
|
result_text = _render_jinja2_message(
|
||||||
|
template=message.jinja2_text or "",
|
||||||
|
jinjia2_variables=jinja2_variables,
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
)
|
||||||
|
prompt_message = _combine_text_message_with_role(text=result_text, role=message.role)
|
||||||
|
prompt_messages.append(prompt_message)
|
||||||
|
else:
|
||||||
|
# Get segment group from basic message
|
||||||
|
if context:
|
||||||
|
template = message.text.replace("{#context#}", context)
|
||||||
|
else:
|
||||||
|
template = message.text
|
||||||
|
segment_group = variable_pool.convert_template(template)
|
||||||
|
|
||||||
|
# Process segments for images
|
||||||
|
file_contents = []
|
||||||
|
for segment in segment_group.value:
|
||||||
|
if isinstance(segment, ArrayFileSegment):
|
||||||
|
for file in segment.value:
|
||||||
|
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO}:
|
||||||
|
file_content = file_manager.to_prompt_message_content(
|
||||||
|
file, image_detail_config=vision_detail_config
|
||||||
|
)
|
||||||
|
file_contents.append(file_content)
|
||||||
|
if isinstance(segment, FileSegment):
|
||||||
|
file = segment.value
|
||||||
|
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO}:
|
||||||
|
file_content = file_manager.to_prompt_message_content(
|
||||||
|
file, image_detail_config=vision_detail_config
|
||||||
|
)
|
||||||
|
file_contents.append(file_content)
|
||||||
|
|
||||||
|
# Create message with text from all segments
|
||||||
|
plain_text = segment_group.text
|
||||||
|
if plain_text:
|
||||||
|
prompt_message = _combine_text_message_with_role(text=plain_text, role=message.role)
|
||||||
|
prompt_messages.append(prompt_message)
|
||||||
|
|
||||||
|
if file_contents:
|
||||||
|
# Create message with image contents
|
||||||
|
prompt_message = UserPromptMessage(content=file_contents)
|
||||||
|
prompt_messages.append(prompt_message)
|
||||||
|
|
||||||
|
return prompt_messages
|
||||||
|
|
||||||
|
|
||||||
|
def _calculate_rest_token(
|
||||||
|
*, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
|
||||||
|
) -> int:
|
||||||
|
rest_tokens = 2000
|
||||||
|
|
||||||
|
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||||
|
if model_context_tokens:
|
||||||
|
model_instance = ModelInstance(
|
||||||
|
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
|
||||||
|
)
|
||||||
|
|
||||||
|
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||||
|
|
||||||
|
max_tokens = 0
|
||||||
|
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||||
|
if parameter_rule.name == "max_tokens" or (
|
||||||
|
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
||||||
|
):
|
||||||
|
max_tokens = (
|
||||||
|
model_config.parameters.get(parameter_rule.name)
|
||||||
|
or model_config.parameters.get(str(parameter_rule.use_template))
|
||||||
|
or 0
|
||||||
|
)
|
||||||
|
|
||||||
|
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
|
||||||
|
rest_tokens = max(rest_tokens, 0)
|
||||||
|
|
||||||
|
return rest_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_memory_chat_mode(
|
||||||
|
*,
|
||||||
|
memory: TokenBufferMemory | None,
|
||||||
|
memory_config: MemoryConfig | None,
|
||||||
|
model_config: ModelConfigWithCredentialsEntity,
|
||||||
|
) -> Sequence[PromptMessage]:
|
||||||
|
memory_messages = []
|
||||||
|
# Get messages from memory for chat model
|
||||||
|
if memory and memory_config:
|
||||||
|
rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config)
|
||||||
|
memory_messages = memory.get_history_prompt_messages(
|
||||||
|
max_token_limit=rest_tokens,
|
||||||
|
message_limit=memory_config.window.size if memory_config.window.enabled else None,
|
||||||
|
)
|
||||||
|
return memory_messages
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_memory_completion_mode(
|
||||||
|
*,
|
||||||
|
memory: TokenBufferMemory | None,
|
||||||
|
memory_config: MemoryConfig | None,
|
||||||
|
model_config: ModelConfigWithCredentialsEntity,
|
||||||
|
) -> str:
|
||||||
|
memory_text = ""
|
||||||
|
# Get history text from memory for completion model
|
||||||
|
if memory and memory_config:
|
||||||
|
rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config)
|
||||||
|
if not memory_config.role_prefix:
|
||||||
|
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
|
||||||
|
memory_text = memory.get_history_prompt_text(
|
||||||
|
max_token_limit=rest_tokens,
|
||||||
|
message_limit=memory_config.window.size if memory_config.window.enabled else None,
|
||||||
|
human_prefix=memory_config.role_prefix.user,
|
||||||
|
ai_prefix=memory_config.role_prefix.assistant,
|
||||||
|
)
|
||||||
|
return memory_text
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_completion_template(
|
||||||
|
*,
|
||||||
|
template: LLMNodeCompletionModelPromptTemplate,
|
||||||
|
context: Optional[str],
|
||||||
|
jinja2_variables: Sequence[VariableSelector],
|
||||||
|
variable_pool: VariablePool,
|
||||||
|
) -> Sequence[PromptMessage]:
|
||||||
|
"""Handle completion template processing outside of LLMNode class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template: The completion model prompt template
|
||||||
|
context: Optional context string
|
||||||
|
jinja2_variables: Variables for jinja2 template rendering
|
||||||
|
variable_pool: Variable pool for template conversion
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sequence of prompt messages
|
||||||
|
"""
|
||||||
|
prompt_messages = []
|
||||||
|
if template.edition_type == "jinja2":
|
||||||
|
result_text = _render_jinja2_message(
|
||||||
|
template=template.jinja2_text or "",
|
||||||
|
jinjia2_variables=jinja2_variables,
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if context:
|
||||||
|
template_text = template.text.replace("{#context#}", context)
|
||||||
|
else:
|
||||||
|
template_text = template.text
|
||||||
|
result_text = variable_pool.convert_template(template_text).text
|
||||||
|
prompt_message = _combine_text_message_with_role(text=result_text, role=PromptMessageRole.USER)
|
||||||
|
prompt_messages.append(prompt_message)
|
||||||
|
return prompt_messages
|
||||||
|
|
|
@ -86,12 +86,14 @@ class QuestionClassifierNode(LLMNode):
|
||||||
)
|
)
|
||||||
prompt_messages, stop = self._fetch_prompt_messages(
|
prompt_messages, stop = self._fetch_prompt_messages(
|
||||||
prompt_template=prompt_template,
|
prompt_template=prompt_template,
|
||||||
system_query=query,
|
user_query=query,
|
||||||
memory=memory,
|
memory=memory,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
files=files,
|
user_files=files,
|
||||||
vision_enabled=node_data.vision.enabled,
|
vision_enabled=node_data.vision.enabled,
|
||||||
vision_detail=node_data.vision.configs.detail,
|
vision_detail=node_data.vision.configs.detail,
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
jinja2_variables=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
# handle invoke result
|
# handle invoke result
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from os import path
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
@ -180,7 +179,6 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||||
for response in tool_response:
|
for response in tool_response:
|
||||||
if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
|
if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
|
||||||
url = str(response.message) if response.message else None
|
url = str(response.message) if response.message else None
|
||||||
ext = path.splitext(url)[1] if url else ".bin"
|
|
||||||
tool_file_id = str(url).split("/")[-1].split(".")[0]
|
tool_file_id = str(url).split("/")[-1].split(".")[0]
|
||||||
transfer_method = response.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
|
transfer_method = response.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
|
||||||
|
|
||||||
|
@ -202,7 +200,6 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||||
)
|
)
|
||||||
result.append(file)
|
result.append(file)
|
||||||
elif response.type == ToolInvokeMessage.MessageType.BLOB:
|
elif response.type == ToolInvokeMessage.MessageType.BLOB:
|
||||||
# get tool file id
|
|
||||||
tool_file_id = str(response.message).split("/")[-1].split(".")[0]
|
tool_file_id = str(response.message).split("/")[-1].split(".")[0]
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||||
|
@ -211,7 +208,6 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||||
raise ValueError(f"tool file {tool_file_id} not exists")
|
raise ValueError(f"tool file {tool_file_id} not exists")
|
||||||
mapping = {
|
mapping = {
|
||||||
"tool_file_id": tool_file_id,
|
"tool_file_id": tool_file_id,
|
||||||
"type": FileType.IMAGE,
|
|
||||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||||
}
|
}
|
||||||
file = file_factory.build_from_mapping(
|
file = file_factory.build_from_mapping(
|
||||||
|
@ -228,13 +224,8 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||||
tool_file = session.scalar(stmt)
|
tool_file = session.scalar(stmt)
|
||||||
if tool_file is None:
|
if tool_file is None:
|
||||||
raise ToolFileError(f"Tool file {tool_file_id} does not exist")
|
raise ToolFileError(f"Tool file {tool_file_id} does not exist")
|
||||||
if "." in url:
|
|
||||||
extension = "." + url.split("/")[-1].split(".")[1]
|
|
||||||
else:
|
|
||||||
extension = ".bin"
|
|
||||||
mapping = {
|
mapping = {
|
||||||
"tool_file_id": tool_file_id,
|
"tool_file_id": tool_file_id,
|
||||||
"type": FileType.IMAGE,
|
|
||||||
"transfer_method": transfer_method,
|
"transfer_method": transfer_method,
|
||||||
"url": url,
|
"url": url,
|
||||||
}
|
}
|
||||||
|
|
|
@ -180,6 +180,20 @@ def _get_remote_file_info(url: str):
|
||||||
return mime_type, filename, file_size
|
return mime_type, filename, file_size
|
||||||
|
|
||||||
|
|
||||||
|
def _get_file_type_by_mimetype(mime_type: str) -> FileType:
|
||||||
|
if "image" in mime_type:
|
||||||
|
file_type = FileType.IMAGE
|
||||||
|
elif "video" in mime_type:
|
||||||
|
file_type = FileType.VIDEO
|
||||||
|
elif "audio" in mime_type:
|
||||||
|
file_type = FileType.AUDIO
|
||||||
|
elif "text" in mime_type or "pdf" in mime_type:
|
||||||
|
file_type = FileType.DOCUMENT
|
||||||
|
else:
|
||||||
|
file_type = FileType.CUSTOM
|
||||||
|
return file_type
|
||||||
|
|
||||||
|
|
||||||
def _build_from_tool_file(
|
def _build_from_tool_file(
|
||||||
*,
|
*,
|
||||||
mapping: Mapping[str, Any],
|
mapping: Mapping[str, Any],
|
||||||
|
@ -199,12 +213,13 @@ def _build_from_tool_file(
|
||||||
raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found")
|
raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found")
|
||||||
|
|
||||||
extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
|
extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
|
||||||
|
file_type = mapping.get("type", _get_file_type_by_mimetype(tool_file.mimetype))
|
||||||
|
|
||||||
return File(
|
return File(
|
||||||
id=mapping.get("id"),
|
id=mapping.get("id"),
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
filename=tool_file.name,
|
filename=tool_file.name,
|
||||||
type=FileType.value_of(mapping.get("type")),
|
type=file_type,
|
||||||
transfer_method=transfer_method,
|
transfer_method=transfer_method,
|
||||||
remote_url=tool_file.original_url,
|
remote_url=tool_file.original_url,
|
||||||
related_id=tool_file.id,
|
related_id=tool_file.id,
|
||||||
|
|
86
api/poetry.lock
generated
86
api/poetry.lock
generated
|
@ -2411,6 +2411,41 @@ files = [
|
||||||
[package.extras]
|
[package.extras]
|
||||||
test = ["pytest (>=6)"]
|
test = ["pytest (>=6)"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "faker"
|
||||||
|
version = "32.1.0"
|
||||||
|
description = "Faker is a Python package that generates fake data for you."
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
files = [
|
||||||
|
{file = "Faker-32.1.0-py3-none-any.whl", hash = "sha256:c77522577863c264bdc9dad3a2a750ad3f7ee43ff8185072e482992288898814"},
|
||||||
|
{file = "faker-32.1.0.tar.gz", hash = "sha256:aac536ba04e6b7beb2332c67df78485fc29c1880ff723beac6d1efd45e2f10f5"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
python-dateutil = ">=2.4"
|
||||||
|
typing-extensions = "*"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "fal-client"
|
||||||
|
version = "0.5.6"
|
||||||
|
description = "Python client for fal.ai"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
files = [
|
||||||
|
{file = "fal_client-0.5.6-py3-none-any.whl", hash = "sha256:631fd857a3c44753ee46a2eea1e7276471453aca58faac9c3702f744c7c84050"},
|
||||||
|
{file = "fal_client-0.5.6.tar.gz", hash = "sha256:d3afc4b6250023d0ee8437ec504558231d3b106d7aabc12cda8c39883faddecb"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
httpx = ">=0.21.0,<1"
|
||||||
|
httpx-sse = ">=0.4.0,<0.5"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
dev = ["fal-client[docs,test]"]
|
||||||
|
docs = ["sphinx", "sphinx-autodoc-typehints", "sphinx-rtd-theme"]
|
||||||
|
test = ["pillow", "pytest", "pytest-asyncio"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "fastapi"
|
name = "fastapi"
|
||||||
version = "0.115.4"
|
version = "0.115.4"
|
||||||
|
@ -4049,6 +4084,17 @@ http2 = ["h2 (>=3,<5)"]
|
||||||
socks = ["socksio (==1.*)"]
|
socks = ["socksio (==1.*)"]
|
||||||
zstd = ["zstandard (>=0.18.0)"]
|
zstd = ["zstandard (>=0.18.0)"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "httpx-sse"
|
||||||
|
version = "0.4.0"
|
||||||
|
description = "Consume Server-Sent Event (SSE) messages with HTTPX."
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
files = [
|
||||||
|
{file = "httpx-sse-0.4.0.tar.gz", hash = "sha256:1e81a3a3070ce322add1d3529ed42eb5f70817f45ed6ec915ab753f961139721"},
|
||||||
|
{file = "httpx_sse-0.4.0-py3-none-any.whl", hash = "sha256:f329af6eae57eaa2bdfd962b42524764af68075ea87370a2de920af5341e318f"},
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "huggingface-hub"
|
name = "huggingface-hub"
|
||||||
version = "0.16.4"
|
version = "0.16.4"
|
||||||
|
@ -8466,29 +8512,29 @@ pyasn1 = ">=0.1.3"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ruff"
|
name = "ruff"
|
||||||
version = "0.6.9"
|
version = "0.7.3"
|
||||||
description = "An extremely fast Python linter and code formatter, written in Rust."
|
description = "An extremely fast Python linter and code formatter, written in Rust."
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.7"
|
python-versions = ">=3.7"
|
||||||
files = [
|
files = [
|
||||||
{file = "ruff-0.6.9-py3-none-linux_armv6l.whl", hash = "sha256:064df58d84ccc0ac0fcd63bc3090b251d90e2a372558c0f057c3f75ed73e1ccd"},
|
{file = "ruff-0.7.3-py3-none-linux_armv6l.whl", hash = "sha256:34f2339dc22687ec7e7002792d1f50712bf84a13d5152e75712ac08be565d344"},
|
||||||
{file = "ruff-0.6.9-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:140d4b5c9f5fc7a7b074908a78ab8d384dd7f6510402267bc76c37195c02a7ec"},
|
{file = "ruff-0.7.3-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:fb397332a1879b9764a3455a0bb1087bda876c2db8aca3a3cbb67b3dbce8cda0"},
|
||||||
{file = "ruff-0.6.9-py3-none-macosx_11_0_arm64.whl", hash = "sha256:53fd8ca5e82bdee8da7f506d7b03a261f24cd43d090ea9db9a1dc59d9313914c"},
|
{file = "ruff-0.7.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:37d0b619546103274e7f62643d14e1adcbccb242efda4e4bdb9544d7764782e9"},
|
||||||
{file = "ruff-0.6.9-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:645d7d8761f915e48a00d4ecc3686969761df69fb561dd914a773c1a8266e14e"},
|
{file = "ruff-0.7.3-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5d59f0c3ee4d1a6787614e7135b72e21024875266101142a09a61439cb6e38a5"},
|
||||||
{file = "ruff-0.6.9-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eae02b700763e3847595b9d2891488989cac00214da7f845f4bcf2989007d577"},
|
{file = "ruff-0.7.3-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:44eb93c2499a169d49fafd07bc62ac89b1bc800b197e50ff4633aed212569299"},
|
||||||
{file = "ruff-0.6.9-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7d5ccc9e58112441de8ad4b29dcb7a86dc25c5f770e3c06a9d57e0e5eba48829"},
|
{file = "ruff-0.7.3-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6d0242ce53f3a576c35ee32d907475a8d569944c0407f91d207c8af5be5dae4e"},
|
||||||
{file = "ruff-0.6.9-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:417b81aa1c9b60b2f8edc463c58363075412866ae4e2b9ab0f690dc1e87ac1b5"},
|
{file = "ruff-0.7.3-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:6b6224af8b5e09772c2ecb8dc9f3f344c1aa48201c7f07e7315367f6dd90ac29"},
|
||||||
{file = "ruff-0.6.9-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3c866b631f5fbce896a74a6e4383407ba7507b815ccc52bcedabb6810fdb3ef7"},
|
{file = "ruff-0.7.3-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c50f95a82b94421c964fae4c27c0242890a20fe67d203d127e84fbb8013855f5"},
|
||||||
{file = "ruff-0.6.9-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7b118afbb3202f5911486ad52da86d1d52305b59e7ef2031cea3425142b97d6f"},
|
{file = "ruff-0.7.3-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7f3eff9961b5d2644bcf1616c606e93baa2d6b349e8aa8b035f654df252c8c67"},
|
||||||
{file = "ruff-0.6.9-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a67267654edc23c97335586774790cde402fb6bbdb3c2314f1fc087dee320bfa"},
|
{file = "ruff-0.7.3-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8963cab06d130c4df2fd52c84e9f10d297826d2e8169ae0c798b6221be1d1d2"},
|
||||||
{file = "ruff-0.6.9-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:3ef0cc774b00fec123f635ce5c547dac263f6ee9fb9cc83437c5904183b55ceb"},
|
{file = "ruff-0.7.3-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:61b46049d6edc0e4317fb14b33bd693245281a3007288b68a3f5b74a22a0746d"},
|
||||||
{file = "ruff-0.6.9-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:12edd2af0c60fa61ff31cefb90aef4288ac4d372b4962c2864aeea3a1a2460c0"},
|
{file = "ruff-0.7.3-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:10ebce7696afe4644e8c1a23b3cf8c0f2193a310c18387c06e583ae9ef284de2"},
|
||||||
{file = "ruff-0.6.9-py3-none-musllinux_1_2_i686.whl", hash = "sha256:55bb01caeaf3a60b2b2bba07308a02fca6ab56233302406ed5245180a05c5625"},
|
{file = "ruff-0.7.3-py3-none-musllinux_1_2_i686.whl", hash = "sha256:3f36d56326b3aef8eeee150b700e519880d1aab92f471eefdef656fd57492aa2"},
|
||||||
{file = "ruff-0.6.9-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:925d26471fa24b0ce5a6cdfab1bb526fb4159952385f386bdcc643813d472039"},
|
{file = "ruff-0.7.3-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:5d024301109a0007b78d57ab0ba190087b43dce852e552734ebf0b0b85e4fb16"},
|
||||||
{file = "ruff-0.6.9-py3-none-win32.whl", hash = "sha256:eb61ec9bdb2506cffd492e05ac40e5bc6284873aceb605503d8494180d6fc84d"},
|
{file = "ruff-0.7.3-py3-none-win32.whl", hash = "sha256:4ba81a5f0c5478aa61674c5a2194de8b02652f17addf8dfc40c8937e6e7d79fc"},
|
||||||
{file = "ruff-0.6.9-py3-none-win_amd64.whl", hash = "sha256:785d31851c1ae91f45b3d8fe23b8ae4b5170089021fbb42402d811135f0b7117"},
|
{file = "ruff-0.7.3-py3-none-win_amd64.whl", hash = "sha256:588a9ff2fecf01025ed065fe28809cd5a53b43505f48b69a1ac7707b1b7e4088"},
|
||||||
{file = "ruff-0.6.9-py3-none-win_arm64.whl", hash = "sha256:a9641e31476d601f83cd602608739a0840e348bda93fec9f1ee816f8b6798b93"},
|
{file = "ruff-0.7.3-py3-none-win_arm64.whl", hash = "sha256:1713e2c5545863cdbfe2cbce21f69ffaf37b813bfd1fb3b90dc9a6f1963f5a8c"},
|
||||||
{file = "ruff-0.6.9.tar.gz", hash = "sha256:b076ef717a8e5bc819514ee1d602bbdca5b4420ae13a9cf61a0c0a4f53a2baa2"},
|
{file = "ruff-0.7.3.tar.gz", hash = "sha256:e1d1ba2e40b6e71a61b063354d04be669ab0d39c352461f3d789cac68b54a313"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -11005,4 +11051,4 @@ cffi = ["cffi (>=1.11)"]
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.10,<3.13"
|
python-versions = ">=3.10,<3.13"
|
||||||
content-hash = "f20bd678044926913dbbc24bd0cf22503a75817aa55f59457ff7822032139b77"
|
content-hash = "cf4e0467f622e58b51411ee1d784928962f52dbf877b8ee013c810909a1f07db"
|
||||||
|
|
|
@ -122,6 +122,7 @@ celery = "~5.4.0"
|
||||||
chardet = "~5.1.0"
|
chardet = "~5.1.0"
|
||||||
cohere = "~5.2.4"
|
cohere = "~5.2.4"
|
||||||
dashscope = { version = "~1.17.0", extras = ["tokenizer"] }
|
dashscope = { version = "~1.17.0", extras = ["tokenizer"] }
|
||||||
|
fal-client = "0.5.6"
|
||||||
flask = "~3.0.1"
|
flask = "~3.0.1"
|
||||||
flask-compress = "~1.14"
|
flask-compress = "~1.14"
|
||||||
flask-cors = "~4.0.0"
|
flask-cors = "~4.0.0"
|
||||||
|
@ -265,6 +266,7 @@ weaviate-client = "~3.21.0"
|
||||||
optional = true
|
optional = true
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
coverage = "~7.2.4"
|
coverage = "~7.2.4"
|
||||||
|
faker = "~32.1.0"
|
||||||
pytest = "~8.3.2"
|
pytest = "~8.3.2"
|
||||||
pytest-benchmark = "~4.0.0"
|
pytest-benchmark = "~4.0.0"
|
||||||
pytest-env = "~1.1.3"
|
pytest-env = "~1.1.3"
|
||||||
|
@ -278,4 +280,4 @@ pytest-mock = "~3.14.0"
|
||||||
optional = true
|
optional = true
|
||||||
[tool.poetry.group.lint.dependencies]
|
[tool.poetry.group.lint.dependencies]
|
||||||
dotenv-linter = "~0.5.0"
|
dotenv-linter = "~0.5.0"
|
||||||
ruff = "~0.6.9"
|
ruff = "~0.7.3"
|
||||||
|
|
|
@ -11,7 +11,6 @@ from core.model_runtime.entities.message_entities import (
|
||||||
)
|
)
|
||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
from core.model_runtime.model_providers.azure_ai_studio.llm.llm import AzureAIStudioLargeLanguageModel
|
from core.model_runtime.model_providers.azure_ai_studio.llm.llm import AzureAIStudioLargeLanguageModel
|
||||||
from tests.integration_tests.model_runtime.__mock.azure_ai_studio import setup_azure_ai_studio_mock
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("setup_azure_ai_studio_mock", [["chat"]], indirect=True)
|
@pytest.mark.parametrize("setup_azure_ai_studio_mock", [["chat"]], indirect=True)
|
||||||
|
|
|
@ -4,29 +4,21 @@ import pytest
|
||||||
|
|
||||||
from core.model_runtime.entities.rerank_entities import RerankResult
|
from core.model_runtime.entities.rerank_entities import RerankResult
|
||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
from core.model_runtime.model_providers.azure_ai_studio.rerank.rerank import AzureAIStudioRerankModel
|
from core.model_runtime.model_providers.azure_ai_studio.rerank.rerank import AzureRerankModel
|
||||||
|
|
||||||
|
|
||||||
def test_validate_credentials():
|
def test_validate_credentials():
|
||||||
model = AzureAIStudioRerankModel()
|
model = AzureRerankModel()
|
||||||
|
|
||||||
with pytest.raises(CredentialsValidateFailedError):
|
with pytest.raises(CredentialsValidateFailedError):
|
||||||
model.validate_credentials(
|
model.validate_credentials(
|
||||||
model="azure-ai-studio-rerank-v1",
|
model="azure-ai-studio-rerank-v1",
|
||||||
credentials={"api_key": "invalid_key", "api_base": os.getenv("AZURE_AI_STUDIO_API_BASE")},
|
credentials={"api_key": "invalid_key", "api_base": os.getenv("AZURE_AI_STUDIO_API_BASE")},
|
||||||
query="What is the capital of the United States?",
|
|
||||||
docs=[
|
|
||||||
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
|
|
||||||
"Census, Carson City had a population of 55,274.",
|
|
||||||
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
|
|
||||||
"are a political division controlled by the United States. Its capital is Saipan.",
|
|
||||||
],
|
|
||||||
score_threshold=0.8,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_invoke_model():
|
def test_invoke_model():
|
||||||
model = AzureAIStudioRerankModel()
|
model = AzureRerankModel()
|
||||||
|
|
||||||
result = model.invoke(
|
result = model.invoke(
|
||||||
model="azure-ai-studio-rerank-v1",
|
model="azure-ai-studio-rerank-v1",
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import os
|
import os
|
||||||
|
from collections import UserDict
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -11,7 +12,7 @@ from pymochow.model.table import Table
|
||||||
from requests.adapters import HTTPAdapter
|
from requests.adapters import HTTPAdapter
|
||||||
|
|
||||||
|
|
||||||
class AttrDict(dict):
|
class AttrDict(UserDict):
|
||||||
def __getattr__(self, item):
|
def __getattr__(self, item):
|
||||||
return self.get(item)
|
return self.get(item)
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import os
|
import os
|
||||||
|
from collections import UserDict
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -50,7 +51,7 @@ class MockIndex:
|
||||||
return AttrDict({"dimension": 1024})
|
return AttrDict({"dimension": 1024})
|
||||||
|
|
||||||
|
|
||||||
class AttrDict(dict):
|
class AttrDict(UserDict):
|
||||||
def __getattr__(self, item):
|
def __getattr__(self, item):
|
||||||
return self.get(item)
|
return self.get(item)
|
||||||
|
|
||||||
|
|
|
@ -1,125 +1,484 @@
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from configs import dify_config
|
||||||
|
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
|
||||||
|
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
||||||
|
from core.entities.provider_entities import CustomConfiguration, SystemConfiguration
|
||||||
from core.file import File, FileTransferMethod, FileType
|
from core.file import File, FileTransferMethod, FileType
|
||||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
|
from core.model_runtime.entities.message_entities import (
|
||||||
|
AssistantPromptMessage,
|
||||||
|
ImagePromptMessageContent,
|
||||||
|
PromptMessage,
|
||||||
|
PromptMessageRole,
|
||||||
|
SystemPromptMessage,
|
||||||
|
TextPromptMessageContent,
|
||||||
|
UserPromptMessage,
|
||||||
|
)
|
||||||
|
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType, ProviderModel
|
||||||
|
from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
|
||||||
|
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||||
|
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||||
from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment
|
from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
|
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
|
||||||
from core.workflow.nodes.answer import AnswerStreamGenerateRoute
|
from core.workflow.nodes.answer import AnswerStreamGenerateRoute
|
||||||
from core.workflow.nodes.end import EndStreamParam
|
from core.workflow.nodes.end import EndStreamParam
|
||||||
from core.workflow.nodes.llm.entities import ContextConfig, LLMNodeData, ModelConfig, VisionConfig, VisionConfigOptions
|
from core.workflow.nodes.llm.entities import (
|
||||||
|
ContextConfig,
|
||||||
|
LLMNodeChatModelMessage,
|
||||||
|
LLMNodeData,
|
||||||
|
ModelConfig,
|
||||||
|
VisionConfig,
|
||||||
|
VisionConfigOptions,
|
||||||
|
)
|
||||||
from core.workflow.nodes.llm.node import LLMNode
|
from core.workflow.nodes.llm.node import LLMNode
|
||||||
from models.enums import UserFrom
|
from models.enums import UserFrom
|
||||||
|
from models.provider import ProviderType
|
||||||
from models.workflow import WorkflowType
|
from models.workflow import WorkflowType
|
||||||
|
from tests.unit_tests.core.workflow.nodes.llm.test_scenarios import LLMNodeTestScenario
|
||||||
|
|
||||||
|
|
||||||
class TestLLMNode:
|
class MockTokenBufferMemory:
|
||||||
@pytest.fixture
|
def __init__(self, history_messages=None):
|
||||||
def llm_node(self):
|
self.history_messages = history_messages or []
|
||||||
data = LLMNodeData(
|
|
||||||
title="Test LLM",
|
|
||||||
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
|
|
||||||
prompt_template=[],
|
|
||||||
memory=None,
|
|
||||||
context=ContextConfig(enabled=False),
|
|
||||||
vision=VisionConfig(
|
|
||||||
enabled=True,
|
|
||||||
configs=VisionConfigOptions(
|
|
||||||
variable_selector=["sys", "files"],
|
|
||||||
detail=ImagePromptMessageContent.DETAIL.HIGH,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
variable_pool = VariablePool(
|
|
||||||
system_variables={},
|
|
||||||
user_inputs={},
|
|
||||||
)
|
|
||||||
node = LLMNode(
|
|
||||||
id="1",
|
|
||||||
config={
|
|
||||||
"id": "1",
|
|
||||||
"data": data.model_dump(),
|
|
||||||
},
|
|
||||||
graph_init_params=GraphInitParams(
|
|
||||||
tenant_id="1",
|
|
||||||
app_id="1",
|
|
||||||
workflow_type=WorkflowType.WORKFLOW,
|
|
||||||
workflow_id="1",
|
|
||||||
graph_config={},
|
|
||||||
user_id="1",
|
|
||||||
user_from=UserFrom.ACCOUNT,
|
|
||||||
invoke_from=InvokeFrom.SERVICE_API,
|
|
||||||
call_depth=0,
|
|
||||||
),
|
|
||||||
graph=Graph(
|
|
||||||
root_node_id="1",
|
|
||||||
answer_stream_generate_routes=AnswerStreamGenerateRoute(
|
|
||||||
answer_dependencies={},
|
|
||||||
answer_generate_route={},
|
|
||||||
),
|
|
||||||
end_stream_param=EndStreamParam(
|
|
||||||
end_dependencies={},
|
|
||||||
end_stream_variable_selector_mapping={},
|
|
||||||
),
|
|
||||||
),
|
|
||||||
graph_runtime_state=GraphRuntimeState(
|
|
||||||
variable_pool=variable_pool,
|
|
||||||
start_at=0,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
return node
|
|
||||||
|
|
||||||
def test_fetch_files_with_file_segment(self, llm_node):
|
def get_history_prompt_messages(
|
||||||
file = File(
|
self, max_token_limit: int = 2000, message_limit: Optional[int] = None
|
||||||
|
) -> Sequence[PromptMessage]:
|
||||||
|
if message_limit is not None:
|
||||||
|
return self.history_messages[-message_limit * 2 :]
|
||||||
|
return self.history_messages
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def llm_node():
|
||||||
|
data = LLMNodeData(
|
||||||
|
title="Test LLM",
|
||||||
|
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
|
||||||
|
prompt_template=[],
|
||||||
|
memory=None,
|
||||||
|
context=ContextConfig(enabled=False),
|
||||||
|
vision=VisionConfig(
|
||||||
|
enabled=True,
|
||||||
|
configs=VisionConfigOptions(
|
||||||
|
variable_selector=["sys", "files"],
|
||||||
|
detail=ImagePromptMessageContent.DETAIL.HIGH,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
variable_pool = VariablePool(
|
||||||
|
system_variables={},
|
||||||
|
user_inputs={},
|
||||||
|
)
|
||||||
|
node = LLMNode(
|
||||||
|
id="1",
|
||||||
|
config={
|
||||||
|
"id": "1",
|
||||||
|
"data": data.model_dump(),
|
||||||
|
},
|
||||||
|
graph_init_params=GraphInitParams(
|
||||||
|
tenant_id="1",
|
||||||
|
app_id="1",
|
||||||
|
workflow_type=WorkflowType.WORKFLOW,
|
||||||
|
workflow_id="1",
|
||||||
|
graph_config={},
|
||||||
|
user_id="1",
|
||||||
|
user_from=UserFrom.ACCOUNT,
|
||||||
|
invoke_from=InvokeFrom.SERVICE_API,
|
||||||
|
call_depth=0,
|
||||||
|
),
|
||||||
|
graph=Graph(
|
||||||
|
root_node_id="1",
|
||||||
|
answer_stream_generate_routes=AnswerStreamGenerateRoute(
|
||||||
|
answer_dependencies={},
|
||||||
|
answer_generate_route={},
|
||||||
|
),
|
||||||
|
end_stream_param=EndStreamParam(
|
||||||
|
end_dependencies={},
|
||||||
|
end_stream_variable_selector_mapping={},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
graph_runtime_state=GraphRuntimeState(
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
start_at=0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def model_config():
|
||||||
|
# Create actual provider and model type instances
|
||||||
|
model_provider_factory = ModelProviderFactory()
|
||||||
|
provider_instance = model_provider_factory.get_provider_instance("openai")
|
||||||
|
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
|
||||||
|
|
||||||
|
# Create a ProviderModelBundle
|
||||||
|
provider_model_bundle = ProviderModelBundle(
|
||||||
|
configuration=ProviderConfiguration(
|
||||||
|
tenant_id="1",
|
||||||
|
provider=provider_instance.get_provider_schema(),
|
||||||
|
preferred_provider_type=ProviderType.CUSTOM,
|
||||||
|
using_provider_type=ProviderType.CUSTOM,
|
||||||
|
system_configuration=SystemConfiguration(enabled=False),
|
||||||
|
custom_configuration=CustomConfiguration(provider=None),
|
||||||
|
model_settings=[],
|
||||||
|
),
|
||||||
|
provider_instance=provider_instance,
|
||||||
|
model_type_instance=model_type_instance,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create and return a ModelConfigWithCredentialsEntity
|
||||||
|
return ModelConfigWithCredentialsEntity(
|
||||||
|
provider="openai",
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
model_schema=AIModelEntity(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
label=I18nObject(en_US="GPT-3.5 Turbo"),
|
||||||
|
model_type=ModelType.LLM,
|
||||||
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
|
model_properties={},
|
||||||
|
),
|
||||||
|
mode="chat",
|
||||||
|
credentials={},
|
||||||
|
parameters={},
|
||||||
|
provider_model_bundle=provider_model_bundle,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_files_with_file_segment(llm_node):
|
||||||
|
file = File(
|
||||||
|
id="1",
|
||||||
|
tenant_id="test",
|
||||||
|
type=FileType.IMAGE,
|
||||||
|
filename="test.jpg",
|
||||||
|
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||||
|
related_id="1",
|
||||||
|
)
|
||||||
|
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file)
|
||||||
|
|
||||||
|
result = llm_node._fetch_files(selector=["sys", "files"])
|
||||||
|
assert result == [file]
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_files_with_array_file_segment(llm_node):
|
||||||
|
files = [
|
||||||
|
File(
|
||||||
id="1",
|
id="1",
|
||||||
tenant_id="test",
|
tenant_id="test",
|
||||||
type=FileType.IMAGE,
|
type=FileType.IMAGE,
|
||||||
filename="test.jpg",
|
filename="test1.jpg",
|
||||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||||
related_id="1",
|
related_id="1",
|
||||||
|
),
|
||||||
|
File(
|
||||||
|
id="2",
|
||||||
|
tenant_id="test",
|
||||||
|
type=FileType.IMAGE,
|
||||||
|
filename="test2.jpg",
|
||||||
|
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||||
|
related_id="2",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))
|
||||||
|
|
||||||
|
result = llm_node._fetch_files(selector=["sys", "files"])
|
||||||
|
assert result == files
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_files_with_none_segment(llm_node):
|
||||||
|
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment())
|
||||||
|
|
||||||
|
result = llm_node._fetch_files(selector=["sys", "files"])
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_files_with_array_any_segment(llm_node):
|
||||||
|
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[]))
|
||||||
|
|
||||||
|
result = llm_node._fetch_files(selector=["sys", "files"])
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_files_with_non_existent_variable(llm_node):
|
||||||
|
result = llm_node._fetch_files(selector=["sys", "files"])
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config):
|
||||||
|
prompt_template = []
|
||||||
|
llm_node.node_data.prompt_template = prompt_template
|
||||||
|
|
||||||
|
fake_vision_detail = faker.random_element(
|
||||||
|
[ImagePromptMessageContent.DETAIL.HIGH, ImagePromptMessageContent.DETAIL.LOW]
|
||||||
|
)
|
||||||
|
fake_remote_url = faker.url()
|
||||||
|
files = [
|
||||||
|
File(
|
||||||
|
id="1",
|
||||||
|
tenant_id="test",
|
||||||
|
type=FileType.IMAGE,
|
||||||
|
filename="test1.jpg",
|
||||||
|
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||||
|
remote_url=fake_remote_url,
|
||||||
)
|
)
|
||||||
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file)
|
]
|
||||||
|
|
||||||
result = llm_node._fetch_files(selector=["sys", "files"])
|
fake_query = faker.sentence()
|
||||||
assert result == [file]
|
|
||||||
|
|
||||||
def test_fetch_files_with_array_file_segment(self, llm_node):
|
prompt_messages, _ = llm_node._fetch_prompt_messages(
|
||||||
files = [
|
user_query=fake_query,
|
||||||
File(
|
user_files=files,
|
||||||
id="1",
|
context=None,
|
||||||
tenant_id="test",
|
memory=None,
|
||||||
type=FileType.IMAGE,
|
model_config=model_config,
|
||||||
filename="test1.jpg",
|
prompt_template=prompt_template,
|
||||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
memory_config=None,
|
||||||
related_id="1",
|
vision_enabled=False,
|
||||||
),
|
vision_detail=fake_vision_detail,
|
||||||
File(
|
variable_pool=llm_node.graph_runtime_state.variable_pool,
|
||||||
id="2",
|
jinja2_variables=[],
|
||||||
tenant_id="test",
|
)
|
||||||
type=FileType.IMAGE,
|
|
||||||
filename="test2.jpg",
|
|
||||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
|
||||||
related_id="2",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))
|
|
||||||
|
|
||||||
result = llm_node._fetch_files(selector=["sys", "files"])
|
assert prompt_messages == [UserPromptMessage(content=fake_query)]
|
||||||
assert result == files
|
|
||||||
|
|
||||||
def test_fetch_files_with_none_segment(self, llm_node):
|
|
||||||
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment())
|
|
||||||
|
|
||||||
result = llm_node._fetch_files(selector=["sys", "files"])
|
def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
|
||||||
assert result == []
|
# Setup dify config
|
||||||
|
dify_config.MULTIMODAL_SEND_IMAGE_FORMAT = "url"
|
||||||
|
dify_config.MULTIMODAL_SEND_VIDEO_FORMAT = "url"
|
||||||
|
|
||||||
def test_fetch_files_with_array_any_segment(self, llm_node):
|
# Generate fake values for prompt template
|
||||||
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[]))
|
fake_assistant_prompt = faker.sentence()
|
||||||
|
fake_query = faker.sentence()
|
||||||
|
fake_context = faker.sentence()
|
||||||
|
fake_window_size = faker.random_int(min=1, max=3)
|
||||||
|
fake_vision_detail = faker.random_element(
|
||||||
|
[ImagePromptMessageContent.DETAIL.HIGH, ImagePromptMessageContent.DETAIL.LOW]
|
||||||
|
)
|
||||||
|
fake_remote_url = faker.url()
|
||||||
|
|
||||||
result = llm_node._fetch_files(selector=["sys", "files"])
|
# Setup mock memory with history messages
|
||||||
assert result == []
|
mock_history = [
|
||||||
|
UserPromptMessage(content=faker.sentence()),
|
||||||
|
AssistantPromptMessage(content=faker.sentence()),
|
||||||
|
UserPromptMessage(content=faker.sentence()),
|
||||||
|
AssistantPromptMessage(content=faker.sentence()),
|
||||||
|
UserPromptMessage(content=faker.sentence()),
|
||||||
|
AssistantPromptMessage(content=faker.sentence()),
|
||||||
|
]
|
||||||
|
|
||||||
def test_fetch_files_with_non_existent_variable(self, llm_node):
|
# Setup memory configuration
|
||||||
result = llm_node._fetch_files(selector=["sys", "files"])
|
memory_config = MemoryConfig(
|
||||||
assert result == []
|
role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"),
|
||||||
|
window=MemoryConfig.WindowConfig(enabled=True, size=fake_window_size),
|
||||||
|
query_prompt_template=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
memory = MockTokenBufferMemory(history_messages=mock_history)
|
||||||
|
|
||||||
|
# Test scenarios covering different file input combinations
|
||||||
|
test_scenarios = [
|
||||||
|
LLMNodeTestScenario(
|
||||||
|
description="No files",
|
||||||
|
user_query=fake_query,
|
||||||
|
user_files=[],
|
||||||
|
features=[],
|
||||||
|
vision_enabled=False,
|
||||||
|
vision_detail=None,
|
||||||
|
window_size=fake_window_size,
|
||||||
|
prompt_template=[
|
||||||
|
LLMNodeChatModelMessage(
|
||||||
|
text=fake_context,
|
||||||
|
role=PromptMessageRole.SYSTEM,
|
||||||
|
edition_type="basic",
|
||||||
|
),
|
||||||
|
LLMNodeChatModelMessage(
|
||||||
|
text="{#context#}",
|
||||||
|
role=PromptMessageRole.USER,
|
||||||
|
edition_type="basic",
|
||||||
|
),
|
||||||
|
LLMNodeChatModelMessage(
|
||||||
|
text=fake_assistant_prompt,
|
||||||
|
role=PromptMessageRole.ASSISTANT,
|
||||||
|
edition_type="basic",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
expected_messages=[
|
||||||
|
SystemPromptMessage(content=fake_context),
|
||||||
|
UserPromptMessage(content=fake_context),
|
||||||
|
AssistantPromptMessage(content=fake_assistant_prompt),
|
||||||
|
]
|
||||||
|
+ mock_history[fake_window_size * -2 :]
|
||||||
|
+ [
|
||||||
|
UserPromptMessage(content=fake_query),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
LLMNodeTestScenario(
|
||||||
|
description="User files",
|
||||||
|
user_query=fake_query,
|
||||||
|
user_files=[
|
||||||
|
File(
|
||||||
|
tenant_id="test",
|
||||||
|
type=FileType.IMAGE,
|
||||||
|
filename="test1.jpg",
|
||||||
|
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||||
|
remote_url=fake_remote_url,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
vision_enabled=True,
|
||||||
|
vision_detail=fake_vision_detail,
|
||||||
|
features=[ModelFeature.VISION],
|
||||||
|
window_size=fake_window_size,
|
||||||
|
prompt_template=[
|
||||||
|
LLMNodeChatModelMessage(
|
||||||
|
text=fake_context,
|
||||||
|
role=PromptMessageRole.SYSTEM,
|
||||||
|
edition_type="basic",
|
||||||
|
),
|
||||||
|
LLMNodeChatModelMessage(
|
||||||
|
text="{#context#}",
|
||||||
|
role=PromptMessageRole.USER,
|
||||||
|
edition_type="basic",
|
||||||
|
),
|
||||||
|
LLMNodeChatModelMessage(
|
||||||
|
text=fake_assistant_prompt,
|
||||||
|
role=PromptMessageRole.ASSISTANT,
|
||||||
|
edition_type="basic",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
expected_messages=[
|
||||||
|
SystemPromptMessage(content=fake_context),
|
||||||
|
UserPromptMessage(content=fake_context),
|
||||||
|
AssistantPromptMessage(content=fake_assistant_prompt),
|
||||||
|
]
|
||||||
|
+ mock_history[fake_window_size * -2 :]
|
||||||
|
+ [
|
||||||
|
UserPromptMessage(
|
||||||
|
content=[
|
||||||
|
TextPromptMessageContent(data=fake_query),
|
||||||
|
ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
LLMNodeTestScenario(
|
||||||
|
description="Prompt template with variable selector of File",
|
||||||
|
user_query=fake_query,
|
||||||
|
user_files=[],
|
||||||
|
vision_enabled=False,
|
||||||
|
vision_detail=fake_vision_detail,
|
||||||
|
features=[ModelFeature.VISION],
|
||||||
|
window_size=fake_window_size,
|
||||||
|
prompt_template=[
|
||||||
|
LLMNodeChatModelMessage(
|
||||||
|
text="{{#input.image#}}",
|
||||||
|
role=PromptMessageRole.USER,
|
||||||
|
edition_type="basic",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
expected_messages=[
|
||||||
|
UserPromptMessage(
|
||||||
|
content=[
|
||||||
|
ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
]
|
||||||
|
+ mock_history[fake_window_size * -2 :]
|
||||||
|
+ [UserPromptMessage(content=fake_query)],
|
||||||
|
file_variables={
|
||||||
|
"input.image": File(
|
||||||
|
tenant_id="test",
|
||||||
|
type=FileType.IMAGE,
|
||||||
|
filename="test1.jpg",
|
||||||
|
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||||
|
remote_url=fake_remote_url,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
),
|
||||||
|
LLMNodeTestScenario(
|
||||||
|
description="Prompt template with variable selector of File without vision feature",
|
||||||
|
user_query=fake_query,
|
||||||
|
user_files=[],
|
||||||
|
vision_enabled=True,
|
||||||
|
vision_detail=fake_vision_detail,
|
||||||
|
features=[],
|
||||||
|
window_size=fake_window_size,
|
||||||
|
prompt_template=[
|
||||||
|
LLMNodeChatModelMessage(
|
||||||
|
text="{{#input.image#}}",
|
||||||
|
role=PromptMessageRole.USER,
|
||||||
|
edition_type="basic",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
expected_messages=mock_history[fake_window_size * -2 :] + [UserPromptMessage(content=fake_query)],
|
||||||
|
file_variables={
|
||||||
|
"input.image": File(
|
||||||
|
tenant_id="test",
|
||||||
|
type=FileType.IMAGE,
|
||||||
|
filename="test1.jpg",
|
||||||
|
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||||
|
remote_url=fake_remote_url,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
),
|
||||||
|
LLMNodeTestScenario(
|
||||||
|
description="Prompt template with variable selector of File with video file and vision feature",
|
||||||
|
user_query=fake_query,
|
||||||
|
user_files=[],
|
||||||
|
vision_enabled=True,
|
||||||
|
vision_detail=fake_vision_detail,
|
||||||
|
features=[ModelFeature.VISION],
|
||||||
|
window_size=fake_window_size,
|
||||||
|
prompt_template=[
|
||||||
|
LLMNodeChatModelMessage(
|
||||||
|
text="{{#input.image#}}",
|
||||||
|
role=PromptMessageRole.USER,
|
||||||
|
edition_type="basic",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
expected_messages=mock_history[fake_window_size * -2 :] + [UserPromptMessage(content=fake_query)],
|
||||||
|
file_variables={
|
||||||
|
"input.image": File(
|
||||||
|
tenant_id="test",
|
||||||
|
type=FileType.VIDEO,
|
||||||
|
filename="test1.mp4",
|
||||||
|
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||||
|
remote_url=fake_remote_url,
|
||||||
|
extension="mp4",
|
||||||
|
)
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
for scenario in test_scenarios:
|
||||||
|
model_config.model_schema.features = scenario.features
|
||||||
|
|
||||||
|
for k, v in scenario.file_variables.items():
|
||||||
|
selector = k.split(".")
|
||||||
|
llm_node.graph_runtime_state.variable_pool.add(selector, v)
|
||||||
|
|
||||||
|
# Call the method under test
|
||||||
|
prompt_messages, _ = llm_node._fetch_prompt_messages(
|
||||||
|
user_query=scenario.user_query,
|
||||||
|
user_files=scenario.user_files,
|
||||||
|
context=fake_context,
|
||||||
|
memory=memory,
|
||||||
|
model_config=model_config,
|
||||||
|
prompt_template=scenario.prompt_template,
|
||||||
|
memory_config=memory_config,
|
||||||
|
vision_enabled=scenario.vision_enabled,
|
||||||
|
vision_detail=scenario.vision_detail,
|
||||||
|
variable_pool=llm_node.graph_runtime_state.variable_pool,
|
||||||
|
jinja2_variables=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the result
|
||||||
|
assert len(prompt_messages) == len(scenario.expected_messages), f"Scenario failed: {scenario.description}"
|
||||||
|
assert (
|
||||||
|
prompt_messages == scenario.expected_messages
|
||||||
|
), f"Message content mismatch in scenario: {scenario.description}"
|
||||||
|
|
|
@ -0,0 +1,25 @@
|
||||||
|
from collections.abc import Mapping, Sequence
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from core.file import File
|
||||||
|
from core.model_runtime.entities.message_entities import PromptMessage
|
||||||
|
from core.model_runtime.entities.model_entities import ModelFeature
|
||||||
|
from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage
|
||||||
|
|
||||||
|
|
||||||
|
class LLMNodeTestScenario(BaseModel):
|
||||||
|
"""Test scenario for LLM node testing."""
|
||||||
|
|
||||||
|
description: str = Field(..., description="Description of the test scenario")
|
||||||
|
user_query: str = Field(..., description="User query input")
|
||||||
|
user_files: Sequence[File] = Field(default_factory=list, description="List of user files")
|
||||||
|
vision_enabled: bool = Field(default=False, description="Whether vision is enabled")
|
||||||
|
vision_detail: str | None = Field(None, description="Vision detail level if vision is enabled")
|
||||||
|
features: Sequence[ModelFeature] = Field(default_factory=list, description="List of model features")
|
||||||
|
window_size: int = Field(..., description="Window size for memory")
|
||||||
|
prompt_template: Sequence[LLMNodeChatModelMessage] = Field(..., description="Template for prompt messages")
|
||||||
|
file_variables: Mapping[str, File | Sequence[File]] = Field(
|
||||||
|
default_factory=dict, description="List of file variables"
|
||||||
|
)
|
||||||
|
expected_messages: Sequence[PromptMessage] = Field(..., description="Expected messages after processing")
|
|
@ -1,4 +1,5 @@
|
||||||
import os
|
import os
|
||||||
|
from collections import UserDict
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -14,7 +15,7 @@ from tests.unit_tests.oss.__mock.base import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class AttrDict(dict):
|
class AttrDict(UserDict):
|
||||||
def __getattr__(self, item):
|
def __getattr__(self, item):
|
||||||
return self.get(item)
|
return self.get(item)
|
||||||
|
|
||||||
|
|
|
@ -44,12 +44,6 @@ export const fileUpload: FileUpload = ({
|
||||||
}
|
}
|
||||||
|
|
||||||
export const getFileExtension = (fileName: string, fileMimetype: string, isRemote?: boolean) => {
|
export const getFileExtension = (fileName: string, fileMimetype: string, isRemote?: boolean) => {
|
||||||
if (fileMimetype)
|
|
||||||
return mime.getExtension(fileMimetype) || ''
|
|
||||||
|
|
||||||
if (isRemote)
|
|
||||||
return ''
|
|
||||||
|
|
||||||
if (fileName) {
|
if (fileName) {
|
||||||
const fileNamePair = fileName.split('.')
|
const fileNamePair = fileName.split('.')
|
||||||
const fileNamePairLength = fileNamePair.length
|
const fileNamePairLength = fileNamePair.length
|
||||||
|
@ -58,6 +52,12 @@ export const getFileExtension = (fileName: string, fileMimetype: string, isRemot
|
||||||
return fileNamePair[fileNamePairLength - 1]
|
return fileNamePair[fileNamePairLength - 1]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (fileMimetype)
|
||||||
|
return mime.getExtension(fileMimetype) || ''
|
||||||
|
|
||||||
|
if (isRemote)
|
||||||
|
return ''
|
||||||
|
|
||||||
return ''
|
return ''
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -144,6 +144,7 @@ const ConfigPromptItem: FC<Props> = ({
|
||||||
onEditionTypeChange={onEditionTypeChange}
|
onEditionTypeChange={onEditionTypeChange}
|
||||||
varList={varList}
|
varList={varList}
|
||||||
handleAddVariable={handleAddVariable}
|
handleAddVariable={handleAddVariable}
|
||||||
|
isSupportFileVar
|
||||||
/>
|
/>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
|
@ -67,6 +67,7 @@ const Panel: FC<NodePanelProps<LLMNodeType>> = ({
|
||||||
handleStop,
|
handleStop,
|
||||||
varInputs,
|
varInputs,
|
||||||
runResult,
|
runResult,
|
||||||
|
filterJinjia2InputVar,
|
||||||
} = useConfig(id, data)
|
} = useConfig(id, data)
|
||||||
|
|
||||||
const model = inputs.model
|
const model = inputs.model
|
||||||
|
@ -194,7 +195,7 @@ const Panel: FC<NodePanelProps<LLMNodeType>> = ({
|
||||||
list={inputs.prompt_config?.jinja2_variables || []}
|
list={inputs.prompt_config?.jinja2_variables || []}
|
||||||
onChange={handleVarListChange}
|
onChange={handleVarListChange}
|
||||||
onVarNameChange={handleVarNameChange}
|
onVarNameChange={handleVarNameChange}
|
||||||
filterVar={filterVar}
|
filterVar={filterJinjia2InputVar}
|
||||||
/>
|
/>
|
||||||
</Field>
|
</Field>
|
||||||
)}
|
)}
|
||||||
|
@ -233,6 +234,7 @@ const Panel: FC<NodePanelProps<LLMNodeType>> = ({
|
||||||
hasSetBlockStatus={hasSetBlockStatus}
|
hasSetBlockStatus={hasSetBlockStatus}
|
||||||
nodesOutputVars={availableVars}
|
nodesOutputVars={availableVars}
|
||||||
availableNodes={availableNodesWithParent}
|
availableNodes={availableNodesWithParent}
|
||||||
|
isSupportFileVar
|
||||||
/>
|
/>
|
||||||
|
|
||||||
{inputs.memory.query_prompt_template && !inputs.memory.query_prompt_template.includes('{{#sys.query#}}') && (
|
{inputs.memory.query_prompt_template && !inputs.memory.query_prompt_template.includes('{{#sys.query#}}') && (
|
||||||
|
|
|
@ -278,11 +278,15 @@ const useConfig = (id: string, payload: LLMNodeType) => {
|
||||||
}, [inputs, setInputs])
|
}, [inputs, setInputs])
|
||||||
|
|
||||||
const filterInputVar = useCallback((varPayload: Var) => {
|
const filterInputVar = useCallback((varPayload: Var) => {
|
||||||
|
return [VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber, VarType.arrayFile].includes(varPayload.type)
|
||||||
|
}, [])
|
||||||
|
|
||||||
|
const filterJinjia2InputVar = useCallback((varPayload: Var) => {
|
||||||
return [VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber].includes(varPayload.type)
|
return [VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber].includes(varPayload.type)
|
||||||
}, [])
|
}, [])
|
||||||
|
|
||||||
const filterMemoryPromptVar = useCallback((varPayload: Var) => {
|
const filterMemoryPromptVar = useCallback((varPayload: Var) => {
|
||||||
return [VarType.arrayObject, VarType.array, VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber].includes(varPayload.type)
|
return [VarType.arrayObject, VarType.array, VarType.number, VarType.string, VarType.secret, VarType.arrayString, VarType.arrayNumber, VarType.arrayFile].includes(varPayload.type)
|
||||||
}, [])
|
}, [])
|
||||||
|
|
||||||
const {
|
const {
|
||||||
|
@ -406,6 +410,7 @@ const useConfig = (id: string, payload: LLMNodeType) => {
|
||||||
handleRun,
|
handleRun,
|
||||||
handleStop,
|
handleStop,
|
||||||
runResult,
|
runResult,
|
||||||
|
filterJinjia2InputVar,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -86,8 +86,8 @@ const translation = {
|
||||||
agenteLogDetail: {
|
agenteLogDetail: {
|
||||||
agentMode: 'Modo Agente',
|
agentMode: 'Modo Agente',
|
||||||
toolUsed: 'Ferramenta usada',
|
toolUsed: 'Ferramenta usada',
|
||||||
iterações: 'Iterações',
|
iterations: 'Iterações',
|
||||||
iteração: 'Iteração',
|
iteration: 'Iteração',
|
||||||
finalProcessing: 'Processamento Final',
|
finalProcessing: 'Processamento Final',
|
||||||
},
|
},
|
||||||
agentLogDetail: {
|
agentLogDetail: {
|
||||||
|
|
Loading…
Reference in New Issue
Block a user