Merge branch 'fix/chore-fix' into dev/plugin-deploy
Some checks failed
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled

This commit is contained in:
Yeuoly 2024-11-06 18:30:02 +08:00
commit 56f2464a4f
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61
163 changed files with 4642 additions and 984 deletions

View File

@ -81,7 +81,7 @@ Dify requires the following dependencies to build, make sure they're installed o
Dify is composed of a backend and a frontend. Navigate to the backend directory by `cd api/`, then follow the [Backend README](api/README.md) to install it. In a separate terminal, navigate to the frontend directory by `cd web/`, then follow the [Frontend README](web/README.md) to install.
Check the [installation FAQ](https://docs.dify.ai/learn-more/faq/self-host-faq) for a list of common issues and steps to troubleshoot.
Check the [installation FAQ](https://docs.dify.ai/learn-more/faq/install-faq) for a list of common issues and steps to troubleshoot.
### 5. Visit dify in your browser

View File

@ -79,7 +79,7 @@ Dify yêu cầu các phụ thuộc sau để build, hãy đảm bảo chúng đ
Dify bao gồm một backend và một frontend. Đi đến thư mục backend bằng lệnh `cd api/`, sau đó làm theo hướng dẫn trong [README của Backend](api/README.md) để cài đặt. Trong một terminal khác, đi đến thư mục frontend bằng lệnh `cd web/`, sau đó làm theo hướng dẫn trong [README của Frontend](web/README.md) để cài đặt.
Kiểm tra [FAQ về cài đặt](https://docs.dify.ai/learn-more/faq/self-host-faq) để xem danh sách các vấn đề thường gặp và các bước khắc phục.
Kiểm tra [FAQ về cài đặt](https://docs.dify.ai/learn-more/faq/install-faq) để xem danh sách các vấn đề thường gặp và các bước khắc phục.
### 5. Truy cập Dify trong trình duyệt của bạn

View File

@ -120,7 +120,8 @@ SUPABASE_URL=your-server-url
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash
# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm
VECTOR_STORE=weaviate
# Weaviate configuration
@ -263,6 +264,11 @@ VIKINGDB_SCHEMA=http
VIKINGDB_CONNECTION_TIMEOUT=30
VIKINGDB_SOCKET_TIMEOUT=30
# Lindorm configuration
LINDORM_URL=http://ld-*******************-proxy-search-pub.lindorm.aliyuncs.com:30070
LINDORM_USERNAME=admin
LINDORM_PASSWORD=admin
# OceanBase Vector configuration
OCEANBASE_VECTOR_HOST=127.0.0.1
OCEANBASE_VECTOR_PORT=2881
@ -271,6 +277,7 @@ OCEANBASE_VECTOR_PASSWORD=
OCEANBASE_VECTOR_DATABASE=test
OCEANBASE_MEMORY_LIMIT=6G
# Upload configuration
UPLOAD_FILE_SIZE_LIMIT=15
UPLOAD_FILE_BATCH_LIMIT=5
@ -320,6 +327,9 @@ SSRF_DEFAULT_MAX_RETRIES=3
BATCH_UPLOAD_LIMIT=10
KEYWORD_DATA_SOURCE_TYPE=database
# Workflow file upload limit
WORKFLOW_FILE_UPLOAD_LIMIT=10
# CODE EXECUTION CONFIGURATION
CODE_EXECUTION_ENDPOINT=http://127.0.0.1:8194
CODE_EXECUTION_API_KEY=dify-sandbox

View File

@ -55,12 +55,7 @@ RUN apt-get update \
&& echo "deb http://deb.debian.org/debian testing main" > /etc/apt/sources.list \
&& apt-get update \
# For Security
&& apt-get install -y --no-install-recommends expat=2.6.3-2 libldap-2.5-0=2.5.18+dfsg-3+b1 perl=5.40.0-6 libsqlite3-0=3.46.1-1 \
&& if [ "$(dpkg --print-architecture)" = "amd64" ]; then \
apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1+b1; \
else \
apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1; \
fi \
&& apt-get install -y --no-install-recommends expat=2.6.3-2 libldap-2.5-0=2.5.18+dfsg-3+b1 perl=5.40.0-6 libsqlite3-0=3.46.1-1 zlib1g=1:1.3.dfsg+really1.3.1-1+b1 \
# install a chinese font to support the use of tools like matplotlib
&& apt-get install -y fonts-noto-cjk \
&& apt-get autoremove -y \

View File

@ -269,6 +269,11 @@ class FileUploadConfig(BaseSettings):
default=20,
)
WORKFLOW_FILE_UPLOAD_LIMIT: PositiveInt = Field(
description="Maximum number of files allowed in a workflow upload operation",
default=10,
)
class HttpConfig(BaseSettings):
"""

View File

@ -20,6 +20,7 @@ from configs.middleware.vdb.baidu_vector_config import BaiduVectorDBConfig
from configs.middleware.vdb.chroma_config import ChromaConfig
from configs.middleware.vdb.couchbase_config import CouchbaseConfig
from configs.middleware.vdb.elasticsearch_config import ElasticsearchConfig
from configs.middleware.vdb.lindorm_config import LindormConfig
from configs.middleware.vdb.milvus_config import MilvusConfig
from configs.middleware.vdb.myscale_config import MyScaleConfig
from configs.middleware.vdb.oceanbase_config import OceanBaseVectorConfig
@ -259,6 +260,7 @@ class MiddlewareConfig(
VikingDBConfig,
UpstashConfig,
TidbOnQdrantConfig,
LindormConfig,
OceanBaseVectorConfig,
BaiduVectorDBConfig,
):

View File

@ -0,0 +1,23 @@
from typing import Optional
from pydantic import Field
from pydantic_settings import BaseSettings
class LindormConfig(BaseSettings):
"""
Lindorm configs
"""
LINDORM_URL: Optional[str] = Field(
description="Lindorm url",
default=None,
)
LINDORM_USERNAME: Optional[str] = Field(
description="Lindorm user",
default=None,
)
LINDORM_PASSWORD: Optional[str] = Field(
description="Lindorm password",
default=None,
)

View File

@ -0,0 +1,24 @@
from flask_restful import fields
parameters__system_parameters = {
"image_file_size_limit": fields.Integer,
"video_file_size_limit": fields.Integer,
"audio_file_size_limit": fields.Integer,
"file_size_limit": fields.Integer,
"workflow_file_upload_limit": fields.Integer,
}
parameters_fields = {
"opening_statement": fields.String,
"suggested_questions": fields.Raw,
"suggested_questions_after_answer": fields.Raw,
"speech_to_text": fields.Raw,
"text_to_speech": fields.Raw,
"retriever_resource": fields.Raw,
"annotation_reply": fields.Raw,
"more_like_this": fields.Raw,
"user_input_form": fields.Raw,
"sensitive_word_avoidance": fields.Raw,
"file_upload": fields.Raw,
"system_parameters": fields.Nested(parameters__system_parameters),
}

View File

@ -2,11 +2,15 @@ import mimetypes
import os
import re
import urllib.parse
from collections.abc import Mapping
from typing import Any
from uuid import uuid4
import httpx
from pydantic import BaseModel
from configs import dify_config
class FileInfo(BaseModel):
filename: str
@ -56,3 +60,38 @@ def guess_file_info_from_response(response: httpx.Response):
mimetype=mimetype,
size=int(response.headers.get("Content-Length", -1)),
)
def get_parameters_from_feature_dict(*, features_dict: Mapping[str, Any], user_input_form: list[dict[str, Any]]):
return {
"opening_statement": features_dict.get("opening_statement"),
"suggested_questions": features_dict.get("suggested_questions", []),
"suggested_questions_after_answer": features_dict.get("suggested_questions_after_answer", {"enabled": False}),
"speech_to_text": features_dict.get("speech_to_text", {"enabled": False}),
"text_to_speech": features_dict.get("text_to_speech", {"enabled": False}),
"retriever_resource": features_dict.get("retriever_resource", {"enabled": False}),
"annotation_reply": features_dict.get("annotation_reply", {"enabled": False}),
"more_like_this": features_dict.get("more_like_this", {"enabled": False}),
"user_input_form": user_input_form,
"sensitive_word_avoidance": features_dict.get(
"sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []}
),
"file_upload": features_dict.get(
"file_upload",
{
"image": {
"enabled": False,
"number_limits": 3,
"detail": "high",
"transfer_methods": ["remote_url", "local_file"],
}
},
),
"system_parameters": {
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
"workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT,
},
}

View File

@ -456,7 +456,7 @@ class DatasetIndexingEstimateApi(Resource):
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
"No Embedding Model available. Please configure a valid provider " "in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@ -620,6 +620,7 @@ class DatasetRetrievalSettingApi(Resource):
case (
VectorType.MILVUS
| VectorType.RELYT
| VectorType.PGVECTOR
| VectorType.TIDB_VECTOR
| VectorType.CHROMA
| VectorType.TENCENT
@ -640,6 +641,7 @@ class DatasetRetrievalSettingApi(Resource):
| VectorType.ELASTICSEARCH
| VectorType.PGVECTOR
| VectorType.TIDB_ON_QDRANT
| VectorType.LINDORM
| VectorType.COUCHBASE
):
return {
@ -682,6 +684,7 @@ class DatasetRetrievalSettingMockApi(Resource):
| VectorType.ELASTICSEARCH
| VectorType.COUCHBASE
| VectorType.PGVECTOR
| VectorType.LINDORM
):
return {
"retrieval_method": [

View File

@ -1,6 +1,7 @@
from flask_restful import fields, marshal_with
from flask_restful import marshal_with
from configs import dify_config
from controllers.common import fields
from controllers.common import helpers as controller_helpers
from controllers.console import api
from controllers.console.app.error import AppUnavailableError
from controllers.console.explore.wraps import InstalledAppResource
@ -11,43 +12,14 @@ from services.app_service import AppService
class AppParameterApi(InstalledAppResource):
"""Resource for app variables."""
variable_fields = {
"key": fields.String,
"name": fields.String,
"description": fields.String,
"type": fields.String,
"default": fields.String,
"max_length": fields.Integer,
"options": fields.List(fields.String),
}
system_parameters_fields = {
"image_file_size_limit": fields.Integer,
"video_file_size_limit": fields.Integer,
"audio_file_size_limit": fields.Integer,
"file_size_limit": fields.Integer,
}
parameters_fields = {
"opening_statement": fields.String,
"suggested_questions": fields.Raw,
"suggested_questions_after_answer": fields.Raw,
"speech_to_text": fields.Raw,
"text_to_speech": fields.Raw,
"retriever_resource": fields.Raw,
"annotation_reply": fields.Raw,
"more_like_this": fields.Raw,
"user_input_form": fields.Raw,
"sensitive_word_avoidance": fields.Raw,
"file_upload": fields.Raw,
"system_parameters": fields.Nested(system_parameters_fields),
}
@marshal_with(parameters_fields)
@marshal_with(fields.parameters_fields)
def get(self, installed_app: InstalledApp):
"""Retrieve app parameters."""
app_model = installed_app.app
if app_model is None:
raise AppUnavailableError()
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
workflow = app_model.workflow
if workflow is None:
@ -57,43 +29,16 @@ class AppParameterApi(InstalledAppResource):
user_input_form = workflow.user_input_form(to_old_structure=True)
else:
app_model_config = app_model.app_model_config
if app_model_config is None:
raise AppUnavailableError()
features_dict = app_model_config.to_dict()
user_input_form = features_dict.get("user_input_form", [])
return {
"opening_statement": features_dict.get("opening_statement"),
"suggested_questions": features_dict.get("suggested_questions", []),
"suggested_questions_after_answer": features_dict.get(
"suggested_questions_after_answer", {"enabled": False}
),
"speech_to_text": features_dict.get("speech_to_text", {"enabled": False}),
"text_to_speech": features_dict.get("text_to_speech", {"enabled": False}),
"retriever_resource": features_dict.get("retriever_resource", {"enabled": False}),
"annotation_reply": features_dict.get("annotation_reply", {"enabled": False}),
"more_like_this": features_dict.get("more_like_this", {"enabled": False}),
"user_input_form": user_input_form,
"sensitive_word_avoidance": features_dict.get(
"sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []}
),
"file_upload": features_dict.get(
"file_upload",
{
"image": {
"enabled": False,
"number_limits": 3,
"detail": "high",
"transfer_methods": ["remote_url", "local_file"],
}
},
),
"system_parameters": {
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
},
}
return controller_helpers.get_parameters_from_feature_dict(
features_dict=features_dict, user_input_form=user_input_form
)
class ExploreAppMetaApi(InstalledAppResource):

View File

@ -37,6 +37,7 @@ class FileApi(Resource):
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
"workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT,
}, 200
@setup_required

View File

@ -61,6 +61,19 @@ class ToolBuiltinProviderListToolsApi(Resource):
)
class ToolBuiltinProviderInfoApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
user = current_user
user_id = user.id
tenant_id = user.current_tenant_id
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(user_id, tenant_id, provider))
class ToolBuiltinProviderDeleteApi(Resource):
@setup_required
@login_required
@ -604,6 +617,7 @@ api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers")
# builtin tool provider
api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/tools")
api.add_resource(ToolBuiltinProviderInfoApi, "/workspaces/current/tool-provider/builtin/<path:provider>/info")
api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin/<path:provider>/delete")
api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin/<path:provider>/update")
api.add_resource(

View File

@ -1,6 +1,7 @@
from flask_restful import Resource, fields, marshal_with
from flask_restful import Resource, marshal_with
from configs import dify_config
from controllers.common import fields
from controllers.common import helpers as controller_helpers
from controllers.service_api import api
from controllers.service_api.app.error import AppUnavailableError
from controllers.service_api.wraps import validate_app_token
@ -11,40 +12,8 @@ from services.app_service import AppService
class AppParameterApi(Resource):
"""Resource for app variables."""
variable_fields = {
"key": fields.String,
"name": fields.String,
"description": fields.String,
"type": fields.String,
"default": fields.String,
"max_length": fields.Integer,
"options": fields.List(fields.String),
}
system_parameters_fields = {
"image_file_size_limit": fields.Integer,
"video_file_size_limit": fields.Integer,
"audio_file_size_limit": fields.Integer,
"file_size_limit": fields.Integer,
}
parameters_fields = {
"opening_statement": fields.String,
"suggested_questions": fields.Raw,
"suggested_questions_after_answer": fields.Raw,
"speech_to_text": fields.Raw,
"text_to_speech": fields.Raw,
"retriever_resource": fields.Raw,
"annotation_reply": fields.Raw,
"more_like_this": fields.Raw,
"user_input_form": fields.Raw,
"sensitive_word_avoidance": fields.Raw,
"file_upload": fields.Raw,
"system_parameters": fields.Nested(system_parameters_fields),
}
@validate_app_token
@marshal_with(parameters_fields)
@marshal_with(fields.parameters_fields)
def get(self, app_model: App):
"""Retrieve app parameters."""
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
@ -56,43 +25,16 @@ class AppParameterApi(Resource):
user_input_form = workflow.user_input_form(to_old_structure=True)
else:
app_model_config = app_model.app_model_config
if app_model_config is None:
raise AppUnavailableError()
features_dict = app_model_config.to_dict()
user_input_form = features_dict.get("user_input_form", [])
return {
"opening_statement": features_dict.get("opening_statement"),
"suggested_questions": features_dict.get("suggested_questions", []),
"suggested_questions_after_answer": features_dict.get(
"suggested_questions_after_answer", {"enabled": False}
),
"speech_to_text": features_dict.get("speech_to_text", {"enabled": False}),
"text_to_speech": features_dict.get("text_to_speech", {"enabled": False}),
"retriever_resource": features_dict.get("retriever_resource", {"enabled": False}),
"annotation_reply": features_dict.get("annotation_reply", {"enabled": False}),
"more_like_this": features_dict.get("more_like_this", {"enabled": False}),
"user_input_form": user_input_form,
"sensitive_word_avoidance": features_dict.get(
"sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []}
),
"file_upload": features_dict.get(
"file_upload",
{
"image": {
"enabled": False,
"number_limits": 3,
"detail": "high",
"transfer_methods": ["remote_url", "local_file"],
}
},
),
"system_parameters": {
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
},
}
return controller_helpers.get_parameters_from_feature_dict(
features_dict=features_dict, user_input_form=user_input_form
)
class AppMetaApi(Resource):

View File

@ -1,6 +1,7 @@
from flask_restful import fields, marshal_with
from flask_restful import marshal_with
from configs import dify_config
from controllers.common import fields
from controllers.common import helpers as controller_helpers
from controllers.web import api
from controllers.web.error import AppUnavailableError
from controllers.web.wraps import WebApiResource
@ -11,39 +12,7 @@ from services.app_service import AppService
class AppParameterApi(WebApiResource):
"""Resource for app variables."""
variable_fields = {
"key": fields.String,
"name": fields.String,
"description": fields.String,
"type": fields.String,
"default": fields.String,
"max_length": fields.Integer,
"options": fields.List(fields.String),
}
system_parameters_fields = {
"image_file_size_limit": fields.Integer,
"video_file_size_limit": fields.Integer,
"audio_file_size_limit": fields.Integer,
"file_size_limit": fields.Integer,
}
parameters_fields = {
"opening_statement": fields.String,
"suggested_questions": fields.Raw,
"suggested_questions_after_answer": fields.Raw,
"speech_to_text": fields.Raw,
"text_to_speech": fields.Raw,
"retriever_resource": fields.Raw,
"annotation_reply": fields.Raw,
"more_like_this": fields.Raw,
"user_input_form": fields.Raw,
"sensitive_word_avoidance": fields.Raw,
"file_upload": fields.Raw,
"system_parameters": fields.Nested(system_parameters_fields),
}
@marshal_with(parameters_fields)
@marshal_with(fields.parameters_fields)
def get(self, app_model: App, end_user):
"""Retrieve app parameters."""
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
@ -55,43 +24,16 @@ class AppParameterApi(WebApiResource):
user_input_form = workflow.user_input_form(to_old_structure=True)
else:
app_model_config = app_model.app_model_config
if app_model_config is None:
raise AppUnavailableError()
features_dict = app_model_config.to_dict()
user_input_form = features_dict.get("user_input_form", [])
return {
"opening_statement": features_dict.get("opening_statement"),
"suggested_questions": features_dict.get("suggested_questions", []),
"suggested_questions_after_answer": features_dict.get(
"suggested_questions_after_answer", {"enabled": False}
),
"speech_to_text": features_dict.get("speech_to_text", {"enabled": False}),
"text_to_speech": features_dict.get("text_to_speech", {"enabled": False}),
"retriever_resource": features_dict.get("retriever_resource", {"enabled": False}),
"annotation_reply": features_dict.get("annotation_reply", {"enabled": False}),
"more_like_this": features_dict.get("more_like_this", {"enabled": False}),
"user_input_form": user_input_form,
"sensitive_word_avoidance": features_dict.get(
"sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []}
),
"file_upload": features_dict.get(
"file_upload",
{
"image": {
"enabled": False,
"number_limits": 3,
"detail": "high",
"transfer_methods": ["remote_url", "local_file"],
}
},
),
"system_parameters": {
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
},
}
return controller_helpers.get_parameters_from_feature_dict(
features_dict=features_dict, user_input_form=user_input_form
)
class AppMeta(WebApiResource):

View File

@ -1,6 +1,5 @@
import urllib.parse
from flask_login import current_user
from flask_restful import marshal_with, reqparse
from controllers.common import helpers
@ -27,7 +26,7 @@ class RemoteFileInfoApi(WebApiResource):
class RemoteFileUploadApi(WebApiResource):
@marshal_with(file_fields_with_signed_url)
def post(self):
def post(self, app_model, end_user): # Add app_model and end_user parameters
parser = reqparse.RequestParser()
parser.add_argument("url", type=str, required=True, help="URL is required")
args = parser.parse_args()
@ -51,7 +50,7 @@ class RemoteFileUploadApi(WebApiResource):
filename=file_info.filename,
content=content,
mimetype=file_info.mimetype,
user=current_user,
user=end_user, # Use end_user instead of current_user
source_url=url,
)
except Exception as e:

View File

@ -1,8 +1,7 @@
from collections.abc import Mapping
from typing import Any
from core.file.models import FileExtraConfig
from models import FileUploadConfig
from core.file import FileExtraConfig
class FileUploadConfigManager:
@ -43,6 +42,6 @@ class FileUploadConfigManager:
if not config.get("file_upload"):
config["file_upload"] = {}
else:
FileUploadConfig.model_validate(config["file_upload"])
FileExtraConfig.model_validate(config["file_upload"])
return config, ["file_upload"]

View File

@ -20,6 +20,7 @@ from core.app.entities.queue_entities import (
QueueIterationStartEvent,
QueueMessageReplaceEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
@ -314,7 +315,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if response:
yield response
elif isinstance(event, QueueNodeFailedEvent):
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent):
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
response = self._workflow_node_finish_to_stream_response(

View File

@ -23,7 +23,10 @@ class BaseAppGenerator:
user_inputs = user_inputs or {}
# Filter input variables from form configuration, handle required fields, default values, and option values
variables = app_config.variables
user_inputs = {var.variable: self._validate_input(inputs=user_inputs, var=var) for var in variables}
user_inputs = {
var.variable: self._validate_inputs(value=user_inputs.get(var.variable), variable_entity=var)
for var in variables
}
user_inputs = {k: self._sanitize_value(v) for k, v in user_inputs.items()}
# Convert files in inputs to File
entity_dictionary = {item.variable: item for item in app_config.variables}
@ -75,50 +78,66 @@ class BaseAppGenerator:
return user_inputs
def _validate_input(self, *, inputs: Mapping[str, Any], var: "VariableEntity"):
user_input_value = inputs.get(var.variable)
if not user_input_value:
if var.required:
raise ValueError(f"{var.variable} is required in input form")
else:
return None
def _validate_inputs(
self,
*,
variable_entity: "VariableEntity",
value: Any,
):
if value is None:
if variable_entity.required:
raise ValueError(f"{variable_entity.variable} is required in input form")
return value
if var.type in {
if variable_entity.type in {
VariableEntityType.TEXT_INPUT,
VariableEntityType.SELECT,
VariableEntityType.PARAGRAPH,
} and not isinstance(user_input_value, str):
raise ValueError(f"(type '{var.type}') {var.variable} in input form must be a string")
if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str):
} and not isinstance(value, str):
raise ValueError(
f"(type '{variable_entity.type}') {variable_entity.variable} in input form must be a string"
)
if variable_entity.type == VariableEntityType.NUMBER and isinstance(value, str):
# may raise ValueError if user_input_value is not a valid number
try:
if "." in user_input_value:
return float(user_input_value)
if "." in value:
return float(value)
else:
return int(user_input_value)
return int(value)
except ValueError:
raise ValueError(f"{var.variable} in input form must be a valid number")
if var.type == VariableEntityType.SELECT:
options = var.options
if user_input_value not in options:
raise ValueError(f"{var.variable} in input form must be one of the following: {options}")
elif var.type in {VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH}:
if var.max_length and len(user_input_value) > var.max_length:
raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters")
elif var.type == VariableEntityType.FILE:
if not isinstance(user_input_value, dict) and not isinstance(user_input_value, File):
raise ValueError(f"{var.variable} in input form must be a file")
elif var.type == VariableEntityType.FILE_LIST:
if not (
isinstance(user_input_value, list)
and (
all(isinstance(item, dict) for item in user_input_value)
or all(isinstance(item, File) for item in user_input_value)
)
):
raise ValueError(f"{var.variable} in input form must be a list of files")
raise ValueError(f"{variable_entity.variable} in input form must be a valid number")
return user_input_value
match variable_entity.type:
case VariableEntityType.SELECT:
if value not in variable_entity.options:
raise ValueError(
f"{variable_entity.variable} in input form must be one of the following: "
f"{variable_entity.options}"
)
case VariableEntityType.TEXT_INPUT | VariableEntityType.PARAGRAPH:
if variable_entity.max_length and len(value) > variable_entity.max_length:
raise ValueError(
f"{variable_entity.variable} in input form must be less than {variable_entity.max_length} "
"characters"
)
case VariableEntityType.FILE:
if not isinstance(value, dict) and not isinstance(value, File):
raise ValueError(f"{variable_entity.variable} in input form must be a file")
case VariableEntityType.FILE_LIST:
# if number of files exceeds the limit, raise ValueError
if not (
isinstance(value, list)
and (all(isinstance(item, dict) for item in value) or all(isinstance(item, File) for item in value))
):
raise ValueError(f"{variable_entity.variable} in input form must be a list of files")
if variable_entity.max_length and len(value) > variable_entity.max_length:
raise ValueError(
f"{variable_entity.variable} in input form must be less than {variable_entity.max_length} files"
)
return value
def _sanitize_value(self, value: Any) -> Any:
if isinstance(value, str):

View File

@ -16,6 +16,7 @@ from core.app.entities.queue_entities import (
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
@ -275,7 +276,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if response:
yield response
elif isinstance(event, QueueNodeFailedEvent):
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent):
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
response = self._workflow_node_finish_to_stream_response(

View File

@ -9,6 +9,7 @@ from core.app.entities.queue_entities import (
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
@ -30,6 +31,7 @@ from core.workflow.graph_engine.entities.event import (
IterationRunNextEvent,
IterationRunStartedEvent,
IterationRunSucceededEvent,
NodeInIterationFailedEvent,
NodeRunFailedEvent,
NodeRunRetrieverResourceEvent,
NodeRunStartedEvent,
@ -193,6 +195,7 @@ class WorkflowBasedAppRunner(AppRunner):
node_run_index=event.route_node_state.index,
predecessor_node_id=event.predecessor_node_id,
in_iteration_id=event.in_iteration_id,
parallel_mode_run_id=event.parallel_mode_run_id,
)
)
elif isinstance(event, NodeRunSucceededEvent):
@ -246,9 +249,40 @@ class WorkflowBasedAppRunner(AppRunner):
error=event.route_node_state.node_run_result.error
if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error
else "Unknown error",
execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result
else {},
in_iteration_id=event.in_iteration_id,
)
)
elif isinstance(event, NodeInIterationFailedEvent):
self._publish_event(
QueueNodeInIterationFailedEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result
else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs
if event.route_node_state.node_run_result
else {},
execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result
else {},
in_iteration_id=event.in_iteration_id,
error=event.error,
)
)
elif isinstance(event, NodeRunStreamChunkEvent):
self._publish_event(
QueueTextChunkEvent(
@ -326,6 +360,7 @@ class WorkflowBasedAppRunner(AppRunner):
index=event.index,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
output=event.pre_iteration_output,
parallel_mode_run_id=event.parallel_mode_run_id,
)
)
elif isinstance(event, (IterationRunSucceededEvent | IterationRunFailedEvent)):

View File

@ -107,7 +107,8 @@ class QueueIterationNextEvent(AppQueueEvent):
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
parallel_mode_run_id: Optional[str] = None
"""iteratoin run in parallel mode run id"""
node_run_index: int
output: Optional[Any] = None # output for the current iteration
@ -273,6 +274,8 @@ class QueueNodeStartedEvent(AppQueueEvent):
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
start_at: datetime
parallel_mode_run_id: Optional[str] = None
"""iteratoin run in parallel mode run id"""
class QueueNodeSucceededEvent(AppQueueEvent):
@ -306,6 +309,37 @@ class QueueNodeSucceededEvent(AppQueueEvent):
error: Optional[str] = None
class QueueNodeInIterationFailedEvent(AppQueueEvent):
"""
QueueNodeInIterationFailedEvent entity
"""
event: QueueEvent = QueueEvent.NODE_FAILED
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
start_at: datetime
inputs: Optional[dict[str, Any]] = None
process_data: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
error: str
class QueueNodeFailedEvent(AppQueueEvent):
"""
QueueNodeFailedEvent entity
@ -332,6 +366,7 @@ class QueueNodeFailedEvent(AppQueueEvent):
inputs: Optional[dict[str, Any]] = None
process_data: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
error: str

View File

@ -244,6 +244,7 @@ class NodeStartStreamResponse(StreamResponse):
parent_parallel_id: Optional[str] = None
parent_parallel_start_node_id: Optional[str] = None
iteration_id: Optional[str] = None
parallel_run_id: Optional[str] = None
event: StreamEvent = StreamEvent.NODE_STARTED
workflow_run_id: str
@ -432,6 +433,7 @@ class IterationNodeNextStreamResponse(StreamResponse):
extras: dict = {}
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
parallel_mode_run_id: Optional[str] = None
event: StreamEvent = StreamEvent.ITERATION_NEXT
workflow_run_id: str

View File

@ -12,6 +12,7 @@ from core.app.entities.queue_entities import (
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
@ -35,6 +36,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.tools.tool_manager import ToolManager
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes import NodeType
from core.workflow.nodes.tool.entities import ToolNodeData
@ -251,6 +253,12 @@ class WorkflowCycleManage:
workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value
workflow_node_execution.created_by_role = workflow_run.created_by_role
workflow_node_execution.created_by = workflow_run.created_by
workflow_node_execution.execution_metadata = json.dumps(
{
NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
}
)
workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None)
session.add(workflow_node_execution)
@ -305,7 +313,9 @@ class WorkflowCycleManage:
return workflow_node_execution
def _handle_workflow_node_execution_failed(self, event: QueueNodeFailedEvent) -> WorkflowNodeExecution:
def _handle_workflow_node_execution_failed(
self, event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent
) -> WorkflowNodeExecution:
"""
Workflow node execution failed
:param event: queue node failed event
@ -318,16 +328,19 @@ class WorkflowCycleManage:
outputs = WorkflowEntry.handle_special_values(event.outputs)
finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
elapsed_time = (finished_at - event.start_at).total_seconds()
execution_metadata = (
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
)
db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update(
{
WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.FAILED.value,
WorkflowNodeExecution.error: event.error,
WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None,
WorkflowNodeExecution.process_data: json.dumps(process_data) if event.process_data else None,
WorkflowNodeExecution.process_data: json.dumps(event.process_data) if event.process_data else None,
WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None,
WorkflowNodeExecution.finished_at: finished_at,
WorkflowNodeExecution.elapsed_time: elapsed_time,
WorkflowNodeExecution.execution_metadata: execution_metadata,
}
)
@ -342,6 +355,7 @@ class WorkflowCycleManage:
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
workflow_node_execution.finished_at = finished_at
workflow_node_execution.elapsed_time = elapsed_time
workflow_node_execution.execution_metadata = execution_metadata
self._wip_workflow_node_executions.pop(workflow_node_execution.node_execution_id)
@ -448,6 +462,7 @@ class WorkflowCycleManage:
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
parallel_run_id=event.parallel_mode_run_id,
),
)
@ -464,7 +479,7 @@ class WorkflowCycleManage:
def _workflow_node_finish_to_stream_response(
self,
event: QueueNodeSucceededEvent | QueueNodeFailedEvent,
event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeInIterationFailedEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> Optional[NodeFinishStreamResponse]:
@ -608,6 +623,7 @@ class WorkflowCycleManage:
extras={},
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parallel_mode_run_id=event.parallel_mode_run_id,
),
)
@ -633,7 +649,9 @@ class WorkflowCycleManage:
created_at=int(time.time()),
extras={},
inputs=event.inputs or {},
status=WorkflowNodeExecutionStatus.SUCCEEDED,
status=WorkflowNodeExecutionStatus.SUCCEEDED
if event.error is None
else WorkflowNodeExecutionStatus.FAILED,
error=None,
elapsed_time=(datetime.now(timezone.utc).replace(tzinfo=None) - event.start_at).total_seconds(),
total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0,

View File

@ -12,7 +12,8 @@ class CommonParameterType(Enum):
SYSTEM_FILES = "system-files"
BOOLEAN = "boolean"
APP_SELECTOR = "app-selector"
MODEL_CONFIG = "model-config"
TOOL_SELECTOR = "tool-selector"
MODEL_SELECTOR = "model-selector"
class AppSelectorScope(Enum):
@ -22,7 +23,7 @@ class AppSelectorScope(Enum):
COMPLETION = "completion"
class ModelConfigScope(Enum):
class ModelSelectorScope(Enum):
LLM = "llm"
TEXT_EMBEDDING = "text-embedding"
RERANK = "rerank"
@ -30,3 +31,10 @@ class ModelConfigScope(Enum):
SPEECH2TEXT = "speech2text"
MODERATION = "moderation"
VISION = "vision"
class ToolSelectorScope(Enum):
ALL = "all"
CUSTOM = "custom"
BUILTIN = "builtin"
WORKFLOW = "workflow"

View File

@ -3,7 +3,12 @@ from typing import Optional, Union
from pydantic import BaseModel, ConfigDict, Field
from core.entities.parameter_entities import AppSelectorScope, CommonParameterType, ModelConfigScope
from core.entities.parameter_entities import (
AppSelectorScope,
CommonParameterType,
ModelSelectorScope,
ToolSelectorScope,
)
from core.model_runtime.entities.model_entities import ModelType
from core.tools.entities.common_entities import I18nObject
@ -140,7 +145,8 @@ class BasicProviderConfig(BaseModel):
SELECT = CommonParameterType.SELECT.value
BOOLEAN = CommonParameterType.BOOLEAN.value
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
MODEL_CONFIG = CommonParameterType.MODEL_CONFIG.value
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value
TOOL_SELECTOR = CommonParameterType.TOOL_SELECTOR.value
@classmethod
def value_of(cls, value: str) -> "ProviderConfig.Type":
@ -168,7 +174,7 @@ class ProviderConfig(BasicProviderConfig):
value: str = Field(..., description="The value of the option")
label: I18nObject = Field(..., description="The label of the option")
scope: AppSelectorScope | ModelConfigScope | None = None
scope: AppSelectorScope | ModelSelectorScope | ToolSelectorScope | None = None
required: bool = False
default: Optional[Union[int, str]] = None
options: Optional[list[Option]] = None

View File

@ -598,7 +598,7 @@ class IndexingRunner:
rules = DatasetProcessRule.AUTOMATIC_RULES
else:
rules = json.loads(processing_rule.rules) if processing_rule.rules else {}
document_text = CleanProcessor.clean(text, rules)
document_text = CleanProcessor.clean(text, {"rules": rules})
return document_text

View File

@ -10,8 +10,15 @@ from core.model_runtime.entities.model_entities import (
PriceInfo,
PriceType,
)
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.plugin.entities.plugin_daemon import PluginDaemonInnerError, PluginModelProviderEntity
from core.plugin.manager.model import PluginModelManager
@ -31,7 +38,7 @@ class AIModel(BaseModel):
model_config = ConfigDict(protected_namespaces=())
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
def _invoke_error_mapping(self) -> dict[type[Exception], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
@ -40,9 +47,17 @@ class AIModel(BaseModel):
:return: Invoke error mapping
"""
raise NotImplementedError
return {
InvokeConnectionError: [InvokeConnectionError],
InvokeServerUnavailableError: [InvokeServerUnavailableError],
InvokeRateLimitError: [InvokeRateLimitError],
InvokeAuthorizationError: [InvokeAuthorizationError],
InvokeBadRequestError: [InvokeBadRequestError],
PluginDaemonInnerError: [PluginDaemonInnerError],
ValueError: [ValueError],
}
def _transform_invoke_error(self, error: Exception) -> InvokeError:
def _transform_invoke_error(self, error: Exception) -> Exception:
"""
Transform invoke error to unified error
@ -52,13 +67,15 @@ class AIModel(BaseModel):
for invoke_error, model_errors in self._invoke_error_mapping.items():
if isinstance(error, tuple(model_errors)):
if invoke_error == InvokeAuthorizationError:
return invoke_error(
return InvokeAuthorizationError(
description=(
f"[{self.provider_name}] Incorrect model credentials provided, please check and try again."
)
)
elif isinstance(invoke_error, InvokeError):
return invoke_error(description=f"[{self.provider_name}] {invoke_error.description}, {str(error)}")
else:
return error
return InvokeError(description=f"[{self.provider_name}] Error: {str(error)}")

View File

@ -0,0 +1,61 @@
model: anthropic.claude-3-5-haiku-20241022-v1:0
label:
en_US: Claude 3.5 Haiku
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 200000
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
parameter_rules:
- name: max_tokens
use_template: max_tokens
required: true
type: int
default: 8192
min: 1
max: 8192
help:
zh_Hans: 停止前生成的最大令牌数。请注意Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
# docs: https://docs.anthropic.com/claude/docs/system-prompts
- name: temperature
use_template: temperature
required: false
type: float
default: 1
min: 0.0
max: 1.0
help:
zh_Hans: 生成内容的随机性。
en_US: The amount of randomness injected into the response.
- name: top_p
required: false
type: float
default: 0.999
min: 0.000
max: 1.000
help:
zh_Hans: 在核采样中Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p但不能同时更改两者。
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
- name: top_k
required: false
type: int
default: 0
min: 0
# tip docs from aws has error, max value is 500
max: 500
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
- name: response_format
use_template: response_format
pricing:
input: '0.001'
output: '0.005'
unit: '0.001'
currency: USD

View File

@ -0,0 +1,61 @@
model: us.anthropic.claude-3-5-haiku-20241022-v1:0
label:
en_US: Claude 3.5 Haiku(US.Cross Region Inference)
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 200000
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
parameter_rules:
- name: max_tokens
use_template: max_tokens
required: true
type: int
default: 4096
min: 1
max: 4096
help:
zh_Hans: 停止前生成的最大令牌数。请注意Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
# docs: https://docs.anthropic.com/claude/docs/system-prompts
- name: temperature
use_template: temperature
required: false
type: float
default: 1
min: 0.0
max: 1.0
help:
zh_Hans: 生成内容的随机性。
en_US: The amount of randomness injected into the response.
- name: top_p
required: false
type: float
default: 0.999
min: 0.000
max: 1.000
help:
zh_Hans: 在核采样中Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p但不能同时更改两者。
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
- name: top_k
required: false
type: int
default: 0
min: 0
# tip docs from aws has error, max value is 500
max: 500
help:
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
- name: response_format
use_template: response_format
pricing:
input: '0.001'
output: '0.005'
unit: '0.001'
currency: USD

View File

@ -1,6 +1,7 @@
import logging
from core.model_runtime.entities.model_entities import ModelType
import requests
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
@ -16,8 +17,18 @@ class GiteeAIProvider(ModelProvider):
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
try:
model_instance = self.get_model_instance(ModelType.LLM)
model_instance.validate_credentials(model="Qwen2-7B-Instruct", credentials=credentials)
api_key = credentials.get("api_key")
if not api_key:
raise CredentialsValidateFailedError("Credentials validation failed: api_key not given")
# send a get request to validate the credentials
headers = {"Authorization": f"Bearer {api_key}"}
response = requests.get("https://ai.gitee.com/api/base/account/me", headers=headers, timeout=(10, 300))
if response.status_code != 200:
raise CredentialsValidateFailedError(
f"Credentials validation failed with status code {response.status_code}"
)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:

Binary file not shown.

After

Width:  |  Height:  |  Size: 277 KiB

View File

@ -0,0 +1,15 @@
<svg width="68" height="24" viewBox="0 0 68 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<g id="Gemini">
<path id="Union" fill-rule="evenodd" clip-rule="evenodd" d="M50.6875 4.37014C48.3498 4.59292 46.5349 6.41319 46.3337 8.72764C46.1446 6.44662 44.2677 4.56074 41.9805 4.3737C44.2762 4.1997 46.152 2.28299 46.3373 0C46.4882 2.28911 48.405 4.20047 50.6875 4.37014ZM15.4567 9.41141L13.9579 10.9076C9.92941 6.64892 2.69298 9.97287 3.17317 15.8112C3.22394 23.108 14.5012 24.4317 15.3628 16.8809H9.52096L9.50061 14.9149H17.3595C18.8163 23.1364 8.44367 27.0292 3.19453 21.238C0.847044 18.7556 0.363651 14.7682 1.83717 11.7212C4.1129 6.62089 11.6505 5.29845 15.4567 9.41141ZM45.5915 23.5989H47.6945C47.6944 22.9155 47.6945 22.2307 47.6946 21.5452V21.5325C47.6948 19.8907 47.695 18.2453 47.6924 16.6072C47.6914 15.9407 47.6161 15.2823 47.4024 14.647C46.4188 11.2828 41.4255 11.4067 39.8332 14.214C38.5637 11.4171 34.4009 11.5236 32.8538 14.0084L32.8082 13.9976V12.4806L32.4233 12.4804H32.4224C31.8687 12.4801 31.3324 12.4798 30.7949 12.4811V23.5848L32.8977 23.5672C32.8981 22.9411 32.8979 22.3122 32.8977 21.6822V21.6812V21.6802V21.6791V21.6781V21.6771V21.676V21.676V21.6759V21.6758V21.6757V21.6757V21.6756C32.8973 20.204 32.8969 18.7261 32.904 17.2614C32.8889 15.3646 34.5674 13.5687 36.5358 14.124C37.7794 14.3298 38.1851 15.6148 38.1761 16.7257C38.1821 17.7019 38.18 18.6824 38.178 19.6633V19.6638C38.1752 20.9756 38.1724 22.2881 38.1891 23.5919L40.2846 23.5731C40.2929 22.7511 40.2881 21.9245 40.2832 21.0966C40.2753 19.7402 40.2674 18.3805 40.317 17.0328C40.4418 15.2122 42.0141 13.6186 43.9064 14.1168C45.2685 14.3231 45.6136 15.7748 45.5882 16.9545C45.5938 18.4959 45.5929 20.0492 45.5921 21.5968V21.5991V21.6014V21.6037V21.606V21.6083V21.6106C45.5917 22.2749 45.5913 22.9382 45.5915 23.5989ZM20.6167 18.4408C20.5625 21.9486 25.2121 23.6996 27.2993 20.0558L29.1566 20.9592C27.8157 23.7067 24.2337 24.7424 21.5381 23.4213C18.0052 21.7253 17.41 16.5007 20.0334 13.7517C21.4609 12.1752 23.7291 11.7901 25.7206 12.3653C28.3408 13.1257 29.4974 15.8937 29.326 18.4399C27.5547 18.4415 25.7971 18.4412 24.0364 18.4409C22.8993 18.4407 21.7609 18.4405 20.6167 18.4408ZM27.1041 16.6957C26.7048 13.1033 21.2867 13.2256 20.7494 16.6957H27.1041ZM53.543 23.5999H55.6206L55.6206 22.4361C55.6205 20.7877 55.6205 19.1443 55.6207 17.4939C55.6208 16.8853 55.7234 16.297 56.0063 15.7531C56.6115 14.3862 58.1745 13.7002 59.5927 14.1774C60.7512 14.4455 61.2852 15.6069 61.2762 16.7154C61.2774 18.3497 61.2771 19.9826 61.2769 21.6162V21.6166V21.617V21.6174V21.6179L61.2766 23.6007H63.3698C63.3913 22.0924 63.3869 20.584 63.3826 19.0755V19.0754V19.0753V19.0753V19.0752C63.3799 18.1682 63.3773 17.2612 63.3803 16.3541C63.3796 15.8622 63.3103 15.3765 63.1698 14.9052C62.3248 11.5142 57.3558 11.2385 55.5828 14.0038L55.5336 13.9905V12.4917H53.539C53.4898 12.7313 53.4934 23.4113 53.543 23.5999ZM49.6211 12.4944H51.7065V23.5994H49.6211V12.4944ZM65.1035 23.5991H67.1831C67.2367 23.2198 67.2133 12.6566 67.1634 12.4983H65.1035V23.5991ZM52.1504 8.67829C52.1709 10.4847 49.2418 10.7058 49.1816 8.65714C49.2189 6.5948 52.2437 6.81331 52.1504 8.67829ZM66.1387 10.1324C64.2712 10.1609 64.1316 7.19881 66.1559 7.17114C68.1709 7.19817 68.0215 10.2087 66.1387 10.1324Z" fill="url(#paint0_linear_14286_118464)"/>
</g>
<defs>
<linearGradient id="paint0_linear_14286_118464" x1="-2" y1="0.999998" x2="67.9999" y2="27.5002" gradientUnits="userSpaceOnUse">
<stop stop-color="#7798E0"/>
<stop offset="0.210002" stop-color="#086FFF"/>
<stop offset="0.345945" stop-color="#086FFF"/>
<stop offset="0.591777" stop-color="#479AFF"/>
<stop offset="0.895892" stop-color="#B7C4FA"/>
<stop offset="1" stop-color="#B5C5F9"/>
</linearGradient>
</defs>
</svg>

After

Width:  |  Height:  |  Size: 3.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 57 KiB

View File

@ -0,0 +1,11 @@
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<rect width="24" height="24" rx="6" fill="url(#paint0_linear_7301_16076)"/>
<path d="M20 12.0116C15.7043 12.42 12.3692 15.757 11.9995 20C11.652 15.8183 8.20301 12.361 4 12.0181C8.21855 11.6991 11.6656 8.1853 12.006 4C12.2833 8.19653 15.8057 11.7005 20 12.0116Z" fill="white" fill-opacity="0.88"/>
<defs>
<linearGradient id="paint0_linear_7301_16076" x1="-9" y1="29.5" x2="19.4387" y2="1.43791" gradientUnits="userSpaceOnUse">
<stop offset="0.192878" stop-color="#1C7DFF"/>
<stop offset="0.520213" stop-color="#1C69FF"/>
<stop offset="1" stop-color="#F0DCD6"/>
</linearGradient>
</defs>
</svg>

After

Width:  |  Height:  |  Size: 689 B

View File

@ -0,0 +1,10 @@
import logging
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
logger = logging.getLogger(__name__)
class GPUStackProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
pass

View File

@ -0,0 +1,120 @@
provider: gpustack
label:
en_US: GPUStack
icon_small:
en_US: icon_s_en.png
icon_large:
en_US: icon_l_en.png
supported_model_types:
- llm
- text-embedding
- rerank
configurate_methods:
- customizable-model
model_credential_schema:
model:
label:
en_US: Model Name
zh_Hans: 模型名称
placeholder:
en_US: Enter your model name
zh_Hans: 输入模型名称
credential_form_schemas:
- variable: endpoint_url
label:
zh_Hans: 服务器地址
en_US: Server URL
type: text-input
required: true
placeholder:
zh_Hans: 输入 GPUStack 的服务器地址,如 http://192.168.1.100
en_US: Enter the GPUStack server URL, e.g. http://192.168.1.100
- variable: api_key
label:
en_US: API Key
type: secret-input
required: true
placeholder:
zh_Hans: 输入您的 API Key
en_US: Enter your API Key
- variable: mode
show_on:
- variable: __model_type
value: llm
label:
en_US: Completion mode
type: select
required: false
default: chat
placeholder:
zh_Hans: 选择补全类型
en_US: Select completion type
options:
- value: completion
label:
en_US: Completion
zh_Hans: 补全
- value: chat
label:
en_US: Chat
zh_Hans: 对话
- variable: context_size
label:
zh_Hans: 模型上下文长度
en_US: Model context size
required: true
type: text-input
default: "8192"
placeholder:
zh_Hans: 输入您的模型上下文长度
en_US: Enter your Model context size
- variable: max_tokens_to_sample
label:
zh_Hans: 最大 token 上限
en_US: Upper bound for max tokens
show_on:
- variable: __model_type
value: llm
default: "8192"
type: text-input
- variable: function_calling_type
show_on:
- variable: __model_type
value: llm
label:
en_US: Function calling
type: select
required: false
default: no_call
options:
- value: function_call
label:
en_US: Function Call
zh_Hans: Function Call
- value: tool_call
label:
en_US: Tool Call
zh_Hans: Tool Call
- value: no_call
label:
en_US: Not Support
zh_Hans: 不支持
- variable: vision_support
show_on:
- variable: __model_type
value: llm
label:
zh_Hans: Vision 支持
en_US: Vision Support
type: select
required: false
default: no_support
options:
- value: support
label:
en_US: Support
zh_Hans: 支持
- value: no_support
label:
en_US: Not Support
zh_Hans: 不支持

View File

@ -0,0 +1,45 @@
from collections.abc import Generator
from yarl import URL
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import (
PromptMessage,
PromptMessageTool,
)
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import (
OAIAPICompatLargeLanguageModel,
)
class GPUStackLanguageModel(OAIAPICompatLargeLanguageModel):
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
user: str | None = None,
) -> LLMResult | Generator:
return super()._invoke(
model,
credentials,
prompt_messages,
model_parameters,
tools,
stop,
stream,
user,
)
def validate_credentials(self, model: str, credentials: dict) -> None:
self._add_custom_parameters(credentials)
super().validate_credentials(model, credentials)
@staticmethod
def _add_custom_parameters(credentials: dict) -> None:
credentials["endpoint_url"] = str(URL(credentials["endpoint_url"]) / "v1-openai")
credentials["mode"] = "chat"

View File

@ -0,0 +1,146 @@
from json import dumps
from typing import Optional
import httpx
from requests import post
from yarl import URL
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import (
AIModelEntity,
FetchFrom,
ModelPropertyKey,
ModelType,
)
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
class GPUStackRerankModel(RerankModel):
"""
Model class for GPUStack rerank model.
"""
def _invoke(
self,
model: str,
credentials: dict,
query: str,
docs: list[str],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,
) -> RerankResult:
"""
Invoke rerank model
:param model: model name
:param credentials: model credentials
:param query: search query
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n documents to return
:param user: unique user id
:return: rerank result
"""
if len(docs) == 0:
return RerankResult(model=model, docs=[])
endpoint_url = credentials["endpoint_url"]
headers = {
"Authorization": f"Bearer {credentials.get('api_key')}",
"Content-Type": "application/json",
}
data = {"model": model, "query": query, "documents": docs, "top_n": top_n}
try:
response = post(
str(URL(endpoint_url) / "v1" / "rerank"),
headers=headers,
data=dumps(data),
timeout=10,
)
response.raise_for_status()
results = response.json()
rerank_documents = []
for result in results["results"]:
index = result["index"]
if "document" in result:
text = result["document"]["text"]
else:
text = docs[index]
rerank_document = RerankDocument(
index=index,
text=text,
score=result["relevance_score"],
)
if score_threshold is None or result["relevance_score"] >= score_threshold:
rerank_documents.append(rerank_document)
return RerankResult(model=model, docs=rerank_documents)
except httpx.HTTPStatusError as e:
raise InvokeServerUnavailableError(str(e))
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
self._invoke(
model=model,
credentials=credentials,
query="What is the capital of the United States?",
docs=[
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
"Census, Carson City had a population of 55,274.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
"are a political division controlled by the United States. Its capital is Saipan.",
],
score_threshold=0.8,
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
"""
return {
InvokeConnectionError: [httpx.ConnectError],
InvokeServerUnavailableError: [httpx.RemoteProtocolError],
InvokeRateLimitError: [],
InvokeAuthorizationError: [httpx.HTTPStatusError],
InvokeBadRequestError: [httpx.RequestError],
}
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
"""
generate custom model entities from credentials
"""
entity = AIModelEntity(
model=model,
label=I18nObject(en_US=model),
model_type=ModelType.RERANK,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))},
)
return entity

View File

@ -0,0 +1,35 @@
from typing import Optional
from yarl import URL
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.text_embedding_entities import (
TextEmbeddingResult,
)
from core.model_runtime.model_providers.openai_api_compatible.text_embedding.text_embedding import (
OAICompatEmbeddingModel,
)
class GPUStackTextEmbeddingModel(OAICompatEmbeddingModel):
"""
Model class for GPUStack text embedding model.
"""
def _invoke(
self,
model: str,
credentials: dict,
texts: list[str],
user: Optional[str] = None,
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
) -> TextEmbeddingResult:
return super()._invoke(model, credentials, texts, user, input_type)
def validate_credentials(self, model: str, credentials: dict) -> None:
self._add_custom_parameters(credentials)
super().validate_credentials(model, credentials)
@staticmethod
def _add_custom_parameters(credentials: dict) -> None:
credentials["endpoint_url"] = str(URL(credentials["endpoint_url"]) / "v1-openai")

View File

@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" fill="currentColor" viewBox="0 0 24 24" aria-hidden="true" class="" focusable="false" style="fill:currentColor;height:28px;width:28px"><path d="m3.005 8.858 8.783 12.544h3.904L6.908 8.858zM6.905 15.825 3 21.402h3.907l1.951-2.788zM16.585 2l-6.75 9.64 1.953 2.79L20.492 2zM17.292 7.965v13.437h3.2V3.395z"></path></svg>

After

Width:  |  Height:  |  Size: 356 B

View File

@ -0,0 +1,63 @@
model: grok-beta
label:
en_US: Grok beta
model_type: llm
features:
- multi-tool-call
model_properties:
mode: chat
context_size: 131072
parameter_rules:
- name: temperature
label:
en_US: "Temperature"
zh_Hans: "采样温度"
type: float
default: 0.7
min: 0.0
max: 2.0
precision: 1
required: true
help:
en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
- name: top_p
label:
en_US: "Top P"
zh_Hans: "Top P"
type: float
default: 0.7
min: 0.0
max: 1.0
precision: 1
required: true
help:
en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
- name: frequency_penalty
use_template: frequency_penalty
label:
en_US: "Frequency Penalty"
zh_Hans: "频率惩罚"
type: float
default: 0
min: 0
max: 2.0
precision: 1
required: false
help:
en_US: "Number between 0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim."
zh_Hans: "介于0和2.0之间的数字。正值会根据新标记在文本中迄今为止的现有频率来惩罚它们,从而降低模型一字不差地重复同一句话的可能性。"
- name: user
use_template: text
label:
en_US: "User"
zh_Hans: "用户"
type: string
required: false
help:
en_US: "Used to track and differentiate conversation requests from different users."
zh_Hans: "用于追踪和区分不同用户的对话请求。"

View File

@ -0,0 +1,37 @@
from collections.abc import Generator
from typing import Optional, Union
from yarl import URL
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
from core.model_runtime.entities.message_entities import (
PromptMessage,
PromptMessageTool,
)
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
class XAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
self._add_custom_parameters(credentials)
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
def validate_credentials(self, model: str, credentials: dict) -> None:
self._add_custom_parameters(credentials)
super().validate_credentials(model, credentials)
@staticmethod
def _add_custom_parameters(credentials) -> None:
credentials["endpoint_url"] = str(URL(credentials["endpoint_url"])) or "https://api.x.ai/v1"
credentials["mode"] = LLMMode.CHAT.value
credentials["function_calling_type"] = "tool_call"

View File

@ -0,0 +1,25 @@
import logging
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
logger = logging.getLogger(__name__)
class XAIProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
if validate failed, raise exception
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
try:
model_instance = self.get_model_instance(ModelType.LLM)
model_instance.validate_credentials(model="grok-beta", credentials=credentials)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
raise ex

View File

@ -0,0 +1,38 @@
provider: x
label:
en_US: xAI
description:
en_US: xAI is a company working on building artificial intelligence to accelerate human scientific discovery. We are guided by our mission to advance our collective understanding of the universe.
icon_small:
en_US: x-ai-logo.svg
icon_large:
en_US: x-ai-logo.svg
help:
title:
en_US: Get your token from xAI
zh_Hans: 从 xAI 获取 token
url:
en_US: https://x.ai/api
supported_model_types:
- llm
configurate_methods:
- predefined-model
provider_credential_schema:
credential_form_schemas:
- variable: api_key
label:
en_US: API Key
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
- variable: endpoint_url
label:
en_US: API Base
type: text-input
required: false
default: https://api.x.ai/v1
placeholder:
zh_Hans: 在此输入您的 API Base
en_US: Enter your API Base

View File

@ -53,7 +53,7 @@ class BasePluginManager:
)
except requests.exceptions.ConnectionError as e:
logger.exception(f"Request to Plugin Daemon Service failed: {e}")
raise ValueError("Request to Plugin Daemon Service failed")
raise PluginDaemonInnerError(code=-500, message="Request to Plugin Daemon Service failed")
return response
@ -157,8 +157,17 @@ class BasePluginManager:
Make a stream request to the plugin daemon inner API and yield the response as a model.
"""
for line in self._stream_request(method, path, params, headers, data, files):
line_data = None
try:
line_data = json.loads(line)
rep = PluginDaemonBasicResponse[type](**line_data)
except Exception as e:
# TODO modify this when line_data has code and message
if line_data and "error" in line_data:
raise ValueError(line_data["error"])
else:
raise ValueError(line)
if rep.code != 0:
if rep.code == -500:
try:

View File

@ -437,6 +437,7 @@ class PluginModelManager(BasePluginManager):
voices = []
for voice in resp.voices:
voices.append({"name": voice.name, "value": voice.value})
return voices
return []

View File

@ -103,7 +103,7 @@ class RetrievalService:
if exceptions:
exception_message = ";\n".join(exceptions)
raise Exception(exception_message)
raise ValueError(exception_message)
if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value:
data_post_processor = DataPostProcessor(

View File

@ -0,0 +1,498 @@
import copy
import json
import logging
from collections.abc import Iterable
from typing import Any, Optional
from opensearchpy import OpenSearch
from opensearchpy.helpers import bulk
from pydantic import BaseModel, model_validator
from tenacity import retry, stop_after_attempt, wait_fixed
from configs import dify_config
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logging.getLogger("lindorm").setLevel(logging.WARN)
class LindormVectorStoreConfig(BaseModel):
hosts: str
username: Optional[str] = None
password: Optional[str] = None
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values["hosts"]:
raise ValueError("config URL is required")
if not values["username"]:
raise ValueError("config USERNAME is required")
if not values["password"]:
raise ValueError("config PASSWORD is required")
return values
def to_opensearch_params(self) -> dict[str, Any]:
params = {
"hosts": self.hosts,
}
if self.username and self.password:
params["http_auth"] = (self.username, self.password)
return params
class LindormVectorStore(BaseVector):
def __init__(self, collection_name: str, config: LindormVectorStoreConfig, **kwargs):
super().__init__(collection_name.lower())
self._client_config = config
self._client = OpenSearch(**config.to_opensearch_params())
self.kwargs = kwargs
def get_type(self) -> str:
return VectorType.LINDORM
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
self.create_collection(len(embeddings[0]), **kwargs)
self.add_texts(texts, embeddings)
def refresh(self):
self._client.indices.refresh(index=self._collection_name)
def __filter_existed_ids(
self,
texts: list[str],
metadatas: list[dict],
ids: list[str],
bulk_size: int = 1024,
) -> tuple[Iterable[str], Optional[list[dict]], Optional[list[str]]]:
@retry(stop=stop_after_attempt(3), wait=wait_fixed(60))
def __fetch_existing_ids(batch_ids: list[str]) -> set[str]:
try:
existing_docs = self._client.mget(index=self._collection_name, body={"ids": batch_ids}, _source=False)
return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]}
except Exception as e:
logger.error(f"Error fetching batch {batch_ids}: {e}")
return set()
@retry(stop=stop_after_attempt(3), wait=wait_fixed(60))
def __fetch_existing_routing_ids(batch_ids: list[str], route_ids: list[str]) -> set[str]:
try:
existing_docs = self._client.mget(
body={
"docs": [
{"_index": self._collection_name, "_id": id, "routing": routing}
for id, routing in zip(batch_ids, route_ids)
]
},
_source=False,
)
return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]}
except Exception as e:
logger.error(f"Error fetching batch {batch_ids}: {e}")
return set()
if ids is None:
return texts, metadatas, ids
if len(texts) != len(ids):
raise RuntimeError(f"texts {len(texts)} != {ids}")
filtered_texts = []
filtered_metadatas = []
filtered_ids = []
def batch(iterable, n):
length = len(iterable)
for idx in range(0, length, n):
yield iterable[idx : min(idx + n, length)]
for ids_batch, texts_batch, metadatas_batch in zip(
batch(ids, bulk_size),
batch(texts, bulk_size),
batch(metadatas, bulk_size) if metadatas is not None else batch([None] * len(ids), bulk_size),
):
existing_ids_set = __fetch_existing_ids(ids_batch)
for text, metadata, doc_id in zip(texts_batch, metadatas_batch, ids_batch):
if doc_id not in existing_ids_set:
filtered_texts.append(text)
filtered_ids.append(doc_id)
if metadatas is not None:
filtered_metadatas.append(metadata)
return filtered_texts, metadatas if metadatas is None else filtered_metadatas, filtered_ids
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
actions = []
uuids = self._get_uuids(documents)
for i in range(len(documents)):
action = {
"_op_type": "index",
"_index": self._collection_name.lower(),
"_id": uuids[i],
"_source": {
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
Field.METADATA_KEY.value: documents[i].metadata,
},
}
actions.append(action)
bulk(self._client, actions)
self.refresh()
def get_ids_by_metadata_field(self, key: str, value: str):
query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}}
response = self._client.search(index=self._collection_name, body=query)
if response["hits"]["hits"]:
return [hit["_id"] for hit in response["hits"]["hits"]]
else:
return None
def delete_by_metadata_field(self, key: str, value: str):
query_str = {"query": {"match": {f"metadata.{key}": f"{value}"}}}
results = self._client.search(index=self._collection_name, body=query_str)
ids = [hit["_id"] for hit in results["hits"]["hits"]]
if ids:
self.delete_by_ids(ids)
def delete_by_ids(self, ids: list[str]) -> None:
for id in ids:
if self._client.exists(index=self._collection_name, id=id):
self._client.delete(index=self._collection_name, id=id)
else:
logger.warning(f"DELETE BY ID: ID {id} does not exist in the index.")
def delete(self) -> None:
try:
if self._client.indices.exists(index=self._collection_name):
self._client.indices.delete(index=self._collection_name, params={"timeout": 60})
logger.info("Delete index success")
else:
logger.warning(f"Index '{self._collection_name}' does not exist. No deletion performed.")
except Exception as e:
logger.error(f"Error occurred while deleting the index: {e}")
raise e
def text_exists(self, id: str) -> bool:
try:
self._client.get(index=self._collection_name, id=id)
return True
except:
return False
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
# Make sure query_vector is a list
if not isinstance(query_vector, list):
raise ValueError("query_vector should be a list of floats")
# Check whether query_vector is a floating-point number list
if not all(isinstance(x, float) for x in query_vector):
raise ValueError("All elements in query_vector should be floats")
top_k = kwargs.get("top_k", 10)
query = default_vector_search_query(query_vector=query_vector, k=top_k, **kwargs)
try:
response = self._client.search(index=self._collection_name, body=query)
except Exception as e:
logger.error(f"Error executing search: {e}")
raise
docs_and_scores = []
for hit in response["hits"]["hits"]:
docs_and_scores.append(
(
Document(
page_content=hit["_source"][Field.CONTENT_KEY.value],
vector=hit["_source"][Field.VECTOR.value],
metadata=hit["_source"][Field.METADATA_KEY.value],
),
hit["_score"],
)
)
docs = []
for doc, score in docs_and_scores:
score_threshold = kwargs.get("score_threshold", 0.0) or 0.0
if score > score_threshold:
doc.metadata["score"] = score
docs.append(doc)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
must = kwargs.get("must")
must_not = kwargs.get("must_not")
should = kwargs.get("should")
minimum_should_match = kwargs.get("minimum_should_match", 0)
top_k = kwargs.get("top_k", 10)
filters = kwargs.get("filter")
routing = kwargs.get("routing")
full_text_query = default_text_search_query(
query_text=query,
k=top_k,
text_field=Field.CONTENT_KEY.value,
must=must,
must_not=must_not,
should=should,
minimum_should_match=minimum_should_match,
filters=filters,
routing=routing,
)
response = self._client.search(index=self._collection_name, body=full_text_query)
docs = []
for hit in response["hits"]["hits"]:
docs.append(
Document(
page_content=hit["_source"][Field.CONTENT_KEY.value],
vector=hit["_source"][Field.VECTOR.value],
metadata=hit["_source"][Field.METADATA_KEY.value],
)
)
return docs
def create_collection(self, dimension: int, **kwargs):
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
logger.info(f"Collection {self._collection_name} already exists.")
return
if self._client.indices.exists(index=self._collection_name):
logger.info("{self._collection_name.lower()} already exists.")
return
if len(self.kwargs) == 0 and len(kwargs) != 0:
self.kwargs = copy.deepcopy(kwargs)
vector_field = kwargs.pop("vector_field", Field.VECTOR.value)
shards = kwargs.pop("shards", 2)
engine = kwargs.pop("engine", "lvector")
method_name = kwargs.pop("method_name", "hnsw")
data_type = kwargs.pop("data_type", "float")
space_type = kwargs.pop("space_type", "cosinesimil")
hnsw_m = kwargs.pop("hnsw_m", 24)
hnsw_ef_construction = kwargs.pop("hnsw_ef_construction", 500)
ivfpq_m = kwargs.pop("ivfpq_m", dimension)
nlist = kwargs.pop("nlist", 1000)
centroids_use_hnsw = kwargs.pop("centroids_use_hnsw", True if nlist >= 5000 else False)
centroids_hnsw_m = kwargs.pop("centroids_hnsw_m", 24)
centroids_hnsw_ef_construct = kwargs.pop("centroids_hnsw_ef_construct", 500)
centroids_hnsw_ef_search = kwargs.pop("centroids_hnsw_ef_search", 100)
mapping = default_text_mapping(
dimension,
method_name,
shards=shards,
engine=engine,
data_type=data_type,
space_type=space_type,
vector_field=vector_field,
hnsw_m=hnsw_m,
hnsw_ef_construction=hnsw_ef_construction,
nlist=nlist,
ivfpq_m=ivfpq_m,
centroids_use_hnsw=centroids_use_hnsw,
centroids_hnsw_m=centroids_hnsw_m,
centroids_hnsw_ef_construct=centroids_hnsw_ef_construct,
centroids_hnsw_ef_search=centroids_hnsw_ef_search,
**kwargs,
)
self._client.indices.create(index=self._collection_name.lower(), body=mapping)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
# logger.info(f"create index success: {self._collection_name}")
def default_text_mapping(dimension: int, method_name: str, **kwargs: Any) -> dict:
routing_field = kwargs.get("routing_field")
excludes_from_source = kwargs.get("excludes_from_source")
analyzer = kwargs.get("analyzer", "ik_max_word")
text_field = kwargs.get("text_field", Field.CONTENT_KEY.value)
engine = kwargs["engine"]
shard = kwargs["shards"]
space_type = kwargs["space_type"]
data_type = kwargs["data_type"]
vector_field = kwargs.get("vector_field", Field.VECTOR.value)
if method_name == "ivfpq":
ivfpq_m = kwargs["ivfpq_m"]
nlist = kwargs["nlist"]
centroids_use_hnsw = True if nlist > 10000 else False
centroids_hnsw_m = 24
centroids_hnsw_ef_construct = 500
centroids_hnsw_ef_search = 100
parameters = {
"m": ivfpq_m,
"nlist": nlist,
"centroids_use_hnsw": centroids_use_hnsw,
"centroids_hnsw_m": centroids_hnsw_m,
"centroids_hnsw_ef_construct": centroids_hnsw_ef_construct,
"centroids_hnsw_ef_search": centroids_hnsw_ef_search,
}
elif method_name == "hnsw":
neighbor = kwargs["hnsw_m"]
ef_construction = kwargs["hnsw_ef_construction"]
parameters = {"m": neighbor, "ef_construction": ef_construction}
elif method_name == "flat":
parameters = {}
else:
raise RuntimeError(f"unexpected method_name: {method_name}")
mapping = {
"settings": {"index": {"number_of_shards": shard, "knn": True}},
"mappings": {
"properties": {
vector_field: {
"type": "knn_vector",
"dimension": dimension,
"data_type": data_type,
"method": {
"engine": engine,
"name": method_name,
"space_type": space_type,
"parameters": parameters,
},
},
text_field: {"type": "text", "analyzer": analyzer},
}
},
}
if excludes_from_source:
mapping["mappings"]["_source"] = {"excludes": excludes_from_source} # e.g. {"excludes": ["vector_field"]}
if method_name == "ivfpq" and routing_field is not None:
mapping["settings"]["index"]["knn_routing"] = True
mapping["settings"]["index"]["knn.offline.construction"] = True
if method_name == "flat" and routing_field is not None:
mapping["settings"]["index"]["knn_routing"] = True
return mapping
def default_text_search_query(
query_text: str,
k: int = 4,
text_field: str = Field.CONTENT_KEY.value,
must: Optional[list[dict]] = None,
must_not: Optional[list[dict]] = None,
should: Optional[list[dict]] = None,
minimum_should_match: int = 0,
filters: Optional[list[dict]] = None,
routing: Optional[str] = None,
**kwargs,
) -> dict:
if routing is not None:
routing_field = kwargs.get("routing_field", "routing_field")
query_clause = {
"bool": {
"must": [{"match": {text_field: query_text}}, {"term": {f"metadata.{routing_field}.keyword": routing}}]
}
}
else:
query_clause = {"match": {text_field: query_text}}
# build the simplest search_query when only query_text is specified
if not must and not must_not and not should and not filters:
search_query = {"size": k, "query": query_clause}
return search_query
# build complex search_query when either of must/must_not/should/filter is specified
if must:
if not isinstance(must, list):
raise RuntimeError(f"unexpected [must] clause with {type(filters)}")
if query_clause not in must:
must.append(query_clause)
else:
must = [query_clause]
boolean_query = {"must": must}
if must_not:
if not isinstance(must_not, list):
raise RuntimeError(f"unexpected [must_not] clause with {type(filters)}")
boolean_query["must_not"] = must_not
if should:
if not isinstance(should, list):
raise RuntimeError(f"unexpected [should] clause with {type(filters)}")
boolean_query["should"] = should
if minimum_should_match != 0:
boolean_query["minimum_should_match"] = minimum_should_match
if filters:
if not isinstance(filters, list):
raise RuntimeError(f"unexpected [filter] clause with {type(filters)}")
boolean_query["filter"] = filters
search_query = {"size": k, "query": {"bool": boolean_query}}
return search_query
def default_vector_search_query(
query_vector: list[float],
k: int = 4,
min_score: str = "0.0",
ef_search: Optional[str] = None, # only for hnsw
nprobe: Optional[str] = None, # "2000"
reorder_factor: Optional[str] = None, # "20"
client_refactor: Optional[str] = None, # "true"
vector_field: str = Field.VECTOR.value,
filters: Optional[list[dict]] = None,
filter_type: Optional[str] = None,
**kwargs,
) -> dict:
if filters is not None:
filter_type = "post_filter" if filter_type is None else filter_type
if not isinstance(filter, list):
raise RuntimeError(f"unexpected filter with {type(filters)}")
final_ext = {"lvector": {}}
if min_score != "0.0":
final_ext["lvector"]["min_score"] = min_score
if ef_search:
final_ext["lvector"]["ef_search"] = ef_search
if nprobe:
final_ext["lvector"]["nprobe"] = nprobe
if reorder_factor:
final_ext["lvector"]["reorder_factor"] = reorder_factor
if client_refactor:
final_ext["lvector"]["client_refactor"] = client_refactor
search_query = {
"size": k,
"_source": True, # force return '_source'
"query": {"knn": {vector_field: {"vector": query_vector, "k": k}}},
}
if filters is not None:
# when using filter, transform filter from List[Dict] to Dict as valid format
filters = {"bool": {"must": filters}} if len(filters) > 1 else filters[0]
search_query["query"]["knn"][vector_field]["filter"] = filters # filter should be Dict
if filter_type:
final_ext["lvector"]["filter_type"] = filter_type
if final_ext != {"lvector": {}}:
search_query["ext"] = final_ext
return search_query
class LindormVectorStoreFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> LindormVectorStore:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.LINDORM, collection_name))
lindorm_config = LindormVectorStoreConfig(
hosts=dify_config.LINDORM_URL,
username=dify_config.LINDORM_USERNAME,
password=dify_config.LINDORM_PASSWORD,
)
return LindormVectorStore(collection_name, lindorm_config)

View File

@ -134,6 +134,10 @@ class Vector:
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector import TidbOnQdrantVectorFactory
return TidbOnQdrantVectorFactory
case VectorType.LINDORM:
from core.rag.datasource.vdb.lindorm.lindorm_vector import LindormVectorStoreFactory
return LindormVectorStoreFactory
case VectorType.OCEANBASE:
from core.rag.datasource.vdb.oceanbase.oceanbase_vector import OceanBaseVectorFactory

View File

@ -16,6 +16,7 @@ class VectorType(str, Enum):
TENCENT = "tencent"
ORACLE = "oracle"
ELASTICSEARCH = "elasticsearch"
LINDORM = "lindorm"
COUCHBASE = "couchbase"
BAIDU = "baidu"
VIKINGDB = "vikingdb"

View File

@ -14,6 +14,7 @@ import requests
from docx import Document as DocxDocument
from configs import dify_config
from core.helper import ssrf_proxy
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
from extensions.ext_database import db
@ -86,7 +87,7 @@ class WordExtractor(BaseExtractor):
image_count += 1
if rel.is_external:
url = rel.reltype
response = requests.get(url, stream=True)
response = ssrf_proxy.get(url, stream=True)
if response.status_code == 200:
image_ext = mimetypes.guess_extension(response.headers["Content-Type"])
file_uuid = str(uuid.uuid4())

View File

@ -5,7 +5,12 @@ from typing import Any, Optional, Union
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator
from core.entities.parameter_entities import AppSelectorScope, CommonParameterType, ModelConfigScope
from core.entities.parameter_entities import (
AppSelectorScope,
CommonParameterType,
ModelSelectorScope,
ToolSelectorScope,
)
from core.entities.provider_entities import ProviderConfig
from core.tools.entities.common_entities import I18nObject
@ -209,6 +214,9 @@ class ToolParameter(BaseModel):
SECRET_INPUT = CommonParameterType.SECRET_INPUT.value
FILE = CommonParameterType.FILE.value
FILES = CommonParameterType.FILES.value
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
TOOL_SELECTOR = CommonParameterType.TOOL_SELECTOR.value
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value
# deprecated, should not use.
SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value
@ -258,11 +266,26 @@ class ToolParameter(BaseModel):
return float(value)
else:
return int(value)
case ToolParameter.ToolParameterType.SYSTEM_FILES | ToolParameter.ToolParameterType.FILES:
if not isinstance(value, list):
return [value]
return value
case ToolParameter.ToolParameterType.FILE:
if isinstance(value, list):
if len(value) != 1:
raise ValueError(
"This parameter only accepts one file but got multiple files while invoking."
)
else:
return value[0]
return value
case (
ToolParameter.ToolParameterType.SYSTEM_FILES
| ToolParameter.ToolParameterType.FILE
| ToolParameter.ToolParameterType.FILES
ToolParameter.ToolParameterType.TOOL_SELECTOR
| ToolParameter.ToolParameterType.MODEL_SELECTOR
| ToolParameter.ToolParameterType.APP_SELECTOR
):
if not isinstance(value, dict):
raise ValueError("The selector must be a dictionary.")
return value
case _:
return str(value)
@ -280,7 +303,7 @@ class ToolParameter(BaseModel):
human_description: Optional[I18nObject] = Field(default=None, description="The description presented to the user")
placeholder: Optional[I18nObject] = Field(default=None, description="The placeholder presented to the user")
type: ToolParameterType = Field(..., description="The type of the parameter")
scope: AppSelectorScope | ModelConfigScope | None = None
scope: AppSelectorScope | ModelSelectorScope | ToolSelectorScope | None = None
form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm")
llm_description: Optional[str] = None
required: Optional[bool] = False

View File

@ -23,6 +23,7 @@ class NodeRunMetadataKey(str, Enum):
PARALLEL_START_NODE_ID = "parallel_start_node_id"
PARENT_PARALLEL_ID = "parent_parallel_id"
PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id"
PARALLEL_MODE_RUN_ID = "parallel_mode_run_id"
class NodeRunResult(BaseModel):

View File

@ -59,6 +59,7 @@ class BaseNodeEvent(GraphEngineEvent):
class NodeRunStartedEvent(BaseNodeEvent):
predecessor_node_id: Optional[str] = None
parallel_mode_run_id: Optional[str] = None
"""predecessor node id"""
@ -81,6 +82,10 @@ class NodeRunFailedEvent(BaseNodeEvent):
error: str = Field(..., description="error")
class NodeInIterationFailedEvent(BaseNodeEvent):
error: str = Field(..., description="error")
###########################################
# Parallel Branch Events
###########################################
@ -129,6 +134,8 @@ class BaseIterationEvent(GraphEngineEvent):
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
parallel_mode_run_id: Optional[str] = None
"""iteratoin run in parallel mode run id"""
class IterationRunStartedEvent(BaseIterationEvent):

View File

@ -4,6 +4,7 @@ import time
import uuid
from collections.abc import Generator, Mapping
from concurrent.futures import ThreadPoolExecutor, wait
from copy import copy, deepcopy
from typing import Any, Optional
from flask import Flask, current_app
@ -724,6 +725,16 @@ class GraphEngine:
"""
return time.perf_counter() - start_at > max_execution_time
def create_copy(self):
"""
create a graph engine copy
:return: with a new variable pool instance of graph engine
"""
new_instance = copy(self)
new_instance.graph_runtime_state = copy(self.graph_runtime_state)
new_instance.graph_runtime_state.variable_pool = deepcopy(self.graph_runtime_state.variable_pool)
return new_instance
class GraphRunFailedError(Exception):
def __init__(self, error: str):

View File

@ -12,6 +12,12 @@ from core.workflow.nodes.code.entities import CodeNodeData
from core.workflow.nodes.enums import NodeType
from models.workflow import WorkflowNodeExecutionStatus
from .exc import (
CodeNodeError,
DepthLimitError,
OutputValidationError,
)
class CodeNode(BaseNode[CodeNodeData]):
_node_data_cls = CodeNodeData
@ -60,7 +66,7 @@ class CodeNode(BaseNode[CodeNodeData]):
# Transform result
result = self._transform_result(result, self.node_data.outputs)
except (CodeExecutionError, ValueError) as e:
except (CodeExecutionError, CodeNodeError) as e:
return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e))
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result)
@ -76,10 +82,10 @@ class CodeNode(BaseNode[CodeNodeData]):
if value is None:
return None
else:
raise ValueError(f"Output variable `{variable}` must be a string")
raise OutputValidationError(f"Output variable `{variable}` must be a string")
if len(value) > dify_config.CODE_MAX_STRING_LENGTH:
raise ValueError(
raise OutputValidationError(
f"The length of output variable `{variable}` must be"
f" less than {dify_config.CODE_MAX_STRING_LENGTH} characters"
)
@ -97,10 +103,10 @@ class CodeNode(BaseNode[CodeNodeData]):
if value is None:
return None
else:
raise ValueError(f"Output variable `{variable}` must be a number")
raise OutputValidationError(f"Output variable `{variable}` must be a number")
if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER:
raise ValueError(
raise OutputValidationError(
f"Output variable `{variable}` is out of range,"
f" it must be between {dify_config.CODE_MIN_NUMBER} and {dify_config.CODE_MAX_NUMBER}."
)
@ -108,7 +114,7 @@ class CodeNode(BaseNode[CodeNodeData]):
if isinstance(value, float):
# raise error if precision is too high
if len(str(value).split(".")[1]) > dify_config.CODE_MAX_PRECISION:
raise ValueError(
raise OutputValidationError(
f"Output variable `{variable}` has too high precision,"
f" it must be less than {dify_config.CODE_MAX_PRECISION} digits."
)
@ -125,7 +131,7 @@ class CodeNode(BaseNode[CodeNodeData]):
:return:
"""
if depth > dify_config.CODE_MAX_DEPTH:
raise ValueError(f"Depth limit ${dify_config.CODE_MAX_DEPTH} reached, object too deep.")
raise DepthLimitError(f"Depth limit ${dify_config.CODE_MAX_DEPTH} reached, object too deep.")
transformed_result = {}
if output_schema is None:
@ -177,14 +183,14 @@ class CodeNode(BaseNode[CodeNodeData]):
depth=depth + 1,
)
else:
raise ValueError(
raise OutputValidationError(
f"Output {prefix}.{output_name} is not a valid array."
f" make sure all elements are of the same type."
)
elif output_value is None:
pass
else:
raise ValueError(f"Output {prefix}.{output_name} is not a valid type.")
raise OutputValidationError(f"Output {prefix}.{output_name} is not a valid type.")
return result
@ -192,7 +198,7 @@ class CodeNode(BaseNode[CodeNodeData]):
for output_name, output_config in output_schema.items():
dot = "." if prefix else ""
if output_name not in result:
raise ValueError(f"Output {prefix}{dot}{output_name} is missing.")
raise OutputValidationError(f"Output {prefix}{dot}{output_name} is missing.")
if output_config.type == "object":
# check if output is object
@ -200,7 +206,7 @@ class CodeNode(BaseNode[CodeNodeData]):
if isinstance(result.get(output_name), type(None)):
transformed_result[output_name] = None
else:
raise ValueError(
raise OutputValidationError(
f"Output {prefix}{dot}{output_name} is not an object,"
f" got {type(result.get(output_name))} instead."
)
@ -228,13 +234,13 @@ class CodeNode(BaseNode[CodeNodeData]):
if isinstance(result[output_name], type(None)):
transformed_result[output_name] = None
else:
raise ValueError(
raise OutputValidationError(
f"Output {prefix}{dot}{output_name} is not an array,"
f" got {type(result.get(output_name))} instead."
)
else:
if len(result[output_name]) > dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH:
raise ValueError(
raise OutputValidationError(
f"The length of output variable `{prefix}{dot}{output_name}` must be"
f" less than {dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH} elements."
)
@ -249,13 +255,13 @@ class CodeNode(BaseNode[CodeNodeData]):
if isinstance(result[output_name], type(None)):
transformed_result[output_name] = None
else:
raise ValueError(
raise OutputValidationError(
f"Output {prefix}{dot}{output_name} is not an array,"
f" got {type(result.get(output_name))} instead."
)
else:
if len(result[output_name]) > dify_config.CODE_MAX_STRING_ARRAY_LENGTH:
raise ValueError(
raise OutputValidationError(
f"The length of output variable `{prefix}{dot}{output_name}` must be"
f" less than {dify_config.CODE_MAX_STRING_ARRAY_LENGTH} elements."
)
@ -270,13 +276,13 @@ class CodeNode(BaseNode[CodeNodeData]):
if isinstance(result[output_name], type(None)):
transformed_result[output_name] = None
else:
raise ValueError(
raise OutputValidationError(
f"Output {prefix}{dot}{output_name} is not an array,"
f" got {type(result.get(output_name))} instead."
)
else:
if len(result[output_name]) > dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH:
raise ValueError(
raise OutputValidationError(
f"The length of output variable `{prefix}{dot}{output_name}` must be"
f" less than {dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH} elements."
)
@ -286,7 +292,7 @@ class CodeNode(BaseNode[CodeNodeData]):
if value is None:
pass
else:
raise ValueError(
raise OutputValidationError(
f"Output {prefix}{dot}{output_name}[{i}] is not an object,"
f" got {type(value)} instead at index {i}."
)
@ -303,13 +309,13 @@ class CodeNode(BaseNode[CodeNodeData]):
for i, value in enumerate(result[output_name])
]
else:
raise ValueError(f"Output type {output_config.type} is not supported.")
raise OutputValidationError(f"Output type {output_config.type} is not supported.")
parameters_validated[output_name] = True
# check if all output parameters are validated
if len(parameters_validated) != len(result):
raise ValueError("Not all output parameters are validated.")
raise CodeNodeError("Not all output parameters are validated.")
return transformed_result

View File

@ -0,0 +1,16 @@
class CodeNodeError(ValueError):
"""Base class for code node errors."""
pass
class OutputValidationError(CodeNodeError):
"""Raised when there is an output validation error."""
pass
class DepthLimitError(CodeNodeError):
"""Raised when the depth limit is reached."""
pass

View File

@ -1,4 +1,4 @@
class DocumentExtractorError(Exception):
class DocumentExtractorError(ValueError):
"""Base exception for errors related to the DocumentExtractorNode."""

View File

@ -6,12 +6,14 @@ import docx
import pandas as pd
import pypdfium2
import yaml
from unstructured.partition.api import partition_via_api
from unstructured.partition.email import partition_email
from unstructured.partition.epub import partition_epub
from unstructured.partition.msg import partition_msg
from unstructured.partition.ppt import partition_ppt
from unstructured.partition.pptx import partition_pptx
from configs import dify_config
from core.file import File, FileTransferMethod, file_manager
from core.helper import ssrf_proxy
from core.variables import ArrayFileSegment
@ -196,10 +198,8 @@ def _download_file_content(file: File) -> bytes:
response = ssrf_proxy.get(file.remote_url)
response.raise_for_status()
return response.content
elif file.transfer_method == FileTransferMethod.LOCAL_FILE:
return file_manager.download(file)
else:
raise ValueError(f"Unsupported transfer method: {file.transfer_method}")
return file_manager.download(file)
except Exception as e:
raise FileDownloadError(f"Error downloading file: {str(e)}") from e
@ -263,6 +263,13 @@ def _extract_text_from_ppt(file_content: bytes) -> str:
def _extract_text_from_pptx(file_content: bytes) -> str:
try:
with io.BytesIO(file_content) as file:
if dify_config.UNSTRUCTURED_API_URL and dify_config.UNSTRUCTURED_API_KEY:
elements = partition_via_api(
file=file,
api_url=dify_config.UNSTRUCTURED_API_URL,
api_key=dify_config.UNSTRUCTURED_API_KEY,
)
else:
elements = partition_pptx(file=file)
return "\n".join([getattr(element, "text", "") for element in elements])
except Exception as e:

View File

@ -0,0 +1,18 @@
class HttpRequestNodeError(ValueError):
"""Custom error for HTTP request node."""
class AuthorizationConfigError(HttpRequestNodeError):
"""Raised when authorization config is missing or invalid."""
class FileFetchError(HttpRequestNodeError):
"""Raised when a file cannot be fetched."""
class InvalidHttpMethodError(HttpRequestNodeError):
"""Raised when an invalid HTTP method is used."""
class ResponseSizeError(HttpRequestNodeError):
"""Raised when the response size exceeds the allowed threshold."""

View File

@ -18,6 +18,12 @@ from .entities import (
HttpRequestNodeTimeout,
Response,
)
from .exc import (
AuthorizationConfigError,
FileFetchError,
InvalidHttpMethodError,
ResponseSizeError,
)
BODY_TYPE_TO_CONTENT_TYPE = {
"json": "application/json",
@ -51,7 +57,7 @@ class Executor:
# If authorization API key is present, convert the API key using the variable pool
if node_data.authorization.type == "api-key":
if node_data.authorization.config is None:
raise ValueError("authorization config is required")
raise AuthorizationConfigError("authorization config is required")
node_data.authorization.config.api_key = variable_pool.convert_template(
node_data.authorization.config.api_key
).text
@ -82,8 +88,10 @@ class Executor:
self.url = self.variable_pool.convert_template(self.node_data.url).text
def _init_params(self):
params = self.variable_pool.convert_template(self.node_data.params).text
self.params = _plain_text_to_dict(params)
params = _plain_text_to_dict(self.node_data.params)
for key in params:
params[key] = self.variable_pool.convert_template(params[key]).text
self.params = params
def _init_headers(self):
headers = self.variable_pool.convert_template(self.node_data.headers).text
@ -116,7 +124,7 @@ class Executor:
file_selector = data[0].file
file_variable = self.variable_pool.get_file(file_selector)
if file_variable is None:
raise ValueError(f"cannot fetch file with selector {file_selector}")
raise FileFetchError(f"cannot fetch file with selector {file_selector}")
file = file_variable.value
self.content = file_manager.download(file)
case "x-www-form-urlencoded":
@ -155,12 +163,12 @@ class Executor:
headers = deepcopy(self.headers) or {}
if self.auth.type == "api-key":
if self.auth.config is None:
raise ValueError("self.authorization config is required")
raise AuthorizationConfigError("self.authorization config is required")
if authorization.config is None:
raise ValueError("authorization config is required")
raise AuthorizationConfigError("authorization config is required")
if self.auth.config.api_key is None:
raise ValueError("api_key is required")
raise AuthorizationConfigError("api_key is required")
if not authorization.config.header:
authorization.config.header = "Authorization"
@ -183,7 +191,7 @@ class Executor:
else dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE
)
if executor_response.size > threshold_size:
raise ValueError(
raise ResponseSizeError(
f'{"File" if executor_response.is_file else "Text"} size is too large,'
f' max size is {threshold_size / 1024 / 1024:.2f} MB,'
f' but current size is {executor_response.readable_size}.'
@ -196,7 +204,7 @@ class Executor:
do http request depending on api bundle
"""
if self.method not in {"get", "head", "post", "put", "delete", "patch"}:
raise ValueError(f"Invalid http method {self.method}")
raise InvalidHttpMethodError(f"Invalid http method {self.method}")
request_args = {
"url": self.url,

View File

@ -20,6 +20,7 @@ from .entities import (
HttpRequestNodeTimeout,
Response,
)
from .exc import HttpRequestNodeError
HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout(
connect=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT,
@ -77,7 +78,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
"request": http_executor.to_log(),
},
)
except Exception as e:
except HttpRequestNodeError as e:
logger.warning(f"http request node {self.node_id} failed to run: {e}")
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,

View File

@ -1,3 +1,4 @@
from enum import Enum
from typing import Any, Optional
from pydantic import Field
@ -5,6 +6,12 @@ from pydantic import Field
from core.workflow.nodes.base import BaseIterationNodeData, BaseIterationState, BaseNodeData
class ErrorHandleMode(str, Enum):
TERMINATED = "terminated"
CONTINUE_ON_ERROR = "continue-on-error"
REMOVE_ABNORMAL_OUTPUT = "remove-abnormal-output"
class IterationNodeData(BaseIterationNodeData):
"""
Iteration Node Data.
@ -13,6 +20,9 @@ class IterationNodeData(BaseIterationNodeData):
parent_loop_id: Optional[str] = None # redundant field, not used currently
iterator_selector: list[str] # variable selector
output_selector: list[str] # output selector
is_parallel: bool = False # open the parallel mode or not
parallel_nums: int = 10 # the numbers of parallel
error_handle_mode: ErrorHandleMode = ErrorHandleMode.TERMINATED # how to handle the error
class IterationStartNodeData(BaseNodeData):

View File

@ -1,12 +1,20 @@
import logging
import uuid
from collections.abc import Generator, Mapping, Sequence
from concurrent.futures import Future, wait
from datetime import datetime, timezone
from typing import Any, cast
from queue import Empty, Queue
from typing import TYPE_CHECKING, Any, Optional, cast
from flask import Flask, current_app
from configs import dify_config
from core.model_runtime.utils.encoders import jsonable_encoder
from core.variables import IntegerSegment
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.entities.node_entities import (
NodeRunMetadataKey,
NodeRunResult,
)
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
BaseGraphEvent,
BaseNodeEvent,
@ -17,6 +25,9 @@ from core.workflow.graph_engine.entities.event import (
IterationRunNextEvent,
IterationRunStartedEvent,
IterationRunSucceededEvent,
NodeInIterationFailedEvent,
NodeRunFailedEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
@ -24,9 +35,11 @@ from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from core.workflow.nodes.iteration.entities import IterationNodeData
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
from models.workflow import WorkflowNodeExecutionStatus
if TYPE_CHECKING:
from core.workflow.graph_engine.graph_engine import GraphEngine
logger = logging.getLogger(__name__)
@ -38,6 +51,17 @@ class IterationNode(BaseNode[IterationNodeData]):
_node_data_cls = IterationNodeData
_node_type = NodeType.ITERATION
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
return {
"type": "iteration",
"config": {
"is_parallel": False,
"parallel_nums": 10,
"error_handle_mode": ErrorHandleMode.TERMINATED.value,
},
}
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
"""
Run the node.
@ -83,7 +107,7 @@ class IterationNode(BaseNode[IterationNodeData]):
variable_pool.add([self.node_id, "item"], iterator_list_value[0])
# init graph engine
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.graph_engine.graph_engine import GraphEngine, GraphEngineThreadPool
graph_engine = GraphEngine(
tenant_id=self.tenant_id,
@ -123,108 +147,64 @@ class IterationNode(BaseNode[IterationNodeData]):
index=0,
pre_iteration_output=None,
)
outputs: list[Any] = []
try:
for _ in range(len(iterator_list_value)):
# run workflow
rst = graph_engine.run()
for event in rst:
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
event.in_iteration_id = self.node_id
if (
isinstance(event, BaseNodeEvent)
and event.node_type == NodeType.ITERATION_START
and not isinstance(event, NodeRunStreamChunkEvent)
):
if self.node_data.is_parallel:
futures: list[Future] = []
q = Queue()
thread_pool = GraphEngineThreadPool(max_workers=self.node_data.parallel_nums, max_submit_count=100)
for index, item in enumerate(iterator_list_value):
future: Future = thread_pool.submit(
self._run_single_iter_parallel,
current_app._get_current_object(),
q,
iterator_list_value,
inputs,
outputs,
start_at,
graph_engine,
iteration_graph,
index,
item,
)
future.add_done_callback(thread_pool.task_done_callback)
futures.append(future)
succeeded_count = 0
while True:
try:
event = q.get(timeout=1)
if event is None:
break
if isinstance(event, IterationRunNextEvent):
succeeded_count += 1
if succeeded_count == len(futures):
q.put(None)
yield event
if isinstance(event, RunCompletedEvent):
q.put(None)
for f in futures:
if not f.done():
f.cancel()
yield event
if isinstance(event, IterationRunFailedEvent):
q.put(None)
yield event
except Empty:
continue
if isinstance(event, NodeRunSucceededEvent):
if event.route_node_state.node_run_result:
metadata = event.route_node_state.node_run_result.metadata
if not metadata:
metadata = {}
if NodeRunMetadataKey.ITERATION_ID not in metadata:
metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id
index_variable = variable_pool.get([self.node_id, "index"])
if not isinstance(index_variable, IntegerSegment):
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=f"Invalid index variable type: {type(index_variable)}",
)
)
return
metadata[NodeRunMetadataKey.ITERATION_INDEX] = index_variable.value
event.route_node_state.node_run_result.metadata = metadata
yield event
elif isinstance(event, BaseGraphEvent):
if isinstance(event, GraphRunFailedEvent):
# iteration run failed
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": jsonable_encoder(outputs)},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=event.error,
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=event.error,
)
)
return
# wait all threads
wait(futures)
else:
event = cast(InNodeEvent, event)
yield event
# append to iteration output variable list
current_iteration_output_variable = variable_pool.get(self.node_data.output_selector)
if current_iteration_output_variable is None:
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=f"Iteration output variable {self.node_data.output_selector} not found",
for _ in range(len(iterator_list_value)):
yield from self._run_single_iter(
iterator_list_value,
variable_pool,
inputs,
outputs,
start_at,
graph_engine,
iteration_graph,
)
)
return
current_iteration_output = current_iteration_output_variable.to_object()
outputs.append(current_iteration_output)
# remove all nodes outputs from variable pool
for node_id in iteration_graph.node_ids:
variable_pool.remove([node_id])
# move to next iteration
current_index_variable = variable_pool.get([self.node_id, "index"])
if not isinstance(current_index_variable, IntegerSegment):
raise ValueError(f"iteration {self.node_id} current index not found")
next_index = current_index_variable.value + 1
variable_pool.add([self.node_id, "index"], next_index)
if next_index < len(iterator_list_value):
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
index=next_index,
pre_iteration_output=jsonable_encoder(current_iteration_output),
)
yield IterationRunSucceededEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
@ -330,3 +310,231 @@ class IterationNode(BaseNode[IterationNodeData]):
}
return variable_mapping
def _handle_event_metadata(
self, event: BaseNodeEvent, iter_run_index: str, parallel_mode_run_id: str
) -> NodeRunStartedEvent | BaseNodeEvent:
"""
add iteration metadata to event.
"""
if not isinstance(event, BaseNodeEvent):
return event
if self.node_data.is_parallel and isinstance(event, NodeRunStartedEvent):
event.parallel_mode_run_id = parallel_mode_run_id
return event
if event.route_node_state.node_run_result:
metadata = event.route_node_state.node_run_result.metadata
if not metadata:
metadata = {}
if NodeRunMetadataKey.ITERATION_ID not in metadata:
metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id
if self.node_data.is_parallel:
metadata[NodeRunMetadataKey.PARALLEL_MODE_RUN_ID] = parallel_mode_run_id
else:
metadata[NodeRunMetadataKey.ITERATION_INDEX] = iter_run_index
event.route_node_state.node_run_result.metadata = metadata
return event
def _run_single_iter(
self,
iterator_list_value: list[str],
variable_pool: VariablePool,
inputs: dict[str, list],
outputs: list,
start_at: datetime,
graph_engine: "GraphEngine",
iteration_graph: Graph,
parallel_mode_run_id: Optional[str] = None,
) -> Generator[NodeEvent | InNodeEvent, None, None]:
"""
run single iteration
"""
try:
rst = graph_engine.run()
# get current iteration index
current_index = variable_pool.get([self.node_id, "index"]).value
next_index = int(current_index) + 1
if current_index is None:
raise ValueError(f"iteration {self.node_id} current index not found")
for event in rst:
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
event.in_iteration_id = self.node_id
if (
isinstance(event, BaseNodeEvent)
and event.node_type == NodeType.ITERATION_START
and not isinstance(event, NodeRunStreamChunkEvent)
):
continue
if isinstance(event, NodeRunSucceededEvent):
yield self._handle_event_metadata(event, current_index, parallel_mode_run_id)
elif isinstance(event, BaseGraphEvent):
if isinstance(event, GraphRunFailedEvent):
# iteration run failed
if self.node_data.is_parallel:
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
parallel_mode_run_id=parallel_mode_run_id,
start_at=start_at,
inputs=inputs,
outputs={"output": jsonable_encoder(outputs)},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=event.error,
)
else:
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": jsonable_encoder(outputs)},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=event.error,
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=event.error,
)
)
return
else:
event = cast(InNodeEvent, event)
metadata_event = self._handle_event_metadata(event, current_index, parallel_mode_run_id)
if isinstance(event, NodeRunFailedEvent):
if self.node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR:
yield NodeInIterationFailedEvent(
**metadata_event.model_dump(),
)
outputs.insert(current_index, None)
variable_pool.add([self.node_id, "index"], next_index)
if next_index < len(iterator_list_value):
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
index=next_index,
parallel_mode_run_id=parallel_mode_run_id,
pre_iteration_output=None,
)
return
elif self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
yield NodeInIterationFailedEvent(
**metadata_event.model_dump(),
)
variable_pool.add([self.node_id, "index"], next_index)
if next_index < len(iterator_list_value):
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
index=next_index,
parallel_mode_run_id=parallel_mode_run_id,
pre_iteration_output=None,
)
return
elif self.node_data.error_handle_mode == ErrorHandleMode.TERMINATED:
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": None},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=event.error,
)
yield metadata_event
current_iteration_output = variable_pool.get(self.node_data.output_selector).value
outputs.insert(current_index, current_iteration_output)
# remove all nodes outputs from variable pool
for node_id in iteration_graph.node_ids:
variable_pool.remove([node_id])
# move to next iteration
variable_pool.add([self.node_id, "index"], next_index)
if next_index < len(iterator_list_value):
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
index=next_index,
parallel_mode_run_id=parallel_mode_run_id,
pre_iteration_output=jsonable_encoder(current_iteration_output) if current_iteration_output else None,
)
except Exception as e:
logger.exception(f"Iteration run failed:{str(e)}")
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": None},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=str(e),
)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
)
)
def _run_single_iter_parallel(
self,
flask_app: Flask,
q: Queue,
iterator_list_value: list[str],
inputs: dict[str, list],
outputs: list,
start_at: datetime,
graph_engine: "GraphEngine",
iteration_graph: Graph,
index: int,
item: Any,
) -> Generator[NodeEvent | InNodeEvent, None, None]:
"""
run single iteration in parallel mode
"""
with flask_app.app_context():
parallel_mode_run_id = uuid.uuid4().hex
graph_engine_copy = graph_engine.create_copy()
variable_pool_copy = graph_engine_copy.graph_runtime_state.variable_pool
variable_pool_copy.add([self.node_id, "index"], index)
variable_pool_copy.add([self.node_id, "item"], item)
for event in self._run_single_iter(
iterator_list_value=iterator_list_value,
variable_pool=variable_pool_copy,
inputs=inputs,
outputs=outputs,
start_at=start_at,
graph_engine=graph_engine_copy,
iteration_graph=iteration_graph,
parallel_mode_run_id=parallel_mode_run_id,
):
q.put(event)

View File

@ -0,0 +1,16 @@
class ListOperatorError(ValueError):
"""Base class for all ListOperator errors."""
pass
class InvalidFilterValueError(ListOperatorError):
pass
class InvalidKeyError(ListOperatorError):
pass
class InvalidConditionError(ListOperatorError):
pass

View File

@ -1,5 +1,5 @@
from collections.abc import Callable, Sequence
from typing import Literal
from typing import Literal, Union
from core.file import File
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
@ -9,6 +9,7 @@ from core.workflow.nodes.enums import NodeType
from models.workflow import WorkflowNodeExecutionStatus
from .entities import ListOperatorNodeData
from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError
class ListOperatorNode(BaseNode[ListOperatorNodeData]):
@ -26,7 +27,17 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs
)
if variable.value and not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment):
if not variable.value:
inputs = {"variable": []}
process_data = {"variable": []}
outputs = {"result": [], "first_record": None, "last_record": None}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
outputs=outputs,
)
if not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment):
error_message = (
f"Variable {self.node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment "
"or ArrayStringSegment"
@ -36,23 +47,59 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
)
if isinstance(variable, ArrayFileSegment):
inputs = {"variable": [item.to_dict() for item in variable.value]}
process_data["variable"] = [item.to_dict() for item in variable.value]
else:
inputs = {"variable": variable.value}
process_data["variable"] = variable.value
try:
# Filter
if self.node_data.filter_by.enabled:
variable = self._apply_filter(variable)
# Order
if self.node_data.order_by.enabled:
variable = self._apply_order(variable)
# Slice
if self.node_data.limit.enabled:
variable = self._apply_slice(variable)
outputs = {
"result": variable.value,
"first_record": variable.value[0] if variable.value else None,
"last_record": variable.value[-1] if variable.value else None,
}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
outputs=outputs,
)
except ListOperatorError as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
inputs=inputs,
process_data=process_data,
outputs=outputs,
)
def _apply_filter(
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
for condition in self.node_data.filter_by.conditions:
if isinstance(variable, ArrayStringSegment):
if not isinstance(condition.value, str):
raise ValueError(f"Invalid filter value: {condition.value}")
raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
filter_func = _get_string_filter_func(condition=condition.comparison_operator, value=value)
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayNumberSegment):
if not isinstance(condition.value, str):
raise ValueError(f"Invalid filter value: {condition.value}")
raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
filter_func = _get_number_filter_func(condition=condition.comparison_operator, value=float(value))
result = list(filter(filter_func, variable.value))
@ -69,9 +116,11 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
)
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result})
return variable
# Order
if self.node_data.order_by.enabled:
def _apply_order(
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
if isinstance(variable, ArrayStringSegment):
result = _order_string(order=self.node_data.order_by.value, array=variable.value)
variable = variable.model_copy(update={"value": result})
@ -83,23 +132,13 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
order=self.node_data.order_by.value, order_by=self.node_data.order_by.key, array=variable.value
)
variable = variable.model_copy(update={"value": result})
return variable
# Slice
if self.node_data.limit.enabled:
def _apply_slice(
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
result = variable.value[: self.node_data.limit.size]
variable = variable.model_copy(update={"value": result})
outputs = {
"result": variable.value,
"first_record": variable.value[0] if variable.value else None,
"last_record": variable.value[-1] if variable.value else None,
}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
outputs=outputs,
)
return variable.model_copy(update={"value": result})
def _get_file_extract_number_func(*, key: str) -> Callable[[File], int]:
@ -107,7 +146,7 @@ def _get_file_extract_number_func(*, key: str) -> Callable[[File], int]:
case "size":
return lambda x: x.size
case _:
raise ValueError(f"Invalid key: {key}")
raise InvalidKeyError(f"Invalid key: {key}")
def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]:
@ -118,14 +157,14 @@ def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]:
return lambda x: x.type
case "extension":
return lambda x: x.extension or ""
case "mimetype":
case "mime_type":
return lambda x: x.mime_type or ""
case "transfer_method":
return lambda x: x.transfer_method
case "url":
return lambda x: x.remote_url or ""
case _:
raise ValueError(f"Invalid key: {key}")
raise InvalidKeyError(f"Invalid key: {key}")
def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bool]:
@ -151,7 +190,7 @@ def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bo
case "not empty":
return lambda x: x != ""
case _:
raise ValueError(f"Invalid condition: {condition}")
raise InvalidConditionError(f"Invalid condition: {condition}")
def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callable[[str], bool]:
@ -161,7 +200,7 @@ def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callab
case "not in":
return lambda x: not _in(value)(x)
case _:
raise ValueError(f"Invalid condition: {condition}")
raise InvalidConditionError(f"Invalid condition: {condition}")
def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[int | float], bool]:
@ -179,7 +218,7 @@ def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[
case "":
return _ge(value)
case _:
raise ValueError(f"Invalid condition: {condition}")
raise InvalidConditionError(f"Invalid condition: {condition}")
def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]:
@ -193,7 +232,7 @@ def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str
extract_func = _get_file_extract_number_func(key=key)
return lambda x: _get_number_filter_func(condition=condition, value=float(value))(extract_func(x))
else:
raise ValueError(f"Invalid key: {key}")
raise InvalidKeyError(f"Invalid key: {key}")
def _contains(value: str):
@ -256,4 +295,4 @@ def _order_file(*, order: Literal["asc", "desc"], order_by: str = "", array: Seq
extract_func = _get_file_extract_number_func(key=order_by)
return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc")
else:
raise ValueError(f"Invalid order key: {order_by}")
raise InvalidKeyError(f"Invalid order key: {order_by}")

View File

@ -0,0 +1,26 @@
class LLMNodeError(ValueError):
"""Base class for LLM Node errors."""
class VariableNotFoundError(LLMNodeError):
"""Raised when a required variable is not found."""
class InvalidContextStructureError(LLMNodeError):
"""Raised when the context structure is invalid."""
class InvalidVariableTypeError(LLMNodeError):
"""Raised when the variable type is invalid."""
class ModelNotExistError(LLMNodeError):
"""Raised when the specified model does not exist."""
class LLMModeRequiredError(LLMNodeError):
"""Raised when LLM mode is required but not provided."""
class NoPromptFoundError(LLMNodeError):
"""Raised when no prompt is found in the LLM configuration."""

View File

@ -56,6 +56,15 @@ from .entities import (
LLMNodeData,
ModelConfig,
)
from .exc import (
InvalidContextStructureError,
InvalidVariableTypeError,
LLMModeRequiredError,
LLMNodeError,
ModelNotExistError,
NoPromptFoundError,
VariableNotFoundError,
)
if TYPE_CHECKING:
from core.file.models import File
@ -103,7 +112,7 @@ class LLMNode(BaseNode[LLMNodeData]):
yield event
if context:
node_inputs["#context#"] = context # type: ignore
node_inputs["#context#"] = context
# fetch model config
model_instance, model_config = self._fetch_model_config(self.node_data.model)
@ -115,7 +124,7 @@ class LLMNode(BaseNode[LLMNodeData]):
if self.node_data.memory:
query = self.graph_runtime_state.variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
if not query:
raise ValueError("Query not found")
raise VariableNotFoundError("Query not found")
query = query.text
else:
query = None
@ -161,7 +170,7 @@ class LLMNode(BaseNode[LLMNodeData]):
usage = event.usage
finish_reason = event.finish_reason
break
except Exception as e:
except LLMNodeError as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
@ -275,7 +284,7 @@ class LLMNode(BaseNode[LLMNodeData]):
variable_name = variable_selector.variable
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
if variable is None:
raise ValueError(f"Variable {variable_selector.variable} not found")
raise VariableNotFoundError(f"Variable {variable_selector.variable} not found")
def parse_dict(input_dict: Mapping[str, Any]) -> str:
"""
@ -325,7 +334,7 @@ class LLMNode(BaseNode[LLMNodeData]):
for variable_selector in variable_selectors:
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
if variable is None:
raise ValueError(f"Variable {variable_selector.variable} not found")
raise VariableNotFoundError(f"Variable {variable_selector.variable} not found")
if isinstance(variable, NoneSegment):
inputs[variable_selector.variable] = ""
inputs[variable_selector.variable] = variable.to_object()
@ -338,7 +347,7 @@ class LLMNode(BaseNode[LLMNodeData]):
for variable_selector in query_variable_selectors:
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
if variable is None:
raise ValueError(f"Variable {variable_selector.variable} not found")
raise VariableNotFoundError(f"Variable {variable_selector.variable} not found")
if isinstance(variable, NoneSegment):
continue
inputs[variable_selector.variable] = variable.to_object()
@ -355,7 +364,7 @@ class LLMNode(BaseNode[LLMNodeData]):
return variable.value
elif isinstance(variable, NoneSegment | ArrayAnySegment):
return []
raise ValueError(f"Invalid variable type: {type(variable)}")
raise InvalidVariableTypeError(f"Invalid variable type: {type(variable)}")
def _fetch_context(self, node_data: LLMNodeData):
if not node_data.context.enabled:
@ -376,7 +385,7 @@ class LLMNode(BaseNode[LLMNodeData]):
context_str += item + "\n"
else:
if "content" not in item:
raise ValueError(f"Invalid context structure: {item}")
raise InvalidContextStructureError(f"Invalid context structure: {item}")
context_str += item["content"] + "\n"
@ -441,7 +450,7 @@ class LLMNode(BaseNode[LLMNodeData]):
)
if provider_model is None:
raise ValueError(f"Model {model_name} not exist.")
raise ModelNotExistError(f"Model {model_name} not exist.")
if provider_model.status == ModelStatus.NO_CONFIGURE:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
@ -460,12 +469,12 @@ class LLMNode(BaseNode[LLMNodeData]):
# get model mode
model_mode = node_data_model.mode
if not model_mode:
raise ValueError("LLM mode is required.")
raise LLMModeRequiredError("LLM mode is required.")
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
if not model_schema:
raise ValueError(f"Model {model_name} not exist.")
raise ModelNotExistError(f"Model {model_name} not exist.")
return model_instance, ModelConfigWithCredentialsEntity(
provider=provider_name,
@ -564,7 +573,7 @@ class LLMNode(BaseNode[LLMNodeData]):
filtered_prompt_messages.append(prompt_message)
if not filtered_prompt_messages:
raise ValueError(
raise NoPromptFoundError(
"No prompt found in the LLM configuration. "
"Please ensure a prompt is properly configured before proceeding."
)
@ -636,7 +645,7 @@ class LLMNode(BaseNode[LLMNodeData]):
variable_template_parser = VariableTemplateParser(template=prompt_template.text)
variable_selectors = variable_template_parser.extract_variable_selectors()
else:
raise ValueError(f"Invalid prompt template type: {type(prompt_template)}")
raise InvalidVariableTypeError(f"Invalid prompt template type: {type(prompt_template)}")
variable_mapping = {}
for variable_selector in variable_selectors:

View File

@ -0,0 +1,50 @@
class ParameterExtractorNodeError(ValueError):
"""Base error for ParameterExtractorNode."""
class InvalidModelTypeError(ParameterExtractorNodeError):
"""Raised when the model is not a Large Language Model."""
class ModelSchemaNotFoundError(ParameterExtractorNodeError):
"""Raised when the model schema is not found."""
class InvalidInvokeResultError(ParameterExtractorNodeError):
"""Raised when the invoke result is invalid."""
class InvalidTextContentTypeError(ParameterExtractorNodeError):
"""Raised when the text content type is invalid."""
class InvalidNumberOfParametersError(ParameterExtractorNodeError):
"""Raised when the number of parameters is invalid."""
class RequiredParameterMissingError(ParameterExtractorNodeError):
"""Raised when a required parameter is missing."""
class InvalidSelectValueError(ParameterExtractorNodeError):
"""Raised when a select value is invalid."""
class InvalidNumberValueError(ParameterExtractorNodeError):
"""Raised when a number value is invalid."""
class InvalidBoolValueError(ParameterExtractorNodeError):
"""Raised when a bool value is invalid."""
class InvalidStringValueError(ParameterExtractorNodeError):
"""Raised when a string value is invalid."""
class InvalidArrayValueError(ParameterExtractorNodeError):
"""Raised when an array value is invalid."""
class InvalidModelModeError(ParameterExtractorNodeError):
"""Raised when the model mode is invalid."""

View File

@ -32,6 +32,21 @@ from extensions.ext_database import db
from models.workflow import WorkflowNodeExecutionStatus
from .entities import ParameterExtractorNodeData
from .exc import (
InvalidArrayValueError,
InvalidBoolValueError,
InvalidInvokeResultError,
InvalidModelModeError,
InvalidModelTypeError,
InvalidNumberOfParametersError,
InvalidNumberValueError,
InvalidSelectValueError,
InvalidStringValueError,
InvalidTextContentTypeError,
ModelSchemaNotFoundError,
ParameterExtractorNodeError,
RequiredParameterMissingError,
)
from .prompts import (
CHAT_EXAMPLE,
CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE,
@ -85,7 +100,7 @@ class ParameterExtractorNode(LLMNode):
model_instance, model_config = self._fetch_model_config(node_data.model)
if not isinstance(model_instance.model_type_instance, LargeLanguageModel):
raise ValueError("Model is not a Large Language Model")
raise InvalidModelTypeError("Model is not a Large Language Model")
llm_model = model_instance.model_type_instance
model_schema = llm_model.get_model_schema(
@ -93,7 +108,7 @@ class ParameterExtractorNode(LLMNode):
credentials=model_config.credentials,
)
if not model_schema:
raise ValueError("Model schema not found")
raise ModelSchemaNotFoundError("Model schema not found")
# fetch memory
memory = self._fetch_memory(
@ -155,7 +170,7 @@ class ParameterExtractorNode(LLMNode):
process_data["usage"] = jsonable_encoder(usage)
process_data["tool_call"] = jsonable_encoder(tool_call)
process_data["llm_text"] = text
except Exception as e:
except ParameterExtractorNodeError as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=inputs,
@ -177,7 +192,7 @@ class ParameterExtractorNode(LLMNode):
try:
result = self._validate_result(data=node_data, result=result or {})
except Exception as e:
except ParameterExtractorNodeError as e:
error = str(e)
# transform result into standard format
@ -217,11 +232,11 @@ class ParameterExtractorNode(LLMNode):
# handle invoke result
if not isinstance(invoke_result, LLMResult):
raise ValueError(f"Invalid invoke result: {invoke_result}")
raise InvalidInvokeResultError(f"Invalid invoke result: {invoke_result}")
text = invoke_result.message.content
if not isinstance(text, str):
raise ValueError(f"Invalid text content type: {type(text)}. Expected str.")
raise InvalidTextContentTypeError(f"Invalid text content type: {type(text)}. Expected str.")
usage = invoke_result.usage
tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None
@ -344,7 +359,7 @@ class ParameterExtractorNode(LLMNode):
files=files,
)
else:
raise ValueError(f"Invalid model mode: {model_mode}")
raise InvalidModelModeError(f"Invalid model mode: {model_mode}")
def _generate_prompt_engineering_completion_prompt(
self,
@ -449,36 +464,36 @@ class ParameterExtractorNode(LLMNode):
Validate result.
"""
if len(data.parameters) != len(result):
raise ValueError("Invalid number of parameters")
raise InvalidNumberOfParametersError("Invalid number of parameters")
for parameter in data.parameters:
if parameter.required and parameter.name not in result:
raise ValueError(f"Parameter {parameter.name} is required")
raise RequiredParameterMissingError(f"Parameter {parameter.name} is required")
if parameter.type == "select" and parameter.options and result.get(parameter.name) not in parameter.options:
raise ValueError(f"Invalid `select` value for parameter {parameter.name}")
raise InvalidSelectValueError(f"Invalid `select` value for parameter {parameter.name}")
if parameter.type == "number" and not isinstance(result.get(parameter.name), int | float):
raise ValueError(f"Invalid `number` value for parameter {parameter.name}")
raise InvalidNumberValueError(f"Invalid `number` value for parameter {parameter.name}")
if parameter.type == "bool" and not isinstance(result.get(parameter.name), bool):
raise ValueError(f"Invalid `bool` value for parameter {parameter.name}")
raise InvalidBoolValueError(f"Invalid `bool` value for parameter {parameter.name}")
if parameter.type == "string" and not isinstance(result.get(parameter.name), str):
raise ValueError(f"Invalid `string` value for parameter {parameter.name}")
raise InvalidStringValueError(f"Invalid `string` value for parameter {parameter.name}")
if parameter.type.startswith("array"):
parameters = result.get(parameter.name)
if not isinstance(parameters, list):
raise ValueError(f"Invalid `array` value for parameter {parameter.name}")
raise InvalidArrayValueError(f"Invalid `array` value for parameter {parameter.name}")
nested_type = parameter.type[6:-1]
for item in parameters:
if nested_type == "number" and not isinstance(item, int | float):
raise ValueError(f"Invalid `array[number]` value for parameter {parameter.name}")
raise InvalidArrayValueError(f"Invalid `array[number]` value for parameter {parameter.name}")
if nested_type == "string" and not isinstance(item, str):
raise ValueError(f"Invalid `array[string]` value for parameter {parameter.name}")
raise InvalidArrayValueError(f"Invalid `array[string]` value for parameter {parameter.name}")
if nested_type == "object" and not isinstance(item, dict):
raise ValueError(f"Invalid `array[object]` value for parameter {parameter.name}")
raise InvalidArrayValueError(f"Invalid `array[object]` value for parameter {parameter.name}")
return result
def _transform_result(self, data: ParameterExtractorNodeData, result: dict) -> dict:
@ -634,7 +649,7 @@ class ParameterExtractorNode(LLMNode):
user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text)
return [system_prompt_messages, user_prompt_message]
else:
raise ValueError(f"Model mode {model_mode} not support.")
raise InvalidModelModeError(f"Model mode {model_mode} not support.")
def _get_prompt_engineering_prompt_template(
self,
@ -669,7 +684,7 @@ class ParameterExtractorNode(LLMNode):
.replace("}γγγ", "")
)
else:
raise ValueError(f"Model mode {model_mode} not support.")
raise InvalidModelModeError(f"Model mode {model_mode} not support.")
def _calculate_rest_token(
self,
@ -683,12 +698,12 @@ class ParameterExtractorNode(LLMNode):
model_instance, model_config = self._fetch_model_config(node_data.model)
if not isinstance(model_instance.model_type_instance, LargeLanguageModel):
raise ValueError("Model is not a Large Language Model")
raise InvalidModelTypeError("Model is not a Large Language Model")
llm_model = model_instance.model_type_instance
model_schema = llm_model.get_model_schema(model_config.model, model_config.credentials)
if not model_schema:
raise ValueError("Model schema not found")
raise ModelSchemaNotFoundError("Model schema not found")
if set(model_schema.features or []) & {ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}:
prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000)

View File

@ -91,6 +91,8 @@ def build_segment(value: Any, /) -> Segment:
return ArrayObjectSegment(value=value)
case SegmentType.FILE:
return ArrayFileSegment(value=value)
case SegmentType.NONE:
return ArrayAnySegment(value=value)
case _:
raise ValueError(f"not supported value {value}")
raise ValueError(f"not supported value {value}")

View File

@ -8,6 +8,7 @@ upload_config_fields = {
"image_file_size_limit": fields.Integer,
"video_file_size_limit": fields.Integer,
"audio_file_size_limit": fields.Integer,
"workflow_file_upload_limit": fields.Integer,
}
file_fields = {

View File

@ -6,7 +6,6 @@ from .model import (
AppMode,
Conversation,
EndUser,
FileUploadConfig,
InstalledApp,
Message,
MessageAnnotation,
@ -50,6 +49,5 @@ __all__ = [
"Tenant",
"Conversation",
"MessageAnnotation",
"FileUploadConfig",
"ToolFile",
]

View File

@ -121,7 +121,7 @@ class App(Base):
return site
@property
def app_model_config(self) -> Optional["AppModelConfig"]:
def app_model_config(self):
if self.app_model_config_id:
return db.session.query(AppModelConfig).filter(AppModelConfig.id == self.app_model_config_id).first()
@ -1320,7 +1320,7 @@ class Site(Base):
privacy_policy = db.Column(db.String(255))
show_workflow_steps = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
use_icon_as_answer_icon = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
custom_disclaimer: Mapped[str] = mapped_column(sa.TEXT, default="")
_custom_disclaimer: Mapped[str] = mapped_column("custom_disclaimer", sa.TEXT, default="")
customize_domain = db.Column(db.String(255))
customize_token_strategy = db.Column(db.String(255), nullable=False)
prompt_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
@ -1331,6 +1331,16 @@ class Site(Base):
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
code = db.Column(db.String(255))
@property
def custom_disclaimer(self):
return self._custom_disclaimer
@custom_disclaimer.setter
def custom_disclaimer(self, value: str):
if len(value) > 512:
raise ValueError("Custom disclaimer cannot exceed 512 characters.")
self._custom_disclaimer = value
@staticmethod
def generate_code(n):
while True:

View File

@ -1,6 +1,6 @@
import json
from collections.abc import Mapping, Sequence
from datetime import datetime
from datetime import datetime, timezone
from enum import Enum
from typing import TYPE_CHECKING, Any, Optional, Union
@ -111,7 +111,9 @@ class Workflow(Base):
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
)
updated_by: Mapped[Optional[str]] = mapped_column(StringUUID)
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False)
updated_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, default=datetime.now(tz=timezone.utc), server_onupdate=func.current_timestamp()
)
_environment_variables: Mapped[str] = mapped_column(
"environment_variables", db.Text, nullable=False, server_default="{}"
)

70
api/poetry.lock generated
View File

@ -2532,6 +2532,19 @@ files = [
{file = "filetype-1.2.0.tar.gz", hash = "sha256:66b56cd6474bf41d8c54660347d37afcc3f7d1970648de365c102ef77548aadb"},
]
[[package]]
name = "fire"
version = "0.7.0"
description = "A library for automatically generating command line interfaces."
optional = false
python-versions = "*"
files = [
{file = "fire-0.7.0.tar.gz", hash = "sha256:961550f07936eaf65ad1dc8360f2b2bf8408fad46abbfa4d2a3794f8d2a95cdf"},
]
[package.dependencies]
termcolor = "*"
[[package]]
name = "flasgger"
version = "0.9.7.1"
@ -2697,6 +2710,19 @@ files = [
{file = "flatbuffers-24.3.25.tar.gz", hash = "sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4"},
]
[[package]]
name = "fontmeta"
version = "1.6.1"
description = "An Utility to get ttf/otf font metadata"
optional = false
python-versions = "*"
files = [
{file = "fontmeta-1.6.1.tar.gz", hash = "sha256:837e5bc4da879394b41bda1428a8a480eb7c4e993799a93cfb582bab771a9c24"},
]
[package.dependencies]
fonttools = "*"
[[package]]
name = "fonttools"
version = "4.54.1"
@ -5279,6 +5305,22 @@ files = [
{file = "monotonic-1.6.tar.gz", hash = "sha256:3a55207bcfed53ddd5c5bae174524062935efed17792e9de2ad0205ce9ad63f7"},
]
[[package]]
name = "mplfonts"
version = "0.0.8"
description = "Fonts manager for matplotlib"
optional = false
python-versions = ">=3.7"
files = [
{file = "mplfonts-0.0.8-py3-none-any.whl", hash = "sha256:b2182e5b0baa216cf016dec19942740e5b48956415708ad2d465e03952112ec1"},
{file = "mplfonts-0.0.8.tar.gz", hash = "sha256:0abcb2fc0605645e1e7561c6923014d856f11676899b33b4d89757843f5e7c22"},
]
[package.dependencies]
fire = ">=0.4.0"
fontmeta = ">=1.6.1"
matplotlib = ">=3.4"
[[package]]
name = "mpmath"
version = "1.3.0"
@ -9300,6 +9342,20 @@ files = [
[package.dependencies]
tencentcloud-sdk-python-common = "3.0.1257"
[[package]]
name = "termcolor"
version = "2.5.0"
description = "ANSI color formatting for output in terminal"
optional = false
python-versions = ">=3.9"
files = [
{file = "termcolor-2.5.0-py3-none-any.whl", hash = "sha256:37b17b5fc1e604945c2642c872a3764b5d547a48009871aea3edd3afa180afb8"},
{file = "termcolor-2.5.0.tar.gz", hash = "sha256:998d8d27da6d48442e8e1f016119076b690d962507531df4890fcd2db2ef8a6f"},
]
[package.extras]
tests = ["pytest", "pytest-cov"]
[[package]]
name = "threadpoolctl"
version = "3.5.0"
@ -10046,13 +10102,13 @@ files = [
[[package]]
name = "vanna"
version = "0.7.3"
version = "0.7.5"
description = "Generate SQL queries from natural language"
optional = false
python-versions = ">=3.9"
files = [
{file = "vanna-0.7.3-py3-none-any.whl", hash = "sha256:82ba39e5d6c503d1c8cca60835ed401d20ec3a3da98d487f529901dcb30061d6"},
{file = "vanna-0.7.3.tar.gz", hash = "sha256:4590dd94d2fe180b4efc7a83c867b73144ef58794018910dc226857cfb703077"},
{file = "vanna-0.7.5-py3-none-any.whl", hash = "sha256:07458c7befa49de517a8760c2d80a13147278b484c515d49a906acc88edcb835"},
{file = "vanna-0.7.5.tar.gz", hash = "sha256:2fdffc58832898e4fc8e93c45b173424db59a22773b22ca348640161d391eacf"},
]
[package.dependencies]
@ -10073,7 +10129,7 @@ sqlparse = "*"
tabulate = "*"
[package.extras]
all = ["PyMySQL", "anthropic", "azure-common", "azure-identity", "azure-search-documents", "chromadb", "db-dtypes", "duckdb", "fastembed", "google-cloud-aiplatform", "google-cloud-bigquery", "google-generativeai", "httpx", "marqo", "mistralai (>=1.0.0)", "ollama", "openai", "opensearch-dsl", "opensearch-py", "pinecone-client", "psycopg2-binary", "pymilvus[model]", "qdrant-client", "qianfan", "snowflake-connector-python", "transformers", "weaviate-client", "zhipuai"]
all = ["PyMySQL", "anthropic", "azure-common", "azure-identity", "azure-search-documents", "boto", "boto3", "botocore", "chromadb", "db-dtypes", "duckdb", "faiss-cpu", "fastembed", "google-cloud-aiplatform", "google-cloud-bigquery", "google-generativeai", "httpx", "langchain_core", "langchain_postgres", "marqo", "mistralai (>=1.0.0)", "ollama", "openai", "opensearch-dsl", "opensearch-py", "pinecone-client", "psycopg2-binary", "pymilvus[model]", "qdrant-client", "qianfan", "snowflake-connector-python", "transformers", "weaviate-client", "xinference-client", "zhipuai"]
anthropic = ["anthropic"]
azuresearch = ["azure-common", "azure-identity", "azure-search-documents", "fastembed"]
bedrock = ["boto3", "botocore"]
@ -10081,6 +10137,8 @@ bigquery = ["google-cloud-bigquery"]
chromadb = ["chromadb"]
clickhouse = ["clickhouse_connect"]
duckdb = ["duckdb"]
faiss-cpu = ["faiss-cpu"]
faiss-gpu = ["faiss-gpu"]
gemini = ["google-generativeai"]
google = ["google-cloud-aiplatform", "google-generativeai"]
hf = ["transformers"]
@ -10091,6 +10149,7 @@ mysql = ["PyMySQL"]
ollama = ["httpx", "ollama"]
openai = ["openai"]
opensearch = ["opensearch-dsl", "opensearch-py"]
pgvector = ["langchain-postgres (>=0.0.12)"]
pinecone = ["fastembed", "pinecone-client"]
postgres = ["db-dtypes", "psycopg2-binary"]
qdrant = ["fastembed", "qdrant-client"]
@ -10099,6 +10158,7 @@ snowflake = ["snowflake-connector-python"]
test = ["tox"]
vllm = ["vllm"]
weaviate = ["weaviate-client"]
xinference-client = ["xinference-client"]
zhipuai = ["zhipuai"]
[[package]]
@ -10940,4 +11000,4 @@ cffi = ["cffi (>=1.11)"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.10,<3.13"
content-hash = "ef927b98c33d704d680e08db0e5c7d9a4e05454c66fcd6a5f656a65eb08e886b"
content-hash = "e4794898403da4ad7b51f248a6c07632a949114c1b569406d3aa6a94c62510a5"

View File

@ -206,13 +206,14 @@ cloudscraper = "1.2.71"
duckduckgo-search = "~6.3.0"
jsonpath-ng = "1.6.1"
matplotlib = "~3.8.2"
mplfonts = "~0.0.8"
newspaper3k = "0.2.8"
nltk = "3.9.1"
numexpr = "~2.9.0"
pydub = "~0.25.1"
qrcode = "~7.4.2"
twilio = "~9.0.4"
vanna = { version = "0.7.3", extras = ["postgres", "mysql", "clickhouse", "duckdb"] }
vanna = { version = "0.7.5", extras = ["postgres", "mysql", "clickhouse", "duckdb", "oracle"] }
wikipedia = "1.4.0"
yfinance = "~0.2.40"

View File

@ -14,7 +14,7 @@ from models.dataset import Embedding
@app.celery.task(queue="dataset")
def clean_embedding_cache_task():
click.echo(click.style("Start clean embedding cache.", fg="green"))
clean_days = int(dify_config.CLEAN_DAY_SETTING)
clean_days = int(dify_config.PLAN_SANDBOX_CLEAN_DAY_SETTING)
start_at = time.perf_counter()
thirty_days_ago = datetime.datetime.now() - datetime.timedelta(days=clean_days)
while True:

View File

@ -986,9 +986,6 @@ class DocumentService:
raise NotFound("Document not found")
if document.display_status != "available":
raise ValueError("Document is not available")
# update document name
if document_data.get("name"):
document.name = document_data["name"]
# save process rule
if document_data.get("process_rule"):
process_rule = document_data["process_rule"]
@ -1065,6 +1062,10 @@ class DocumentService:
document.data_source_type = document_data["data_source"]["type"]
document.data_source_info = json.dumps(data_source_info)
document.name = file_name
# update document name
if document_data.get("name"):
document.name = document_data["name"]
# update document to be waiting
document.indexing_status = "waiting"
document.completed_at = None

View File

@ -62,6 +62,37 @@ class BuiltinToolManageService:
return result
@staticmethod
def get_builtin_tool_provider_info(user_id: str, tenant_id: str, provider: str):
"""
get builtin tool provider info
"""
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
tool_provider_configurations = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
# check if user has added the provider
builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
credentials = {}
if builtin_provider is not None:
# get credentials
credentials = builtin_provider.credentials
credentials = tool_provider_configurations.decrypt(credentials)
entity = ToolTransformService.builtin_provider_to_user_provider(
provider_controller=provider_controller,
db_provider=builtin_provider,
decrypt_credentials=True,
)
entity.original_credentials = {}
return entity
@staticmethod
def list_builtin_provider_credentials_schema(provider_name: str, tenant_id: str):
"""
@ -255,6 +286,7 @@ class BuiltinToolManageService:
@staticmethod
def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None:
try:
full_provider_name = provider_name
provider_id_entity = ToolProviderID(provider_name)
provider_name = provider_id_entity.provider_name
if provider_id_entity.organization != "langgenius":
@ -264,7 +296,8 @@ class BuiltinToolManageService:
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
(BuiltinToolProvider.provider == provider_name) | (BuiltinToolProvider.provider == provider_name),
(BuiltinToolProvider.provider == provider_name)
| (BuiltinToolProvider.provider == full_provider_name),
)
.first()
)

View File

@ -96,5 +96,13 @@ VESSL_AI_MODEL_NAME=
VESSL_AI_API_KEY=
VESSL_AI_ENDPOINT_URL=
# GPUStack Credentials
GPUSTACK_SERVER_URL=
GPUSTACK_API_KEY=
# Gitee AI Credentials
GITEE_AI_API_KEY=
# xAI Credentials
XAI_API_KEY=
XAI_API_BASE=

View File

@ -0,0 +1,49 @@
import os
import pytest
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.gpustack.text_embedding.text_embedding import (
GPUStackTextEmbeddingModel,
)
def test_validate_credentials():
model = GPUStackTextEmbeddingModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model="bge-m3",
credentials={
"endpoint_url": "invalid_url",
"api_key": "invalid_api_key",
},
)
model.validate_credentials(
model="bge-m3",
credentials={
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
"api_key": os.environ.get("GPUSTACK_API_KEY"),
},
)
def test_invoke_model():
model = GPUStackTextEmbeddingModel()
result = model.invoke(
model="bge-m3",
credentials={
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
"api_key": os.environ.get("GPUSTACK_API_KEY"),
"context_size": 8192,
},
texts=["hello", "world"],
user="abc-123",
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 2
assert result.usage.total_tokens == 7

View File

@ -0,0 +1,162 @@
import os
from collections.abc import Generator
import pytest
from core.model_runtime.entities.llm_entities import (
LLMResult,
LLMResultChunk,
LLMResultChunkDelta,
)
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessageTool,
SystemPromptMessage,
UserPromptMessage,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.gpustack.llm.llm import GPUStackLanguageModel
def test_validate_credentials_for_chat_model():
model = GPUStackLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model="llama-3.2-1b-instruct",
credentials={
"endpoint_url": "invalid_url",
"api_key": "invalid_api_key",
"mode": "chat",
},
)
model.validate_credentials(
model="llama-3.2-1b-instruct",
credentials={
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
"api_key": os.environ.get("GPUSTACK_API_KEY"),
"mode": "chat",
},
)
def test_invoke_completion_model():
model = GPUStackLanguageModel()
response = model.invoke(
model="llama-3.2-1b-instruct",
credentials={
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
"api_key": os.environ.get("GPUSTACK_API_KEY"),
"mode": "completion",
},
prompt_messages=[UserPromptMessage(content="ping")],
model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
stop=[],
user="abc-123",
stream=False,
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
assert response.usage.total_tokens > 0
def test_invoke_chat_model():
model = GPUStackLanguageModel()
response = model.invoke(
model="llama-3.2-1b-instruct",
credentials={
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
"api_key": os.environ.get("GPUSTACK_API_KEY"),
"mode": "chat",
},
prompt_messages=[UserPromptMessage(content="ping")],
model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
stop=[],
user="abc-123",
stream=False,
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
assert response.usage.total_tokens > 0
def test_invoke_stream_chat_model():
model = GPUStackLanguageModel()
response = model.invoke(
model="llama-3.2-1b-instruct",
credentials={
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
"api_key": os.environ.get("GPUSTACK_API_KEY"),
"mode": "chat",
},
prompt_messages=[UserPromptMessage(content="Hello World!")],
model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
stop=["you"],
stream=True,
user="abc-123",
)
assert isinstance(response, Generator)
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
def test_get_num_tokens():
model = GPUStackLanguageModel()
num_tokens = model.get_num_tokens(
model="????",
credentials={
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
"api_key": os.environ.get("GPUSTACK_API_KEY"),
"mode": "chat",
},
prompt_messages=[
SystemPromptMessage(
content="You are a helpful AI assistant.",
),
UserPromptMessage(content="Hello World!"),
],
tools=[
PromptMessageTool(
name="get_current_weather",
description="Get the current weather in a given location",
parameters={
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["c", "f"]},
},
"required": ["location"],
},
)
],
)
assert isinstance(num_tokens, int)
assert num_tokens == 80
num_tokens = model.get_num_tokens(
model="????",
credentials={
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
"api_key": os.environ.get("GPUSTACK_API_KEY"),
"mode": "chat",
},
prompt_messages=[UserPromptMessage(content="Hello World!")],
)
assert isinstance(num_tokens, int)
assert num_tokens == 10

View File

@ -0,0 +1,107 @@
import os
import pytest
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.gpustack.rerank.rerank import (
GPUStackRerankModel,
)
def test_validate_credentials_for_rerank_model():
model = GPUStackRerankModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model="bge-reranker-v2-m3",
credentials={
"endpoint_url": "invalid_url",
"api_key": "invalid_api_key",
},
)
model.validate_credentials(
model="bge-reranker-v2-m3",
credentials={
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
"api_key": os.environ.get("GPUSTACK_API_KEY"),
},
)
def test_invoke_rerank_model():
model = GPUStackRerankModel()
response = model.invoke(
model="bge-reranker-v2-m3",
credentials={
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
"api_key": os.environ.get("GPUSTACK_API_KEY"),
},
query="Organic skincare products for sensitive skin",
docs=[
"Eco-friendly kitchenware for modern homes",
"Biodegradable cleaning supplies for eco-conscious consumers",
"Organic cotton baby clothes for sensitive skin",
"Natural organic skincare range for sensitive skin",
"Tech gadgets for smart homes: 2024 edition",
"Sustainable gardening tools and compost solutions",
"Sensitive skin-friendly facial cleansers and toners",
"Organic food wraps and storage solutions",
"Yoga mats made from recycled materials",
],
top_n=3,
score_threshold=-0.75,
user="abc-123",
)
assert isinstance(response, RerankResult)
assert len(response.docs) == 3
def test__invoke():
model = GPUStackRerankModel()
# Test case 1: Empty docs
result = model._invoke(
model="bge-reranker-v2-m3",
credentials={
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
"api_key": os.environ.get("GPUSTACK_API_KEY"),
},
query="Organic skincare products for sensitive skin",
docs=[],
top_n=3,
score_threshold=0.75,
user="abc-123",
)
assert isinstance(result, RerankResult)
assert len(result.docs) == 0
# Test case 2: Expected docs
result = model._invoke(
model="bge-reranker-v2-m3",
credentials={
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
"api_key": os.environ.get("GPUSTACK_API_KEY"),
},
query="Organic skincare products for sensitive skin",
docs=[
"Eco-friendly kitchenware for modern homes",
"Biodegradable cleaning supplies for eco-conscious consumers",
"Organic cotton baby clothes for sensitive skin",
"Natural organic skincare range for sensitive skin",
"Tech gadgets for smart homes: 2024 edition",
"Sustainable gardening tools and compost solutions",
"Sensitive skin-friendly facial cleansers and toners",
"Organic food wraps and storage solutions",
"Yoga mats made from recycled materials",
],
top_n=3,
score_threshold=-0.75,
user="abc-123",
)
assert isinstance(result, RerankResult)
assert len(result.docs) == 3
assert all(isinstance(doc, RerankDocument) for doc in result.docs)

View File

@ -0,0 +1,204 @@
import os
from collections.abc import Generator
import pytest
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessageTool,
SystemPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import AIModelEntity
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.x.llm.llm import XAILargeLanguageModel
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
def test_predefined_models():
model = XAILargeLanguageModel()
model_schemas = model.predefined_models()
assert len(model_schemas) >= 1
assert isinstance(model_schemas[0], AIModelEntity)
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
def test_validate_credentials_for_chat_model(setup_openai_mock):
model = XAILargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
# model name to gpt-3.5-turbo because of mocking
model.validate_credentials(
model="gpt-3.5-turbo",
credentials={"api_key": "invalid_key", "endpoint_url": os.environ.get("XAI_API_BASE"), "mode": "chat"},
)
model.validate_credentials(
model="grok-beta",
credentials={
"api_key": os.environ.get("XAI_API_KEY"),
"endpoint_url": os.environ.get("XAI_API_BASE"),
"mode": "chat",
},
)
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
def test_invoke_chat_model(setup_openai_mock):
model = XAILargeLanguageModel()
result = model.invoke(
model="grok-beta",
credentials={
"api_key": os.environ.get("XAI_API_KEY"),
"endpoint_url": os.environ.get("XAI_API_BASE"),
"mode": "chat",
},
prompt_messages=[
SystemPromptMessage(
content="You are a helpful AI assistant.",
),
UserPromptMessage(content="Hello World!"),
],
model_parameters={
"temperature": 0.0,
"top_p": 1.0,
"presence_penalty": 0.0,
"frequency_penalty": 0.0,
"max_tokens": 10,
},
stop=["How"],
stream=False,
user="foo",
)
assert isinstance(result, LLMResult)
assert len(result.message.content) > 0
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
def test_invoke_chat_model_with_tools(setup_openai_mock):
model = XAILargeLanguageModel()
result = model.invoke(
model="grok-beta",
credentials={
"api_key": os.environ.get("XAI_API_KEY"),
"endpoint_url": os.environ.get("XAI_API_BASE"),
"mode": "chat",
},
prompt_messages=[
SystemPromptMessage(
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content="what's the weather today in London?",
),
],
model_parameters={"temperature": 0.0, "max_tokens": 100},
tools=[
PromptMessageTool(
name="get_weather",
description="Determine weather in my location",
parameters={
"type": "object",
"properties": {
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
"unit": {"type": "string", "enum": ["c", "f"]},
},
"required": ["location"],
},
),
PromptMessageTool(
name="get_stock_price",
description="Get the current stock price",
parameters={
"type": "object",
"properties": {"symbol": {"type": "string", "description": "The stock symbol"}},
"required": ["symbol"],
},
),
],
stream=False,
user="foo",
)
assert isinstance(result, LLMResult)
assert isinstance(result.message, AssistantPromptMessage)
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
def test_invoke_stream_chat_model(setup_openai_mock):
model = XAILargeLanguageModel()
result = model.invoke(
model="grok-beta",
credentials={
"api_key": os.environ.get("XAI_API_KEY"),
"endpoint_url": os.environ.get("XAI_API_BASE"),
"mode": "chat",
},
prompt_messages=[
SystemPromptMessage(
content="You are a helpful AI assistant.",
),
UserPromptMessage(content="Hello World!"),
],
model_parameters={"temperature": 0.0, "max_tokens": 100},
stream=True,
user="foo",
)
assert isinstance(result, Generator)
for chunk in result:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
if chunk.delta.finish_reason is not None:
assert chunk.delta.usage is not None
assert chunk.delta.usage.completion_tokens > 0
def test_get_num_tokens():
model = XAILargeLanguageModel()
num_tokens = model.get_num_tokens(
model="grok-beta",
credentials={"api_key": os.environ.get("XAI_API_KEY"), "endpoint_url": os.environ.get("XAI_API_BASE")},
prompt_messages=[UserPromptMessage(content="Hello World!")],
)
assert num_tokens == 10
num_tokens = model.get_num_tokens(
model="grok-beta",
credentials={"api_key": os.environ.get("XAI_API_KEY"), "endpoint_url": os.environ.get("XAI_API_BASE")},
prompt_messages=[
SystemPromptMessage(
content="You are a helpful AI assistant.",
),
UserPromptMessage(content="Hello World!"),
],
tools=[
PromptMessageTool(
name="get_weather",
description="Determine weather in my location",
parameters={
"type": "object",
"properties": {
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
"unit": {"type": "string", "enum": ["c", "f"]},
},
"required": ["location"],
},
),
],
)
assert num_tokens == 77

View File

@ -0,0 +1,35 @@
import environs
from core.rag.datasource.vdb.lindorm.lindorm_vector import LindormVectorStore, LindormVectorStoreConfig
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, setup_mock_redis
env = environs.Env()
class Config:
SEARCH_ENDPOINT = env.str("SEARCH_ENDPOINT", "http://ld-*************-proxy-search-pub.lindorm.aliyuncs.com:30070")
SEARCH_USERNAME = env.str("SEARCH_USERNAME", "ADMIN")
SEARCH_PWD = env.str("SEARCH_PWD", "PWD")
class TestLindormVectorStore(AbstractVectorTest):
def __init__(self):
super().__init__()
self.vector = LindormVectorStore(
collection_name=self.collection_name,
config=LindormVectorStoreConfig(
hosts=Config.SEARCH_ENDPOINT,
username=Config.SEARCH_USERNAME,
password=Config.SEARCH_PWD,
),
)
def get_ids_by_metadata_field(self):
ids = self.vector.get_ids_by_metadata_field(key="doc_id", value=self.example_doc_id)
assert ids is not None
assert len(ids) == 1
assert ids[0] == self.example_doc_id
def test_lindorm_vector(setup_mock_redis):
TestLindormVectorStore().run_all_tests()

View File

@ -0,0 +1,52 @@
import pytest
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.apps.base_app_generator import BaseAppGenerator
def test_validate_inputs_with_zero():
base_app_generator = BaseAppGenerator()
var = VariableEntity(
variable="test_var",
label="test_var",
type=VariableEntityType.NUMBER,
required=True,
)
# Test with input 0
result = base_app_generator._validate_inputs(
variable_entity=var,
value=0,
)
assert result == 0
# Test with input "0" (string)
result = base_app_generator._validate_inputs(
variable_entity=var,
value="0",
)
assert result == 0
def test_validate_input_with_none_for_required_variable():
base_app_generator = BaseAppGenerator()
for var_type in VariableEntityType:
var = VariableEntity(
variable="test_var",
label="test_var",
type=var_type,
required=True,
)
# Test with input None
with pytest.raises(ValueError) as exc_info:
base_app_generator._validate_inputs(
variable_entity=var,
value=None,
)
assert str(exc_info.value) == "test_var is required in input form"

View File

@ -13,6 +13,7 @@ from core.variables import (
StringVariable,
)
from core.variables.exc import VariableError
from core.variables.segments import ArrayAnySegment
from factories import variable_factory
@ -156,3 +157,9 @@ def test_variable_cannot_large_than_200_kb():
"value": "a" * 1024 * 201,
}
)
def test_array_none_variable():
var = variable_factory.build_segment([None, None, None, None])
assert isinstance(var, ArrayAnySegment)
assert var.value == [None, None, None, None]

View File

@ -0,0 +1,198 @@
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.http_request import (
BodyData,
HttpRequestNodeAuthorization,
HttpRequestNodeBody,
HttpRequestNodeData,
)
from core.workflow.nodes.http_request.entities import HttpRequestNodeTimeout
from core.workflow.nodes.http_request.executor import Executor
def test_executor_with_json_body_and_number_variable():
# Prepare the variable pool
variable_pool = VariablePool(
system_variables={},
user_inputs={},
)
variable_pool.add(["pre_node_id", "number"], 42)
# Prepare the node data
node_data = HttpRequestNodeData(
title="Test JSON Body with Number Variable",
method="post",
url="https://api.example.com/data",
authorization=HttpRequestNodeAuthorization(type="no-auth"),
headers="Content-Type: application/json",
params="",
body=HttpRequestNodeBody(
type="json",
data=[
BodyData(
key="",
type="text",
value='{"number": {{#pre_node_id.number#}}}',
)
],
),
)
# Initialize the Executor
executor = Executor(
node_data=node_data,
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
variable_pool=variable_pool,
)
# Check the executor's data
assert executor.method == "post"
assert executor.url == "https://api.example.com/data"
assert executor.headers == {"Content-Type": "application/json"}
assert executor.params == {}
assert executor.json == {"number": 42}
assert executor.data is None
assert executor.files is None
assert executor.content is None
# Check the raw request (to_log method)
raw_request = executor.to_log()
assert "POST /data HTTP/1.1" in raw_request
assert "Host: api.example.com" in raw_request
assert "Content-Type: application/json" in raw_request
assert '{"number": 42}' in raw_request
def test_executor_with_json_body_and_object_variable():
# Prepare the variable pool
variable_pool = VariablePool(
system_variables={},
user_inputs={},
)
variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"})
# Prepare the node data
node_data = HttpRequestNodeData(
title="Test JSON Body with Object Variable",
method="post",
url="https://api.example.com/data",
authorization=HttpRequestNodeAuthorization(type="no-auth"),
headers="Content-Type: application/json",
params="",
body=HttpRequestNodeBody(
type="json",
data=[
BodyData(
key="",
type="text",
value="{{#pre_node_id.object#}}",
)
],
),
)
# Initialize the Executor
executor = Executor(
node_data=node_data,
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
variable_pool=variable_pool,
)
# Check the executor's data
assert executor.method == "post"
assert executor.url == "https://api.example.com/data"
assert executor.headers == {"Content-Type": "application/json"}
assert executor.params == {}
assert executor.json == {"name": "John Doe", "age": 30, "email": "john@example.com"}
assert executor.data is None
assert executor.files is None
assert executor.content is None
# Check the raw request (to_log method)
raw_request = executor.to_log()
assert "POST /data HTTP/1.1" in raw_request
assert "Host: api.example.com" in raw_request
assert "Content-Type: application/json" in raw_request
assert '"name": "John Doe"' in raw_request
assert '"age": 30' in raw_request
assert '"email": "john@example.com"' in raw_request
def test_executor_with_json_body_and_nested_object_variable():
# Prepare the variable pool
variable_pool = VariablePool(
system_variables={},
user_inputs={},
)
variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"})
# Prepare the node data
node_data = HttpRequestNodeData(
title="Test JSON Body with Nested Object Variable",
method="post",
url="https://api.example.com/data",
authorization=HttpRequestNodeAuthorization(type="no-auth"),
headers="Content-Type: application/json",
params="",
body=HttpRequestNodeBody(
type="json",
data=[
BodyData(
key="",
type="text",
value='{"object": {{#pre_node_id.object#}}}',
)
],
),
)
# Initialize the Executor
executor = Executor(
node_data=node_data,
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
variable_pool=variable_pool,
)
# Check the executor's data
assert executor.method == "post"
assert executor.url == "https://api.example.com/data"
assert executor.headers == {"Content-Type": "application/json"}
assert executor.params == {}
assert executor.json == {"object": {"name": "John Doe", "age": 30, "email": "john@example.com"}}
assert executor.data is None
assert executor.files is None
assert executor.content is None
# Check the raw request (to_log method)
raw_request = executor.to_log()
assert "POST /data HTTP/1.1" in raw_request
assert "Host: api.example.com" in raw_request
assert "Content-Type: application/json" in raw_request
assert '"object": {' in raw_request
assert '"name": "John Doe"' in raw_request
assert '"age": 30' in raw_request
assert '"email": "john@example.com"' in raw_request
def test_extract_selectors_from_template_with_newline():
variable_pool = VariablePool()
variable_pool.add(("node_id", "custom_query"), "line1\nline2")
node_data = HttpRequestNodeData(
title="Test JSON Body with Nested Object Variable",
method="post",
url="https://api.example.com/data",
authorization=HttpRequestNodeAuthorization(type="no-auth"),
headers="Content-Type: application/json",
params="test: {{#node_id.custom_query#}}",
body=HttpRequestNodeBody(
type="none",
data=[],
),
)
executor = Executor(
node_data=node_data,
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
variable_pool=variable_pool,
)
assert executor.params == {"test": "line1\nline2"}

View File

@ -1,5 +1,3 @@
import json
import httpx
from core.app.entities.app_invoke_entities import InvokeFrom
@ -16,8 +14,7 @@ from core.workflow.nodes.http_request import (
HttpRequestNodeBody,
HttpRequestNodeData,
)
from core.workflow.nodes.http_request.entities import HttpRequestNodeTimeout
from core.workflow.nodes.http_request.executor import Executor, _plain_text_to_dict
from core.workflow.nodes.http_request.executor import _plain_text_to_dict
from models.enums import UserFrom
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
@ -203,167 +200,3 @@ def test_http_request_node_form_with_file(monkeypatch):
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs is not None
assert result.outputs["body"] == ""
def test_executor_with_json_body_and_number_variable():
# Prepare the variable pool
variable_pool = VariablePool(
system_variables={},
user_inputs={},
)
variable_pool.add(["pre_node_id", "number"], 42)
# Prepare the node data
node_data = HttpRequestNodeData(
title="Test JSON Body with Number Variable",
method="post",
url="https://api.example.com/data",
authorization=HttpRequestNodeAuthorization(type="no-auth"),
headers="Content-Type: application/json",
params="",
body=HttpRequestNodeBody(
type="json",
data=[
BodyData(
key="",
type="text",
value='{"number": {{#pre_node_id.number#}}}',
)
],
),
)
# Initialize the Executor
executor = Executor(
node_data=node_data,
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
variable_pool=variable_pool,
)
# Check the executor's data
assert executor.method == "post"
assert executor.url == "https://api.example.com/data"
assert executor.headers == {"Content-Type": "application/json"}
assert executor.params == {}
assert executor.json == {"number": 42}
assert executor.data is None
assert executor.files is None
assert executor.content is None
# Check the raw request (to_log method)
raw_request = executor.to_log()
assert "POST /data HTTP/1.1" in raw_request
assert "Host: api.example.com" in raw_request
assert "Content-Type: application/json" in raw_request
assert '{"number": 42}' in raw_request
def test_executor_with_json_body_and_object_variable():
# Prepare the variable pool
variable_pool = VariablePool(
system_variables={},
user_inputs={},
)
variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"})
# Prepare the node data
node_data = HttpRequestNodeData(
title="Test JSON Body with Object Variable",
method="post",
url="https://api.example.com/data",
authorization=HttpRequestNodeAuthorization(type="no-auth"),
headers="Content-Type: application/json",
params="",
body=HttpRequestNodeBody(
type="json",
data=[
BodyData(
key="",
type="text",
value="{{#pre_node_id.object#}}",
)
],
),
)
# Initialize the Executor
executor = Executor(
node_data=node_data,
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
variable_pool=variable_pool,
)
# Check the executor's data
assert executor.method == "post"
assert executor.url == "https://api.example.com/data"
assert executor.headers == {"Content-Type": "application/json"}
assert executor.params == {}
assert executor.json == {"name": "John Doe", "age": 30, "email": "john@example.com"}
assert executor.data is None
assert executor.files is None
assert executor.content is None
# Check the raw request (to_log method)
raw_request = executor.to_log()
assert "POST /data HTTP/1.1" in raw_request
assert "Host: api.example.com" in raw_request
assert "Content-Type: application/json" in raw_request
assert '"name": "John Doe"' in raw_request
assert '"age": 30' in raw_request
assert '"email": "john@example.com"' in raw_request
def test_executor_with_json_body_and_nested_object_variable():
# Prepare the variable pool
variable_pool = VariablePool(
system_variables={},
user_inputs={},
)
variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"})
# Prepare the node data
node_data = HttpRequestNodeData(
title="Test JSON Body with Nested Object Variable",
method="post",
url="https://api.example.com/data",
authorization=HttpRequestNodeAuthorization(type="no-auth"),
headers="Content-Type: application/json",
params="",
body=HttpRequestNodeBody(
type="json",
data=[
BodyData(
key="",
type="text",
value='{"object": {{#pre_node_id.object#}}}',
)
],
),
)
# Initialize the Executor
executor = Executor(
node_data=node_data,
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
variable_pool=variable_pool,
)
# Check the executor's data
assert executor.method == "post"
assert executor.url == "https://api.example.com/data"
assert executor.headers == {"Content-Type": "application/json"}
assert executor.params == {}
assert executor.json == {"object": {"name": "John Doe", "age": 30, "email": "john@example.com"}}
assert executor.data is None
assert executor.files is None
assert executor.content is None
# Check the raw request (to_log method)
raw_request = executor.to_log()
assert "POST /data HTTP/1.1" in raw_request
assert "Host: api.example.com" in raw_request
assert "Content-Type: application/json" in raw_request
assert '"object": {' in raw_request
assert '"name": "John Doe"' in raw_request
assert '"age": 30' in raw_request
assert '"email": "john@example.com"' in raw_request

Some files were not shown because too many files have changed in this diff Show More