mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 03:32:23 +08:00
Merge branch 'fix/chore-fix' into dev/plugin-deploy
Some checks failed
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Some checks failed
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
This commit is contained in:
commit
56f2464a4f
|
@ -81,7 +81,7 @@ Dify requires the following dependencies to build, make sure they're installed o
|
|||
|
||||
Dify is composed of a backend and a frontend. Navigate to the backend directory by `cd api/`, then follow the [Backend README](api/README.md) to install it. In a separate terminal, navigate to the frontend directory by `cd web/`, then follow the [Frontend README](web/README.md) to install.
|
||||
|
||||
Check the [installation FAQ](https://docs.dify.ai/learn-more/faq/self-host-faq) for a list of common issues and steps to troubleshoot.
|
||||
Check the [installation FAQ](https://docs.dify.ai/learn-more/faq/install-faq) for a list of common issues and steps to troubleshoot.
|
||||
|
||||
### 5. Visit dify in your browser
|
||||
|
||||
|
|
|
@ -79,7 +79,7 @@ Dify yêu cầu các phụ thuộc sau để build, hãy đảm bảo chúng đ
|
|||
|
||||
Dify bao gồm một backend và một frontend. Đi đến thư mục backend bằng lệnh `cd api/`, sau đó làm theo hướng dẫn trong [README của Backend](api/README.md) để cài đặt. Trong một terminal khác, đi đến thư mục frontend bằng lệnh `cd web/`, sau đó làm theo hướng dẫn trong [README của Frontend](web/README.md) để cài đặt.
|
||||
|
||||
Kiểm tra [FAQ về cài đặt](https://docs.dify.ai/learn-more/faq/self-host-faq) để xem danh sách các vấn đề thường gặp và các bước khắc phục.
|
||||
Kiểm tra [FAQ về cài đặt](https://docs.dify.ai/learn-more/faq/install-faq) để xem danh sách các vấn đề thường gặp và các bước khắc phục.
|
||||
|
||||
### 5. Truy cập Dify trong trình duyệt của bạn
|
||||
|
||||
|
|
|
@ -120,7 +120,8 @@ SUPABASE_URL=your-server-url
|
|||
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
|
||||
# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash
|
||||
|
||||
# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm
|
||||
VECTOR_STORE=weaviate
|
||||
|
||||
# Weaviate configuration
|
||||
|
@ -263,6 +264,11 @@ VIKINGDB_SCHEMA=http
|
|||
VIKINGDB_CONNECTION_TIMEOUT=30
|
||||
VIKINGDB_SOCKET_TIMEOUT=30
|
||||
|
||||
# Lindorm configuration
|
||||
LINDORM_URL=http://ld-*******************-proxy-search-pub.lindorm.aliyuncs.com:30070
|
||||
LINDORM_USERNAME=admin
|
||||
LINDORM_PASSWORD=admin
|
||||
|
||||
# OceanBase Vector configuration
|
||||
OCEANBASE_VECTOR_HOST=127.0.0.1
|
||||
OCEANBASE_VECTOR_PORT=2881
|
||||
|
@ -271,6 +277,7 @@ OCEANBASE_VECTOR_PASSWORD=
|
|||
OCEANBASE_VECTOR_DATABASE=test
|
||||
OCEANBASE_MEMORY_LIMIT=6G
|
||||
|
||||
|
||||
# Upload configuration
|
||||
UPLOAD_FILE_SIZE_LIMIT=15
|
||||
UPLOAD_FILE_BATCH_LIMIT=5
|
||||
|
@ -320,6 +327,9 @@ SSRF_DEFAULT_MAX_RETRIES=3
|
|||
BATCH_UPLOAD_LIMIT=10
|
||||
KEYWORD_DATA_SOURCE_TYPE=database
|
||||
|
||||
# Workflow file upload limit
|
||||
WORKFLOW_FILE_UPLOAD_LIMIT=10
|
||||
|
||||
# CODE EXECUTION CONFIGURATION
|
||||
CODE_EXECUTION_ENDPOINT=http://127.0.0.1:8194
|
||||
CODE_EXECUTION_API_KEY=dify-sandbox
|
||||
|
|
|
@ -55,12 +55,7 @@ RUN apt-get update \
|
|||
&& echo "deb http://deb.debian.org/debian testing main" > /etc/apt/sources.list \
|
||||
&& apt-get update \
|
||||
# For Security
|
||||
&& apt-get install -y --no-install-recommends expat=2.6.3-2 libldap-2.5-0=2.5.18+dfsg-3+b1 perl=5.40.0-6 libsqlite3-0=3.46.1-1 \
|
||||
&& if [ "$(dpkg --print-architecture)" = "amd64" ]; then \
|
||||
apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1+b1; \
|
||||
else \
|
||||
apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1; \
|
||||
fi \
|
||||
&& apt-get install -y --no-install-recommends expat=2.6.3-2 libldap-2.5-0=2.5.18+dfsg-3+b1 perl=5.40.0-6 libsqlite3-0=3.46.1-1 zlib1g=1:1.3.dfsg+really1.3.1-1+b1 \
|
||||
# install a chinese font to support the use of tools like matplotlib
|
||||
&& apt-get install -y fonts-noto-cjk \
|
||||
&& apt-get autoremove -y \
|
||||
|
|
|
@ -269,6 +269,11 @@ class FileUploadConfig(BaseSettings):
|
|||
default=20,
|
||||
)
|
||||
|
||||
WORKFLOW_FILE_UPLOAD_LIMIT: PositiveInt = Field(
|
||||
description="Maximum number of files allowed in a workflow upload operation",
|
||||
default=10,
|
||||
)
|
||||
|
||||
|
||||
class HttpConfig(BaseSettings):
|
||||
"""
|
||||
|
|
|
@ -20,6 +20,7 @@ from configs.middleware.vdb.baidu_vector_config import BaiduVectorDBConfig
|
|||
from configs.middleware.vdb.chroma_config import ChromaConfig
|
||||
from configs.middleware.vdb.couchbase_config import CouchbaseConfig
|
||||
from configs.middleware.vdb.elasticsearch_config import ElasticsearchConfig
|
||||
from configs.middleware.vdb.lindorm_config import LindormConfig
|
||||
from configs.middleware.vdb.milvus_config import MilvusConfig
|
||||
from configs.middleware.vdb.myscale_config import MyScaleConfig
|
||||
from configs.middleware.vdb.oceanbase_config import OceanBaseVectorConfig
|
||||
|
@ -259,6 +260,7 @@ class MiddlewareConfig(
|
|||
VikingDBConfig,
|
||||
UpstashConfig,
|
||||
TidbOnQdrantConfig,
|
||||
LindormConfig,
|
||||
OceanBaseVectorConfig,
|
||||
BaiduVectorDBConfig,
|
||||
):
|
||||
|
|
23
api/configs/middleware/vdb/lindorm_config.py
Normal file
23
api/configs/middleware/vdb/lindorm_config.py
Normal file
|
@ -0,0 +1,23 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class LindormConfig(BaseSettings):
|
||||
"""
|
||||
Lindorm configs
|
||||
"""
|
||||
|
||||
LINDORM_URL: Optional[str] = Field(
|
||||
description="Lindorm url",
|
||||
default=None,
|
||||
)
|
||||
LINDORM_USERNAME: Optional[str] = Field(
|
||||
description="Lindorm user",
|
||||
default=None,
|
||||
)
|
||||
LINDORM_PASSWORD: Optional[str] = Field(
|
||||
description="Lindorm password",
|
||||
default=None,
|
||||
)
|
24
api/controllers/common/fields.py
Normal file
24
api/controllers/common/fields.py
Normal file
|
@ -0,0 +1,24 @@
|
|||
from flask_restful import fields
|
||||
|
||||
parameters__system_parameters = {
|
||||
"image_file_size_limit": fields.Integer,
|
||||
"video_file_size_limit": fields.Integer,
|
||||
"audio_file_size_limit": fields.Integer,
|
||||
"file_size_limit": fields.Integer,
|
||||
"workflow_file_upload_limit": fields.Integer,
|
||||
}
|
||||
|
||||
parameters_fields = {
|
||||
"opening_statement": fields.String,
|
||||
"suggested_questions": fields.Raw,
|
||||
"suggested_questions_after_answer": fields.Raw,
|
||||
"speech_to_text": fields.Raw,
|
||||
"text_to_speech": fields.Raw,
|
||||
"retriever_resource": fields.Raw,
|
||||
"annotation_reply": fields.Raw,
|
||||
"more_like_this": fields.Raw,
|
||||
"user_input_form": fields.Raw,
|
||||
"sensitive_word_avoidance": fields.Raw,
|
||||
"file_upload": fields.Raw,
|
||||
"system_parameters": fields.Nested(parameters__system_parameters),
|
||||
}
|
|
@ -2,11 +2,15 @@ import mimetypes
|
|||
import os
|
||||
import re
|
||||
import urllib.parse
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
|
||||
class FileInfo(BaseModel):
|
||||
filename: str
|
||||
|
@ -56,3 +60,38 @@ def guess_file_info_from_response(response: httpx.Response):
|
|||
mimetype=mimetype,
|
||||
size=int(response.headers.get("Content-Length", -1)),
|
||||
)
|
||||
|
||||
|
||||
def get_parameters_from_feature_dict(*, features_dict: Mapping[str, Any], user_input_form: list[dict[str, Any]]):
|
||||
return {
|
||||
"opening_statement": features_dict.get("opening_statement"),
|
||||
"suggested_questions": features_dict.get("suggested_questions", []),
|
||||
"suggested_questions_after_answer": features_dict.get("suggested_questions_after_answer", {"enabled": False}),
|
||||
"speech_to_text": features_dict.get("speech_to_text", {"enabled": False}),
|
||||
"text_to_speech": features_dict.get("text_to_speech", {"enabled": False}),
|
||||
"retriever_resource": features_dict.get("retriever_resource", {"enabled": False}),
|
||||
"annotation_reply": features_dict.get("annotation_reply", {"enabled": False}),
|
||||
"more_like_this": features_dict.get("more_like_this", {"enabled": False}),
|
||||
"user_input_form": user_input_form,
|
||||
"sensitive_word_avoidance": features_dict.get(
|
||||
"sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []}
|
||||
),
|
||||
"file_upload": features_dict.get(
|
||||
"file_upload",
|
||||
{
|
||||
"image": {
|
||||
"enabled": False,
|
||||
"number_limits": 3,
|
||||
"detail": "high",
|
||||
"transfer_methods": ["remote_url", "local_file"],
|
||||
}
|
||||
},
|
||||
),
|
||||
"system_parameters": {
|
||||
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
|
||||
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
|
||||
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
|
||||
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
|
||||
"workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT,
|
||||
},
|
||||
}
|
||||
|
|
|
@ -456,7 +456,7 @@ class DatasetIndexingEstimateApi(Resource):
|
|||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider " "in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
@ -620,6 +620,7 @@ class DatasetRetrievalSettingApi(Resource):
|
|||
case (
|
||||
VectorType.MILVUS
|
||||
| VectorType.RELYT
|
||||
| VectorType.PGVECTOR
|
||||
| VectorType.TIDB_VECTOR
|
||||
| VectorType.CHROMA
|
||||
| VectorType.TENCENT
|
||||
|
@ -640,6 +641,7 @@ class DatasetRetrievalSettingApi(Resource):
|
|||
| VectorType.ELASTICSEARCH
|
||||
| VectorType.PGVECTOR
|
||||
| VectorType.TIDB_ON_QDRANT
|
||||
| VectorType.LINDORM
|
||||
| VectorType.COUCHBASE
|
||||
):
|
||||
return {
|
||||
|
@ -682,6 +684,7 @@ class DatasetRetrievalSettingMockApi(Resource):
|
|||
| VectorType.ELASTICSEARCH
|
||||
| VectorType.COUCHBASE
|
||||
| VectorType.PGVECTOR
|
||||
| VectorType.LINDORM
|
||||
):
|
||||
return {
|
||||
"retrieval_method": [
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from flask_restful import fields, marshal_with
|
||||
from flask_restful import marshal_with
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common import fields
|
||||
from controllers.common import helpers as controller_helpers
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import AppUnavailableError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
|
@ -11,43 +12,14 @@ from services.app_service import AppService
|
|||
class AppParameterApi(InstalledAppResource):
|
||||
"""Resource for app variables."""
|
||||
|
||||
variable_fields = {
|
||||
"key": fields.String,
|
||||
"name": fields.String,
|
||||
"description": fields.String,
|
||||
"type": fields.String,
|
||||
"default": fields.String,
|
||||
"max_length": fields.Integer,
|
||||
"options": fields.List(fields.String),
|
||||
}
|
||||
|
||||
system_parameters_fields = {
|
||||
"image_file_size_limit": fields.Integer,
|
||||
"video_file_size_limit": fields.Integer,
|
||||
"audio_file_size_limit": fields.Integer,
|
||||
"file_size_limit": fields.Integer,
|
||||
}
|
||||
|
||||
parameters_fields = {
|
||||
"opening_statement": fields.String,
|
||||
"suggested_questions": fields.Raw,
|
||||
"suggested_questions_after_answer": fields.Raw,
|
||||
"speech_to_text": fields.Raw,
|
||||
"text_to_speech": fields.Raw,
|
||||
"retriever_resource": fields.Raw,
|
||||
"annotation_reply": fields.Raw,
|
||||
"more_like_this": fields.Raw,
|
||||
"user_input_form": fields.Raw,
|
||||
"sensitive_word_avoidance": fields.Raw,
|
||||
"file_upload": fields.Raw,
|
||||
"system_parameters": fields.Nested(system_parameters_fields),
|
||||
}
|
||||
|
||||
@marshal_with(parameters_fields)
|
||||
@marshal_with(fields.parameters_fields)
|
||||
def get(self, installed_app: InstalledApp):
|
||||
"""Retrieve app parameters."""
|
||||
app_model = installed_app.app
|
||||
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||
workflow = app_model.workflow
|
||||
if workflow is None:
|
||||
|
@ -57,43 +29,16 @@ class AppParameterApi(InstalledAppResource):
|
|||
user_input_form = workflow.user_input_form(to_old_structure=True)
|
||||
else:
|
||||
app_model_config = app_model.app_model_config
|
||||
if app_model_config is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict = app_model_config.to_dict()
|
||||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
return {
|
||||
"opening_statement": features_dict.get("opening_statement"),
|
||||
"suggested_questions": features_dict.get("suggested_questions", []),
|
||||
"suggested_questions_after_answer": features_dict.get(
|
||||
"suggested_questions_after_answer", {"enabled": False}
|
||||
),
|
||||
"speech_to_text": features_dict.get("speech_to_text", {"enabled": False}),
|
||||
"text_to_speech": features_dict.get("text_to_speech", {"enabled": False}),
|
||||
"retriever_resource": features_dict.get("retriever_resource", {"enabled": False}),
|
||||
"annotation_reply": features_dict.get("annotation_reply", {"enabled": False}),
|
||||
"more_like_this": features_dict.get("more_like_this", {"enabled": False}),
|
||||
"user_input_form": user_input_form,
|
||||
"sensitive_word_avoidance": features_dict.get(
|
||||
"sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []}
|
||||
),
|
||||
"file_upload": features_dict.get(
|
||||
"file_upload",
|
||||
{
|
||||
"image": {
|
||||
"enabled": False,
|
||||
"number_limits": 3,
|
||||
"detail": "high",
|
||||
"transfer_methods": ["remote_url", "local_file"],
|
||||
}
|
||||
},
|
||||
),
|
||||
"system_parameters": {
|
||||
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
|
||||
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
|
||||
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
|
||||
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
|
||||
},
|
||||
}
|
||||
return controller_helpers.get_parameters_from_feature_dict(
|
||||
features_dict=features_dict, user_input_form=user_input_form
|
||||
)
|
||||
|
||||
|
||||
class ExploreAppMetaApi(InstalledAppResource):
|
||||
|
|
|
@ -37,6 +37,7 @@ class FileApi(Resource):
|
|||
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
|
||||
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
|
||||
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
|
||||
"workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT,
|
||||
}, 200
|
||||
|
||||
@setup_required
|
||||
|
|
|
@ -61,6 +61,19 @@ class ToolBuiltinProviderListToolsApi(Resource):
|
|||
)
|
||||
|
||||
|
||||
class ToolBuiltinProviderInfoApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
user = current_user
|
||||
|
||||
user_id = user.id
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(user_id, tenant_id, provider))
|
||||
|
||||
|
||||
class ToolBuiltinProviderDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
@ -604,6 +617,7 @@ api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers")
|
|||
|
||||
# builtin tool provider
|
||||
api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/tools")
|
||||
api.add_resource(ToolBuiltinProviderInfoApi, "/workspaces/current/tool-provider/builtin/<path:provider>/info")
|
||||
api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin/<path:provider>/delete")
|
||||
api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin/<path:provider>/update")
|
||||
api.add_resource(
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from flask_restful import Resource, fields, marshal_with
|
||||
from flask_restful import Resource, marshal_with
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common import fields
|
||||
from controllers.common import helpers as controller_helpers
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app.error import AppUnavailableError
|
||||
from controllers.service_api.wraps import validate_app_token
|
||||
|
@ -11,40 +12,8 @@ from services.app_service import AppService
|
|||
class AppParameterApi(Resource):
|
||||
"""Resource for app variables."""
|
||||
|
||||
variable_fields = {
|
||||
"key": fields.String,
|
||||
"name": fields.String,
|
||||
"description": fields.String,
|
||||
"type": fields.String,
|
||||
"default": fields.String,
|
||||
"max_length": fields.Integer,
|
||||
"options": fields.List(fields.String),
|
||||
}
|
||||
|
||||
system_parameters_fields = {
|
||||
"image_file_size_limit": fields.Integer,
|
||||
"video_file_size_limit": fields.Integer,
|
||||
"audio_file_size_limit": fields.Integer,
|
||||
"file_size_limit": fields.Integer,
|
||||
}
|
||||
|
||||
parameters_fields = {
|
||||
"opening_statement": fields.String,
|
||||
"suggested_questions": fields.Raw,
|
||||
"suggested_questions_after_answer": fields.Raw,
|
||||
"speech_to_text": fields.Raw,
|
||||
"text_to_speech": fields.Raw,
|
||||
"retriever_resource": fields.Raw,
|
||||
"annotation_reply": fields.Raw,
|
||||
"more_like_this": fields.Raw,
|
||||
"user_input_form": fields.Raw,
|
||||
"sensitive_word_avoidance": fields.Raw,
|
||||
"file_upload": fields.Raw,
|
||||
"system_parameters": fields.Nested(system_parameters_fields),
|
||||
}
|
||||
|
||||
@validate_app_token
|
||||
@marshal_with(parameters_fields)
|
||||
@marshal_with(fields.parameters_fields)
|
||||
def get(self, app_model: App):
|
||||
"""Retrieve app parameters."""
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||
|
@ -56,43 +25,16 @@ class AppParameterApi(Resource):
|
|||
user_input_form = workflow.user_input_form(to_old_structure=True)
|
||||
else:
|
||||
app_model_config = app_model.app_model_config
|
||||
if app_model_config is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict = app_model_config.to_dict()
|
||||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
return {
|
||||
"opening_statement": features_dict.get("opening_statement"),
|
||||
"suggested_questions": features_dict.get("suggested_questions", []),
|
||||
"suggested_questions_after_answer": features_dict.get(
|
||||
"suggested_questions_after_answer", {"enabled": False}
|
||||
),
|
||||
"speech_to_text": features_dict.get("speech_to_text", {"enabled": False}),
|
||||
"text_to_speech": features_dict.get("text_to_speech", {"enabled": False}),
|
||||
"retriever_resource": features_dict.get("retriever_resource", {"enabled": False}),
|
||||
"annotation_reply": features_dict.get("annotation_reply", {"enabled": False}),
|
||||
"more_like_this": features_dict.get("more_like_this", {"enabled": False}),
|
||||
"user_input_form": user_input_form,
|
||||
"sensitive_word_avoidance": features_dict.get(
|
||||
"sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []}
|
||||
),
|
||||
"file_upload": features_dict.get(
|
||||
"file_upload",
|
||||
{
|
||||
"image": {
|
||||
"enabled": False,
|
||||
"number_limits": 3,
|
||||
"detail": "high",
|
||||
"transfer_methods": ["remote_url", "local_file"],
|
||||
}
|
||||
},
|
||||
),
|
||||
"system_parameters": {
|
||||
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
|
||||
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
|
||||
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
|
||||
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
|
||||
},
|
||||
}
|
||||
return controller_helpers.get_parameters_from_feature_dict(
|
||||
features_dict=features_dict, user_input_form=user_input_form
|
||||
)
|
||||
|
||||
|
||||
class AppMetaApi(Resource):
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from flask_restful import fields, marshal_with
|
||||
from flask_restful import marshal_with
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common import fields
|
||||
from controllers.common import helpers as controller_helpers
|
||||
from controllers.web import api
|
||||
from controllers.web.error import AppUnavailableError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
|
@ -11,39 +12,7 @@ from services.app_service import AppService
|
|||
class AppParameterApi(WebApiResource):
|
||||
"""Resource for app variables."""
|
||||
|
||||
variable_fields = {
|
||||
"key": fields.String,
|
||||
"name": fields.String,
|
||||
"description": fields.String,
|
||||
"type": fields.String,
|
||||
"default": fields.String,
|
||||
"max_length": fields.Integer,
|
||||
"options": fields.List(fields.String),
|
||||
}
|
||||
|
||||
system_parameters_fields = {
|
||||
"image_file_size_limit": fields.Integer,
|
||||
"video_file_size_limit": fields.Integer,
|
||||
"audio_file_size_limit": fields.Integer,
|
||||
"file_size_limit": fields.Integer,
|
||||
}
|
||||
|
||||
parameters_fields = {
|
||||
"opening_statement": fields.String,
|
||||
"suggested_questions": fields.Raw,
|
||||
"suggested_questions_after_answer": fields.Raw,
|
||||
"speech_to_text": fields.Raw,
|
||||
"text_to_speech": fields.Raw,
|
||||
"retriever_resource": fields.Raw,
|
||||
"annotation_reply": fields.Raw,
|
||||
"more_like_this": fields.Raw,
|
||||
"user_input_form": fields.Raw,
|
||||
"sensitive_word_avoidance": fields.Raw,
|
||||
"file_upload": fields.Raw,
|
||||
"system_parameters": fields.Nested(system_parameters_fields),
|
||||
}
|
||||
|
||||
@marshal_with(parameters_fields)
|
||||
@marshal_with(fields.parameters_fields)
|
||||
def get(self, app_model: App, end_user):
|
||||
"""Retrieve app parameters."""
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||
|
@ -55,43 +24,16 @@ class AppParameterApi(WebApiResource):
|
|||
user_input_form = workflow.user_input_form(to_old_structure=True)
|
||||
else:
|
||||
app_model_config = app_model.app_model_config
|
||||
if app_model_config is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict = app_model_config.to_dict()
|
||||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
return {
|
||||
"opening_statement": features_dict.get("opening_statement"),
|
||||
"suggested_questions": features_dict.get("suggested_questions", []),
|
||||
"suggested_questions_after_answer": features_dict.get(
|
||||
"suggested_questions_after_answer", {"enabled": False}
|
||||
),
|
||||
"speech_to_text": features_dict.get("speech_to_text", {"enabled": False}),
|
||||
"text_to_speech": features_dict.get("text_to_speech", {"enabled": False}),
|
||||
"retriever_resource": features_dict.get("retriever_resource", {"enabled": False}),
|
||||
"annotation_reply": features_dict.get("annotation_reply", {"enabled": False}),
|
||||
"more_like_this": features_dict.get("more_like_this", {"enabled": False}),
|
||||
"user_input_form": user_input_form,
|
||||
"sensitive_word_avoidance": features_dict.get(
|
||||
"sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []}
|
||||
),
|
||||
"file_upload": features_dict.get(
|
||||
"file_upload",
|
||||
{
|
||||
"image": {
|
||||
"enabled": False,
|
||||
"number_limits": 3,
|
||||
"detail": "high",
|
||||
"transfer_methods": ["remote_url", "local_file"],
|
||||
}
|
||||
},
|
||||
),
|
||||
"system_parameters": {
|
||||
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
|
||||
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
|
||||
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
|
||||
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
|
||||
},
|
||||
}
|
||||
return controller_helpers.get_parameters_from_feature_dict(
|
||||
features_dict=features_dict, user_input_form=user_input_form
|
||||
)
|
||||
|
||||
|
||||
class AppMeta(WebApiResource):
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import urllib.parse
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import marshal_with, reqparse
|
||||
|
||||
from controllers.common import helpers
|
||||
|
@ -27,7 +26,7 @@ class RemoteFileInfoApi(WebApiResource):
|
|||
|
||||
class RemoteFileUploadApi(WebApiResource):
|
||||
@marshal_with(file_fields_with_signed_url)
|
||||
def post(self):
|
||||
def post(self, app_model, end_user): # Add app_model and end_user parameters
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("url", type=str, required=True, help="URL is required")
|
||||
args = parser.parse_args()
|
||||
|
@ -51,7 +50,7 @@ class RemoteFileUploadApi(WebApiResource):
|
|||
filename=file_info.filename,
|
||||
content=content,
|
||||
mimetype=file_info.mimetype,
|
||||
user=current_user,
|
||||
user=end_user, # Use end_user instead of current_user
|
||||
source_url=url,
|
||||
)
|
||||
except Exception as e:
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.file.models import FileExtraConfig
|
||||
from models import FileUploadConfig
|
||||
from core.file import FileExtraConfig
|
||||
|
||||
|
||||
class FileUploadConfigManager:
|
||||
|
@ -43,6 +42,6 @@ class FileUploadConfigManager:
|
|||
if not config.get("file_upload"):
|
||||
config["file_upload"] = {}
|
||||
else:
|
||||
FileUploadConfig.model_validate(config["file_upload"])
|
||||
FileExtraConfig.model_validate(config["file_upload"])
|
||||
|
||||
return config, ["file_upload"]
|
||||
|
|
|
@ -20,6 +20,7 @@ from core.app.entities.queue_entities import (
|
|||
QueueIterationStartEvent,
|
||||
QueueMessageReplaceEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
|
@ -314,7 +315,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueNodeFailedEvent):
|
||||
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
|
||||
|
||||
response = self._workflow_node_finish_to_stream_response(
|
||||
|
|
|
@ -23,7 +23,10 @@ class BaseAppGenerator:
|
|||
user_inputs = user_inputs or {}
|
||||
# Filter input variables from form configuration, handle required fields, default values, and option values
|
||||
variables = app_config.variables
|
||||
user_inputs = {var.variable: self._validate_input(inputs=user_inputs, var=var) for var in variables}
|
||||
user_inputs = {
|
||||
var.variable: self._validate_inputs(value=user_inputs.get(var.variable), variable_entity=var)
|
||||
for var in variables
|
||||
}
|
||||
user_inputs = {k: self._sanitize_value(v) for k, v in user_inputs.items()}
|
||||
# Convert files in inputs to File
|
||||
entity_dictionary = {item.variable: item for item in app_config.variables}
|
||||
|
@ -75,50 +78,66 @@ class BaseAppGenerator:
|
|||
|
||||
return user_inputs
|
||||
|
||||
def _validate_input(self, *, inputs: Mapping[str, Any], var: "VariableEntity"):
|
||||
user_input_value = inputs.get(var.variable)
|
||||
if not user_input_value:
|
||||
if var.required:
|
||||
raise ValueError(f"{var.variable} is required in input form")
|
||||
else:
|
||||
return None
|
||||
def _validate_inputs(
|
||||
self,
|
||||
*,
|
||||
variable_entity: "VariableEntity",
|
||||
value: Any,
|
||||
):
|
||||
if value is None:
|
||||
if variable_entity.required:
|
||||
raise ValueError(f"{variable_entity.variable} is required in input form")
|
||||
return value
|
||||
|
||||
if var.type in {
|
||||
if variable_entity.type in {
|
||||
VariableEntityType.TEXT_INPUT,
|
||||
VariableEntityType.SELECT,
|
||||
VariableEntityType.PARAGRAPH,
|
||||
} and not isinstance(user_input_value, str):
|
||||
raise ValueError(f"(type '{var.type}') {var.variable} in input form must be a string")
|
||||
if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str):
|
||||
} and not isinstance(value, str):
|
||||
raise ValueError(
|
||||
f"(type '{variable_entity.type}') {variable_entity.variable} in input form must be a string"
|
||||
)
|
||||
|
||||
if variable_entity.type == VariableEntityType.NUMBER and isinstance(value, str):
|
||||
# may raise ValueError if user_input_value is not a valid number
|
||||
try:
|
||||
if "." in user_input_value:
|
||||
return float(user_input_value)
|
||||
if "." in value:
|
||||
return float(value)
|
||||
else:
|
||||
return int(user_input_value)
|
||||
return int(value)
|
||||
except ValueError:
|
||||
raise ValueError(f"{var.variable} in input form must be a valid number")
|
||||
if var.type == VariableEntityType.SELECT:
|
||||
options = var.options
|
||||
if user_input_value not in options:
|
||||
raise ValueError(f"{var.variable} in input form must be one of the following: {options}")
|
||||
elif var.type in {VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH}:
|
||||
if var.max_length and len(user_input_value) > var.max_length:
|
||||
raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters")
|
||||
elif var.type == VariableEntityType.FILE:
|
||||
if not isinstance(user_input_value, dict) and not isinstance(user_input_value, File):
|
||||
raise ValueError(f"{var.variable} in input form must be a file")
|
||||
elif var.type == VariableEntityType.FILE_LIST:
|
||||
if not (
|
||||
isinstance(user_input_value, list)
|
||||
and (
|
||||
all(isinstance(item, dict) for item in user_input_value)
|
||||
or all(isinstance(item, File) for item in user_input_value)
|
||||
)
|
||||
):
|
||||
raise ValueError(f"{var.variable} in input form must be a list of files")
|
||||
raise ValueError(f"{variable_entity.variable} in input form must be a valid number")
|
||||
|
||||
return user_input_value
|
||||
match variable_entity.type:
|
||||
case VariableEntityType.SELECT:
|
||||
if value not in variable_entity.options:
|
||||
raise ValueError(
|
||||
f"{variable_entity.variable} in input form must be one of the following: "
|
||||
f"{variable_entity.options}"
|
||||
)
|
||||
case VariableEntityType.TEXT_INPUT | VariableEntityType.PARAGRAPH:
|
||||
if variable_entity.max_length and len(value) > variable_entity.max_length:
|
||||
raise ValueError(
|
||||
f"{variable_entity.variable} in input form must be less than {variable_entity.max_length} "
|
||||
"characters"
|
||||
)
|
||||
case VariableEntityType.FILE:
|
||||
if not isinstance(value, dict) and not isinstance(value, File):
|
||||
raise ValueError(f"{variable_entity.variable} in input form must be a file")
|
||||
case VariableEntityType.FILE_LIST:
|
||||
# if number of files exceeds the limit, raise ValueError
|
||||
if not (
|
||||
isinstance(value, list)
|
||||
and (all(isinstance(item, dict) for item in value) or all(isinstance(item, File) for item in value))
|
||||
):
|
||||
raise ValueError(f"{variable_entity.variable} in input form must be a list of files")
|
||||
|
||||
if variable_entity.max_length and len(value) > variable_entity.max_length:
|
||||
raise ValueError(
|
||||
f"{variable_entity.variable} in input form must be less than {variable_entity.max_length} files"
|
||||
)
|
||||
|
||||
return value
|
||||
|
||||
def _sanitize_value(self, value: Any) -> Any:
|
||||
if isinstance(value, str):
|
||||
|
|
|
@ -16,6 +16,7 @@ from core.app.entities.queue_entities import (
|
|||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
|
@ -275,7 +276,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueNodeFailedEvent):
|
||||
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
|
||||
|
||||
response = self._workflow_node_finish_to_stream_response(
|
||||
|
|
|
@ -9,6 +9,7 @@ from core.app.entities.queue_entities import (
|
|||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
|
@ -30,6 +31,7 @@ from core.workflow.graph_engine.entities.event import (
|
|||
IterationRunNextEvent,
|
||||
IterationRunStartedEvent,
|
||||
IterationRunSucceededEvent,
|
||||
NodeInIterationFailedEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunRetrieverResourceEvent,
|
||||
NodeRunStartedEvent,
|
||||
|
@ -193,6 +195,7 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||
node_run_index=event.route_node_state.index,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunSucceededEvent):
|
||||
|
@ -246,9 +249,40 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||
error=event.route_node_state.node_run_result.error
|
||||
if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error
|
||||
else "Unknown error",
|
||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeInIterationFailedEvent):
|
||||
self._publish_event(
|
||||
QueueNodeInIterationFailedEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
inputs=event.route_node_state.node_run_result.inputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
error=event.error,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
self._publish_event(
|
||||
QueueTextChunkEvent(
|
||||
|
@ -326,6 +360,7 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||
index=event.index,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
output=event.pre_iteration_output,
|
||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, (IterationRunSucceededEvent | IterationRunFailedEvent)):
|
||||
|
|
|
@ -107,7 +107,8 @@ class QueueIterationNextEvent(AppQueueEvent):
|
|||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
|
||||
parallel_mode_run_id: Optional[str] = None
|
||||
"""iteratoin run in parallel mode run id"""
|
||||
node_run_index: int
|
||||
output: Optional[Any] = None # output for the current iteration
|
||||
|
||||
|
@ -273,6 +274,8 @@ class QueueNodeStartedEvent(AppQueueEvent):
|
|||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
start_at: datetime
|
||||
parallel_mode_run_id: Optional[str] = None
|
||||
"""iteratoin run in parallel mode run id"""
|
||||
|
||||
|
||||
class QueueNodeSucceededEvent(AppQueueEvent):
|
||||
|
@ -306,6 +309,37 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
|||
error: Optional[str] = None
|
||||
|
||||
|
||||
class QueueNodeInIterationFailedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueNodeInIterationFailedEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.NODE_FAILED
|
||||
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
process_data: Optional[dict[str, Any]] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
|
||||
|
||||
error: str
|
||||
|
||||
|
||||
class QueueNodeFailedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueNodeFailedEvent entity
|
||||
|
@ -332,6 +366,7 @@ class QueueNodeFailedEvent(AppQueueEvent):
|
|||
inputs: Optional[dict[str, Any]] = None
|
||||
process_data: Optional[dict[str, Any]] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
|
||||
|
||||
error: str
|
||||
|
||||
|
|
|
@ -244,6 +244,7 @@ class NodeStartStreamResponse(StreamResponse):
|
|||
parent_parallel_id: Optional[str] = None
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
iteration_id: Optional[str] = None
|
||||
parallel_run_id: Optional[str] = None
|
||||
|
||||
event: StreamEvent = StreamEvent.NODE_STARTED
|
||||
workflow_run_id: str
|
||||
|
@ -432,6 +433,7 @@ class IterationNodeNextStreamResponse(StreamResponse):
|
|||
extras: dict = {}
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
parallel_mode_run_id: Optional[str] = None
|
||||
|
||||
event: StreamEvent = StreamEvent.ITERATION_NEXT
|
||||
workflow_run_id: str
|
||||
|
|
|
@ -12,6 +12,7 @@ from core.app.entities.queue_entities import (
|
|||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
|
@ -35,6 +36,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
|||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
|
@ -251,6 +253,12 @@ class WorkflowCycleManage:
|
|||
workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value
|
||||
workflow_node_execution.created_by_role = workflow_run.created_by_role
|
||||
workflow_node_execution.created_by = workflow_run.created_by
|
||||
workflow_node_execution.execution_metadata = json.dumps(
|
||||
{
|
||||
NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
|
||||
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
|
||||
}
|
||||
)
|
||||
workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
|
||||
session.add(workflow_node_execution)
|
||||
|
@ -305,7 +313,9 @@ class WorkflowCycleManage:
|
|||
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_workflow_node_execution_failed(self, event: QueueNodeFailedEvent) -> WorkflowNodeExecution:
|
||||
def _handle_workflow_node_execution_failed(
|
||||
self, event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent
|
||||
) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Workflow node execution failed
|
||||
:param event: queue node failed event
|
||||
|
@ -318,16 +328,19 @@ class WorkflowCycleManage:
|
|||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||
finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
elapsed_time = (finished_at - event.start_at).total_seconds()
|
||||
|
||||
execution_metadata = (
|
||||
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
|
||||
)
|
||||
db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update(
|
||||
{
|
||||
WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.FAILED.value,
|
||||
WorkflowNodeExecution.error: event.error,
|
||||
WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None,
|
||||
WorkflowNodeExecution.process_data: json.dumps(process_data) if event.process_data else None,
|
||||
WorkflowNodeExecution.process_data: json.dumps(event.process_data) if event.process_data else None,
|
||||
WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None,
|
||||
WorkflowNodeExecution.finished_at: finished_at,
|
||||
WorkflowNodeExecution.elapsed_time: elapsed_time,
|
||||
WorkflowNodeExecution.execution_metadata: execution_metadata,
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -342,6 +355,7 @@ class WorkflowCycleManage:
|
|||
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
||||
workflow_node_execution.finished_at = finished_at
|
||||
workflow_node_execution.elapsed_time = elapsed_time
|
||||
workflow_node_execution.execution_metadata = execution_metadata
|
||||
|
||||
self._wip_workflow_node_executions.pop(workflow_node_execution.node_execution_id)
|
||||
|
||||
|
@ -448,6 +462,7 @@ class WorkflowCycleManage:
|
|||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
parallel_run_id=event.parallel_mode_run_id,
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -464,7 +479,7 @@ class WorkflowCycleManage:
|
|||
|
||||
def _workflow_node_finish_to_stream_response(
|
||||
self,
|
||||
event: QueueNodeSucceededEvent | QueueNodeFailedEvent,
|
||||
event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeInIterationFailedEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
) -> Optional[NodeFinishStreamResponse]:
|
||||
|
@ -608,6 +623,7 @@ class WorkflowCycleManage:
|
|||
extras={},
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -633,7 +649,9 @@ class WorkflowCycleManage:
|
|||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=event.inputs or {},
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
if event.error is None
|
||||
else WorkflowNodeExecutionStatus.FAILED,
|
||||
error=None,
|
||||
elapsed_time=(datetime.now(timezone.utc).replace(tzinfo=None) - event.start_at).total_seconds(),
|
||||
total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0,
|
||||
|
|
|
@ -12,7 +12,8 @@ class CommonParameterType(Enum):
|
|||
SYSTEM_FILES = "system-files"
|
||||
BOOLEAN = "boolean"
|
||||
APP_SELECTOR = "app-selector"
|
||||
MODEL_CONFIG = "model-config"
|
||||
TOOL_SELECTOR = "tool-selector"
|
||||
MODEL_SELECTOR = "model-selector"
|
||||
|
||||
|
||||
class AppSelectorScope(Enum):
|
||||
|
@ -22,7 +23,7 @@ class AppSelectorScope(Enum):
|
|||
COMPLETION = "completion"
|
||||
|
||||
|
||||
class ModelConfigScope(Enum):
|
||||
class ModelSelectorScope(Enum):
|
||||
LLM = "llm"
|
||||
TEXT_EMBEDDING = "text-embedding"
|
||||
RERANK = "rerank"
|
||||
|
@ -30,3 +31,10 @@ class ModelConfigScope(Enum):
|
|||
SPEECH2TEXT = "speech2text"
|
||||
MODERATION = "moderation"
|
||||
VISION = "vision"
|
||||
|
||||
|
||||
class ToolSelectorScope(Enum):
|
||||
ALL = "all"
|
||||
CUSTOM = "custom"
|
||||
BUILTIN = "builtin"
|
||||
WORKFLOW = "workflow"
|
||||
|
|
|
@ -3,7 +3,12 @@ from typing import Optional, Union
|
|||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.entities.parameter_entities import AppSelectorScope, CommonParameterType, ModelConfigScope
|
||||
from core.entities.parameter_entities import (
|
||||
AppSelectorScope,
|
||||
CommonParameterType,
|
||||
ModelSelectorScope,
|
||||
ToolSelectorScope,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
|
||||
|
@ -140,7 +145,8 @@ class BasicProviderConfig(BaseModel):
|
|||
SELECT = CommonParameterType.SELECT.value
|
||||
BOOLEAN = CommonParameterType.BOOLEAN.value
|
||||
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
|
||||
MODEL_CONFIG = CommonParameterType.MODEL_CONFIG.value
|
||||
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value
|
||||
TOOL_SELECTOR = CommonParameterType.TOOL_SELECTOR.value
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "ProviderConfig.Type":
|
||||
|
@ -168,7 +174,7 @@ class ProviderConfig(BasicProviderConfig):
|
|||
value: str = Field(..., description="The value of the option")
|
||||
label: I18nObject = Field(..., description="The label of the option")
|
||||
|
||||
scope: AppSelectorScope | ModelConfigScope | None = None
|
||||
scope: AppSelectorScope | ModelSelectorScope | ToolSelectorScope | None = None
|
||||
required: bool = False
|
||||
default: Optional[Union[int, str]] = None
|
||||
options: Optional[list[Option]] = None
|
||||
|
|
|
@ -598,7 +598,7 @@ class IndexingRunner:
|
|||
rules = DatasetProcessRule.AUTOMATIC_RULES
|
||||
else:
|
||||
rules = json.loads(processing_rule.rules) if processing_rule.rules else {}
|
||||
document_text = CleanProcessor.clean(text, rules)
|
||||
document_text = CleanProcessor.clean(text, {"rules": rules})
|
||||
|
||||
return document_text
|
||||
|
||||
|
|
|
@ -10,8 +10,15 @@ from core.model_runtime.entities.model_entities import (
|
|||
PriceInfo,
|
||||
PriceType,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.plugin.entities.plugin_daemon import PluginDaemonInnerError, PluginModelProviderEntity
|
||||
from core.plugin.manager.model import PluginModelManager
|
||||
|
||||
|
||||
|
@ -31,7 +38,7 @@ class AIModel(BaseModel):
|
|||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
def _invoke_error_mapping(self) -> dict[type[Exception], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
|
@ -40,9 +47,17 @@ class AIModel(BaseModel):
|
|||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
raise NotImplementedError
|
||||
return {
|
||||
InvokeConnectionError: [InvokeConnectionError],
|
||||
InvokeServerUnavailableError: [InvokeServerUnavailableError],
|
||||
InvokeRateLimitError: [InvokeRateLimitError],
|
||||
InvokeAuthorizationError: [InvokeAuthorizationError],
|
||||
InvokeBadRequestError: [InvokeBadRequestError],
|
||||
PluginDaemonInnerError: [PluginDaemonInnerError],
|
||||
ValueError: [ValueError],
|
||||
}
|
||||
|
||||
def _transform_invoke_error(self, error: Exception) -> InvokeError:
|
||||
def _transform_invoke_error(self, error: Exception) -> Exception:
|
||||
"""
|
||||
Transform invoke error to unified error
|
||||
|
||||
|
@ -52,13 +67,15 @@ class AIModel(BaseModel):
|
|||
for invoke_error, model_errors in self._invoke_error_mapping.items():
|
||||
if isinstance(error, tuple(model_errors)):
|
||||
if invoke_error == InvokeAuthorizationError:
|
||||
return invoke_error(
|
||||
return InvokeAuthorizationError(
|
||||
description=(
|
||||
f"[{self.provider_name}] Incorrect model credentials provided, please check and try again."
|
||||
)
|
||||
)
|
||||
|
||||
elif isinstance(invoke_error, InvokeError):
|
||||
return invoke_error(description=f"[{self.provider_name}] {invoke_error.description}, {str(error)}")
|
||||
else:
|
||||
return error
|
||||
|
||||
return InvokeError(description=f"[{self.provider_name}] Error: {str(error)}")
|
||||
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
model: anthropic.claude-3-5-haiku-20241022-v1:0
|
||||
label:
|
||||
en_US: Claude 3.5 Haiku
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
|
||||
parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
type: int
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
help:
|
||||
zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
|
||||
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
|
||||
# docs: https://docs.anthropic.com/claude/docs/system-prompts
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
type: float
|
||||
default: 1
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 生成内容的随机性。
|
||||
en_US: The amount of randomness injected into the response.
|
||||
- name: top_p
|
||||
required: false
|
||||
type: float
|
||||
default: 0.999
|
||||
min: 0.000
|
||||
max: 1.000
|
||||
help:
|
||||
zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。
|
||||
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
|
||||
- name: top_k
|
||||
required: false
|
||||
type: int
|
||||
default: 0
|
||||
min: 0
|
||||
# tip docs from aws has error, max value is 500
|
||||
max: 500
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.001'
|
||||
output: '0.005'
|
||||
unit: '0.001'
|
||||
currency: USD
|
|
@ -0,0 +1,61 @@
|
|||
model: us.anthropic.claude-3-5-haiku-20241022-v1:0
|
||||
label:
|
||||
en_US: Claude 3.5 Haiku(US.Cross Region Inference)
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
|
||||
parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
type: int
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
|
||||
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
|
||||
# docs: https://docs.anthropic.com/claude/docs/system-prompts
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
type: float
|
||||
default: 1
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 生成内容的随机性。
|
||||
en_US: The amount of randomness injected into the response.
|
||||
- name: top_p
|
||||
required: false
|
||||
type: float
|
||||
default: 0.999
|
||||
min: 0.000
|
||||
max: 1.000
|
||||
help:
|
||||
zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。
|
||||
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
|
||||
- name: top_k
|
||||
required: false
|
||||
type: int
|
||||
default: 0
|
||||
min: 0
|
||||
# tip docs from aws has error, max value is 500
|
||||
max: 500
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.001'
|
||||
output: '0.005'
|
||||
unit: '0.001'
|
||||
currency: USD
|
|
@ -1,6 +1,7 @@
|
|||
import logging
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
import requests
|
||||
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
|
@ -16,8 +17,18 @@ class GiteeAIProvider(ModelProvider):
|
|||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
model_instance.validate_credentials(model="Qwen2-7B-Instruct", credentials=credentials)
|
||||
api_key = credentials.get("api_key")
|
||||
if not api_key:
|
||||
raise CredentialsValidateFailedError("Credentials validation failed: api_key not given")
|
||||
|
||||
# send a get request to validate the credentials
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
response = requests.get("https://ai.gitee.com/api/base/account/me", headers=headers, timeout=(10, 300))
|
||||
|
||||
if response.status_code != 200:
|
||||
raise CredentialsValidateFailedError(
|
||||
f"Credentials validation failed with status code {response.status_code}"
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
|
|
Binary file not shown.
After Width: | Height: | Size: 277 KiB |
|
@ -0,0 +1,15 @@
|
|||
<svg width="68" height="24" viewBox="0 0 68 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<g id="Gemini">
|
||||
<path id="Union" fill-rule="evenodd" clip-rule="evenodd" d="M50.6875 4.37014C48.3498 4.59292 46.5349 6.41319 46.3337 8.72764C46.1446 6.44662 44.2677 4.56074 41.9805 4.3737C44.2762 4.1997 46.152 2.28299 46.3373 0C46.4882 2.28911 48.405 4.20047 50.6875 4.37014ZM15.4567 9.41141L13.9579 10.9076C9.92941 6.64892 2.69298 9.97287 3.17317 15.8112C3.22394 23.108 14.5012 24.4317 15.3628 16.8809H9.52096L9.50061 14.9149H17.3595C18.8163 23.1364 8.44367 27.0292 3.19453 21.238C0.847044 18.7556 0.363651 14.7682 1.83717 11.7212C4.1129 6.62089 11.6505 5.29845 15.4567 9.41141ZM45.5915 23.5989H47.6945C47.6944 22.9155 47.6945 22.2307 47.6946 21.5452V21.5325C47.6948 19.8907 47.695 18.2453 47.6924 16.6072C47.6914 15.9407 47.6161 15.2823 47.4024 14.647C46.4188 11.2828 41.4255 11.4067 39.8332 14.214C38.5637 11.4171 34.4009 11.5236 32.8538 14.0084L32.8082 13.9976V12.4806L32.4233 12.4804H32.4224C31.8687 12.4801 31.3324 12.4798 30.7949 12.4811V23.5848L32.8977 23.5672C32.8981 22.9411 32.8979 22.3122 32.8977 21.6822V21.6812V21.6802V21.6791V21.6781V21.6771V21.676V21.676V21.6759V21.6758V21.6757V21.6757V21.6756C32.8973 20.204 32.8969 18.7261 32.904 17.2614C32.8889 15.3646 34.5674 13.5687 36.5358 14.124C37.7794 14.3298 38.1851 15.6148 38.1761 16.7257C38.1821 17.7019 38.18 18.6824 38.178 19.6633V19.6638C38.1752 20.9756 38.1724 22.2881 38.1891 23.5919L40.2846 23.5731C40.2929 22.7511 40.2881 21.9245 40.2832 21.0966C40.2753 19.7402 40.2674 18.3805 40.317 17.0328C40.4418 15.2122 42.0141 13.6186 43.9064 14.1168C45.2685 14.3231 45.6136 15.7748 45.5882 16.9545C45.5938 18.4959 45.5929 20.0492 45.5921 21.5968V21.5991V21.6014V21.6037V21.606V21.6083V21.6106C45.5917 22.2749 45.5913 22.9382 45.5915 23.5989ZM20.6167 18.4408C20.5625 21.9486 25.2121 23.6996 27.2993 20.0558L29.1566 20.9592C27.8157 23.7067 24.2337 24.7424 21.5381 23.4213C18.0052 21.7253 17.41 16.5007 20.0334 13.7517C21.4609 12.1752 23.7291 11.7901 25.7206 12.3653C28.3408 13.1257 29.4974 15.8937 29.326 18.4399C27.5547 18.4415 25.7971 18.4412 24.0364 18.4409C22.8993 18.4407 21.7609 18.4405 20.6167 18.4408ZM27.1041 16.6957C26.7048 13.1033 21.2867 13.2256 20.7494 16.6957H27.1041ZM53.543 23.5999H55.6206L55.6206 22.4361C55.6205 20.7877 55.6205 19.1443 55.6207 17.4939C55.6208 16.8853 55.7234 16.297 56.0063 15.7531C56.6115 14.3862 58.1745 13.7002 59.5927 14.1774C60.7512 14.4455 61.2852 15.6069 61.2762 16.7154C61.2774 18.3497 61.2771 19.9826 61.2769 21.6162V21.6166V21.617V21.6174V21.6179L61.2766 23.6007H63.3698C63.3913 22.0924 63.3869 20.584 63.3826 19.0755V19.0754V19.0753V19.0753V19.0752C63.3799 18.1682 63.3773 17.2612 63.3803 16.3541C63.3796 15.8622 63.3103 15.3765 63.1698 14.9052C62.3248 11.5142 57.3558 11.2385 55.5828 14.0038L55.5336 13.9905V12.4917H53.539C53.4898 12.7313 53.4934 23.4113 53.543 23.5999ZM49.6211 12.4944H51.7065V23.5994H49.6211V12.4944ZM65.1035 23.5991H67.1831C67.2367 23.2198 67.2133 12.6566 67.1634 12.4983H65.1035V23.5991ZM52.1504 8.67829C52.1709 10.4847 49.2418 10.7058 49.1816 8.65714C49.2189 6.5948 52.2437 6.81331 52.1504 8.67829ZM66.1387 10.1324C64.2712 10.1609 64.1316 7.19881 66.1559 7.17114C68.1709 7.19817 68.0215 10.2087 66.1387 10.1324Z" fill="url(#paint0_linear_14286_118464)"/>
|
||||
</g>
|
||||
<defs>
|
||||
<linearGradient id="paint0_linear_14286_118464" x1="-2" y1="0.999998" x2="67.9999" y2="27.5002" gradientUnits="userSpaceOnUse">
|
||||
<stop stop-color="#7798E0"/>
|
||||
<stop offset="0.210002" stop-color="#086FFF"/>
|
||||
<stop offset="0.345945" stop-color="#086FFF"/>
|
||||
<stop offset="0.591777" stop-color="#479AFF"/>
|
||||
<stop offset="0.895892" stop-color="#B7C4FA"/>
|
||||
<stop offset="1" stop-color="#B5C5F9"/>
|
||||
</linearGradient>
|
||||
</defs>
|
||||
</svg>
|
After Width: | Height: | Size: 3.6 KiB |
Binary file not shown.
After Width: | Height: | Size: 57 KiB |
|
@ -0,0 +1,11 @@
|
|||
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<rect width="24" height="24" rx="6" fill="url(#paint0_linear_7301_16076)"/>
|
||||
<path d="M20 12.0116C15.7043 12.42 12.3692 15.757 11.9995 20C11.652 15.8183 8.20301 12.361 4 12.0181C8.21855 11.6991 11.6656 8.1853 12.006 4C12.2833 8.19653 15.8057 11.7005 20 12.0116Z" fill="white" fill-opacity="0.88"/>
|
||||
<defs>
|
||||
<linearGradient id="paint0_linear_7301_16076" x1="-9" y1="29.5" x2="19.4387" y2="1.43791" gradientUnits="userSpaceOnUse">
|
||||
<stop offset="0.192878" stop-color="#1C7DFF"/>
|
||||
<stop offset="0.520213" stop-color="#1C69FF"/>
|
||||
<stop offset="1" stop-color="#F0DCD6"/>
|
||||
</linearGradient>
|
||||
</defs>
|
||||
</svg>
|
After Width: | Height: | Size: 689 B |
10
api/core/model_runtime/model_providers/gpustack/gpustack.py
Normal file
10
api/core/model_runtime/model_providers/gpustack/gpustack.py
Normal file
|
@ -0,0 +1,10 @@
|
|||
import logging
|
||||
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GPUStackProvider(ModelProvider):
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
pass
|
120
api/core/model_runtime/model_providers/gpustack/gpustack.yaml
Normal file
120
api/core/model_runtime/model_providers/gpustack/gpustack.yaml
Normal file
|
@ -0,0 +1,120 @@
|
|||
provider: gpustack
|
||||
label:
|
||||
en_US: GPUStack
|
||||
icon_small:
|
||||
en_US: icon_s_en.png
|
||||
icon_large:
|
||||
en_US: icon_l_en.png
|
||||
supported_model_types:
|
||||
- llm
|
||||
- text-embedding
|
||||
- rerank
|
||||
configurate_methods:
|
||||
- customizable-model
|
||||
model_credential_schema:
|
||||
model:
|
||||
label:
|
||||
en_US: Model Name
|
||||
zh_Hans: 模型名称
|
||||
placeholder:
|
||||
en_US: Enter your model name
|
||||
zh_Hans: 输入模型名称
|
||||
credential_form_schemas:
|
||||
- variable: endpoint_url
|
||||
label:
|
||||
zh_Hans: 服务器地址
|
||||
en_US: Server URL
|
||||
type: text-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 输入 GPUStack 的服务器地址,如 http://192.168.1.100
|
||||
en_US: Enter the GPUStack server URL, e.g. http://192.168.1.100
|
||||
- variable: api_key
|
||||
label:
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
- variable: mode
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
label:
|
||||
en_US: Completion mode
|
||||
type: select
|
||||
required: false
|
||||
default: chat
|
||||
placeholder:
|
||||
zh_Hans: 选择补全类型
|
||||
en_US: Select completion type
|
||||
options:
|
||||
- value: completion
|
||||
label:
|
||||
en_US: Completion
|
||||
zh_Hans: 补全
|
||||
- value: chat
|
||||
label:
|
||||
en_US: Chat
|
||||
zh_Hans: 对话
|
||||
- variable: context_size
|
||||
label:
|
||||
zh_Hans: 模型上下文长度
|
||||
en_US: Model context size
|
||||
required: true
|
||||
type: text-input
|
||||
default: "8192"
|
||||
placeholder:
|
||||
zh_Hans: 输入您的模型上下文长度
|
||||
en_US: Enter your Model context size
|
||||
- variable: max_tokens_to_sample
|
||||
label:
|
||||
zh_Hans: 最大 token 上限
|
||||
en_US: Upper bound for max tokens
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
default: "8192"
|
||||
type: text-input
|
||||
- variable: function_calling_type
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
label:
|
||||
en_US: Function calling
|
||||
type: select
|
||||
required: false
|
||||
default: no_call
|
||||
options:
|
||||
- value: function_call
|
||||
label:
|
||||
en_US: Function Call
|
||||
zh_Hans: Function Call
|
||||
- value: tool_call
|
||||
label:
|
||||
en_US: Tool Call
|
||||
zh_Hans: Tool Call
|
||||
- value: no_call
|
||||
label:
|
||||
en_US: Not Support
|
||||
zh_Hans: 不支持
|
||||
- variable: vision_support
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
label:
|
||||
zh_Hans: Vision 支持
|
||||
en_US: Vision Support
|
||||
type: select
|
||||
required: false
|
||||
default: no_support
|
||||
options:
|
||||
- value: support
|
||||
label:
|
||||
en_US: Support
|
||||
zh_Hans: 支持
|
||||
- value: no_support
|
||||
label:
|
||||
en_US: Not Support
|
||||
zh_Hans: 不支持
|
45
api/core/model_runtime/model_providers/gpustack/llm/llm.py
Normal file
45
api/core/model_runtime/model_providers/gpustack/llm/llm.py
Normal file
|
@ -0,0 +1,45 @@
|
|||
from collections.abc import Generator
|
||||
|
||||
from yarl import URL
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
)
|
||||
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import (
|
||||
OAIAPICompatLargeLanguageModel,
|
||||
)
|
||||
|
||||
|
||||
class GPUStackLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
) -> LLMResult | Generator:
|
||||
return super()._invoke(
|
||||
model,
|
||||
credentials,
|
||||
prompt_messages,
|
||||
model_parameters,
|
||||
tools,
|
||||
stop,
|
||||
stream,
|
||||
user,
|
||||
)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
self._add_custom_parameters(credentials)
|
||||
super().validate_credentials(model, credentials)
|
||||
|
||||
@staticmethod
|
||||
def _add_custom_parameters(credentials: dict) -> None:
|
||||
credentials["endpoint_url"] = str(URL(credentials["endpoint_url"]) / "v1-openai")
|
||||
credentials["mode"] = "chat"
|
146
api/core/model_runtime/model_providers/gpustack/rerank/rerank.py
Normal file
146
api/core/model_runtime/model_providers/gpustack/rerank/rerank.py
Normal file
|
@ -0,0 +1,146 @@
|
|||
from json import dumps
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
from requests import post
|
||||
from yarl import URL
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
FetchFrom,
|
||||
ModelPropertyKey,
|
||||
ModelType,
|
||||
)
|
||||
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||
|
||||
|
||||
class GPUStackRerankModel(RerankModel):
|
||||
"""
|
||||
Model class for GPUStack rerank model.
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
query: str,
|
||||
docs: list[str],
|
||||
score_threshold: Optional[float] = None,
|
||||
top_n: Optional[int] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> RerankResult:
|
||||
"""
|
||||
Invoke rerank model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param query: search query
|
||||
:param docs: docs for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n documents to return
|
||||
:param user: unique user id
|
||||
:return: rerank result
|
||||
"""
|
||||
if len(docs) == 0:
|
||||
return RerankResult(model=model, docs=[])
|
||||
|
||||
endpoint_url = credentials["endpoint_url"]
|
||||
headers = {
|
||||
"Authorization": f"Bearer {credentials.get('api_key')}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
data = {"model": model, "query": query, "documents": docs, "top_n": top_n}
|
||||
|
||||
try:
|
||||
response = post(
|
||||
str(URL(endpoint_url) / "v1" / "rerank"),
|
||||
headers=headers,
|
||||
data=dumps(data),
|
||||
timeout=10,
|
||||
)
|
||||
response.raise_for_status()
|
||||
results = response.json()
|
||||
|
||||
rerank_documents = []
|
||||
for result in results["results"]:
|
||||
index = result["index"]
|
||||
if "document" in result:
|
||||
text = result["document"]["text"]
|
||||
else:
|
||||
text = docs[index]
|
||||
|
||||
rerank_document = RerankDocument(
|
||||
index=index,
|
||||
text=text,
|
||||
score=result["relevance_score"],
|
||||
)
|
||||
|
||||
if score_threshold is None or result["relevance_score"] >= score_threshold:
|
||||
rerank_documents.append(rerank_document)
|
||||
|
||||
return RerankResult(model=model, docs=rerank_documents)
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise InvokeServerUnavailableError(str(e))
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
self._invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
query="What is the capital of the United States?",
|
||||
docs=[
|
||||
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
|
||||
"Census, Carson City had a population of 55,274.",
|
||||
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
|
||||
"are a political division controlled by the United States. Its capital is Saipan.",
|
||||
],
|
||||
score_threshold=0.8,
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [httpx.ConnectError],
|
||||
InvokeServerUnavailableError: [httpx.RemoteProtocolError],
|
||||
InvokeRateLimitError: [],
|
||||
InvokeAuthorizationError: [httpx.HTTPStatusError],
|
||||
InvokeBadRequestError: [httpx.RequestError],
|
||||
}
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||
"""
|
||||
generate custom model entities from credentials
|
||||
"""
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(en_US=model),
|
||||
model_type=ModelType.RERANK,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))},
|
||||
)
|
||||
|
||||
return entity
|
|
@ -0,0 +1,35 @@
|
|||
from typing import Optional
|
||||
|
||||
from yarl import URL
|
||||
|
||||
from core.entities.embedding_type import EmbeddingInputType
|
||||
from core.model_runtime.entities.text_embedding_entities import (
|
||||
TextEmbeddingResult,
|
||||
)
|
||||
from core.model_runtime.model_providers.openai_api_compatible.text_embedding.text_embedding import (
|
||||
OAICompatEmbeddingModel,
|
||||
)
|
||||
|
||||
|
||||
class GPUStackTextEmbeddingModel(OAICompatEmbeddingModel):
|
||||
"""
|
||||
Model class for GPUStack text embedding model.
|
||||
"""
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
return super()._invoke(model, credentials, texts, user, input_type)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
self._add_custom_parameters(credentials)
|
||||
super().validate_credentials(model, credentials)
|
||||
|
||||
@staticmethod
|
||||
def _add_custom_parameters(credentials: dict) -> None:
|
||||
credentials["endpoint_url"] = str(URL(credentials["endpoint_url"]) / "v1-openai")
|
|
@ -0,0 +1 @@
|
|||
<svg xmlns="http://www.w3.org/2000/svg" fill="currentColor" viewBox="0 0 24 24" aria-hidden="true" class="" focusable="false" style="fill:currentColor;height:28px;width:28px"><path d="m3.005 8.858 8.783 12.544h3.904L6.908 8.858zM6.905 15.825 3 21.402h3.907l1.951-2.788zM16.585 2l-6.75 9.64 1.953 2.79L20.492 2zM17.292 7.965v13.437h3.2V3.395z"></path></svg>
|
After Width: | Height: | Size: 356 B |
63
api/core/model_runtime/model_providers/x/llm/grok-beta.yaml
Normal file
63
api/core/model_runtime/model_providers/x/llm/grok-beta.yaml
Normal file
|
@ -0,0 +1,63 @@
|
|||
model: grok-beta
|
||||
label:
|
||||
en_US: Grok beta
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 131072
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
label:
|
||||
en_US: "Temperature"
|
||||
zh_Hans: "采样温度"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: top_p
|
||||
label:
|
||||
en_US: "Top P"
|
||||
zh_Hans: "Top P"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
label:
|
||||
en_US: "Frequency Penalty"
|
||||
zh_Hans: "频率惩罚"
|
||||
type: float
|
||||
default: 0
|
||||
min: 0
|
||||
max: 2.0
|
||||
precision: 1
|
||||
required: false
|
||||
help:
|
||||
en_US: "Number between 0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim."
|
||||
zh_Hans: "介于0和2.0之间的数字。正值会根据新标记在文本中迄今为止的现有频率来惩罚它们,从而降低模型一字不差地重复同一句话的可能性。"
|
||||
|
||||
- name: user
|
||||
use_template: text
|
||||
label:
|
||||
en_US: "User"
|
||||
zh_Hans: "用户"
|
||||
type: string
|
||||
required: false
|
||||
help:
|
||||
en_US: "Used to track and differentiate conversation requests from different users."
|
||||
zh_Hans: "用于追踪和区分不同用户的对话请求。"
|
37
api/core/model_runtime/model_providers/x/llm/llm.py
Normal file
37
api/core/model_runtime/model_providers/x/llm/llm.py
Normal file
|
@ -0,0 +1,37 @@
|
|||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
from yarl import URL
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
)
|
||||
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
|
||||
|
||||
|
||||
class XAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
self._add_custom_parameters(credentials)
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
self._add_custom_parameters(credentials)
|
||||
super().validate_credentials(model, credentials)
|
||||
|
||||
@staticmethod
|
||||
def _add_custom_parameters(credentials) -> None:
|
||||
credentials["endpoint_url"] = str(URL(credentials["endpoint_url"])) or "https://api.x.ai/v1"
|
||||
credentials["mode"] = LLMMode.CHAT.value
|
||||
credentials["function_calling_type"] = "tool_call"
|
25
api/core/model_runtime/model_providers/x/x.py
Normal file
25
api/core/model_runtime/model_providers/x/x.py
Normal file
|
@ -0,0 +1,25 @@
|
|||
import logging
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class XAIProvider(ModelProvider):
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
if validate failed, raise exception
|
||||
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
model_instance.validate_credentials(model="grok-beta", credentials=credentials)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
|
||||
raise ex
|
38
api/core/model_runtime/model_providers/x/x.yaml
Normal file
38
api/core/model_runtime/model_providers/x/x.yaml
Normal file
|
@ -0,0 +1,38 @@
|
|||
provider: x
|
||||
label:
|
||||
en_US: xAI
|
||||
description:
|
||||
en_US: xAI is a company working on building artificial intelligence to accelerate human scientific discovery. We are guided by our mission to advance our collective understanding of the universe.
|
||||
icon_small:
|
||||
en_US: x-ai-logo.svg
|
||||
icon_large:
|
||||
en_US: x-ai-logo.svg
|
||||
help:
|
||||
title:
|
||||
en_US: Get your token from xAI
|
||||
zh_Hans: 从 xAI 获取 token
|
||||
url:
|
||||
en_US: https://x.ai/api
|
||||
supported_model_types:
|
||||
- llm
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
- variable: api_key
|
||||
label:
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
- variable: endpoint_url
|
||||
label:
|
||||
en_US: API Base
|
||||
type: text-input
|
||||
required: false
|
||||
default: https://api.x.ai/v1
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Base
|
||||
en_US: Enter your API Base
|
|
@ -53,7 +53,7 @@ class BasePluginManager:
|
|||
)
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
logger.exception(f"Request to Plugin Daemon Service failed: {e}")
|
||||
raise ValueError("Request to Plugin Daemon Service failed")
|
||||
raise PluginDaemonInnerError(code=-500, message="Request to Plugin Daemon Service failed")
|
||||
|
||||
return response
|
||||
|
||||
|
@ -157,8 +157,17 @@ class BasePluginManager:
|
|||
Make a stream request to the plugin daemon inner API and yield the response as a model.
|
||||
"""
|
||||
for line in self._stream_request(method, path, params, headers, data, files):
|
||||
line_data = None
|
||||
try:
|
||||
line_data = json.loads(line)
|
||||
rep = PluginDaemonBasicResponse[type](**line_data)
|
||||
except Exception as e:
|
||||
# TODO modify this when line_data has code and message
|
||||
if line_data and "error" in line_data:
|
||||
raise ValueError(line_data["error"])
|
||||
else:
|
||||
raise ValueError(line)
|
||||
|
||||
if rep.code != 0:
|
||||
if rep.code == -500:
|
||||
try:
|
||||
|
|
|
@ -437,6 +437,7 @@ class PluginModelManager(BasePluginManager):
|
|||
voices = []
|
||||
for voice in resp.voices:
|
||||
voices.append({"name": voice.name, "value": voice.value})
|
||||
|
||||
return voices
|
||||
|
||||
return []
|
||||
|
|
|
@ -103,7 +103,7 @@ class RetrievalService:
|
|||
|
||||
if exceptions:
|
||||
exception_message = ";\n".join(exceptions)
|
||||
raise Exception(exception_message)
|
||||
raise ValueError(exception_message)
|
||||
|
||||
if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value:
|
||||
data_post_processor = DataPostProcessor(
|
||||
|
|
0
api/core/rag/datasource/vdb/lindorm/__init__.py
Normal file
0
api/core/rag/datasource/vdb/lindorm/__init__.py
Normal file
498
api/core/rag/datasource/vdb/lindorm/lindorm_vector.py
Normal file
498
api/core/rag/datasource/vdb/lindorm/lindorm_vector.py
Normal file
|
@ -0,0 +1,498 @@
|
|||
import copy
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Optional
|
||||
|
||||
from opensearchpy import OpenSearch
|
||||
from opensearchpy.helpers import bulk
|
||||
from pydantic import BaseModel, model_validator
|
||||
from tenacity import retry, stop_after_attempt, wait_fixed
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
from core.rag.datasource.vdb.vector_base import BaseVector
|
||||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.embedding.embedding_base import Embeddings
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
logging.getLogger("lindorm").setLevel(logging.WARN)
|
||||
|
||||
|
||||
class LindormVectorStoreConfig(BaseModel):
|
||||
hosts: str
|
||||
username: Optional[str] = None
|
||||
password: Optional[str] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values["hosts"]:
|
||||
raise ValueError("config URL is required")
|
||||
if not values["username"]:
|
||||
raise ValueError("config USERNAME is required")
|
||||
if not values["password"]:
|
||||
raise ValueError("config PASSWORD is required")
|
||||
return values
|
||||
|
||||
def to_opensearch_params(self) -> dict[str, Any]:
|
||||
params = {
|
||||
"hosts": self.hosts,
|
||||
}
|
||||
if self.username and self.password:
|
||||
params["http_auth"] = (self.username, self.password)
|
||||
return params
|
||||
|
||||
|
||||
class LindormVectorStore(BaseVector):
|
||||
def __init__(self, collection_name: str, config: LindormVectorStoreConfig, **kwargs):
|
||||
super().__init__(collection_name.lower())
|
||||
self._client_config = config
|
||||
self._client = OpenSearch(**config.to_opensearch_params())
|
||||
self.kwargs = kwargs
|
||||
|
||||
def get_type(self) -> str:
|
||||
return VectorType.LINDORM
|
||||
|
||||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
self.create_collection(len(embeddings[0]), **kwargs)
|
||||
self.add_texts(texts, embeddings)
|
||||
|
||||
def refresh(self):
|
||||
self._client.indices.refresh(index=self._collection_name)
|
||||
|
||||
def __filter_existed_ids(
|
||||
self,
|
||||
texts: list[str],
|
||||
metadatas: list[dict],
|
||||
ids: list[str],
|
||||
bulk_size: int = 1024,
|
||||
) -> tuple[Iterable[str], Optional[list[dict]], Optional[list[str]]]:
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(60))
|
||||
def __fetch_existing_ids(batch_ids: list[str]) -> set[str]:
|
||||
try:
|
||||
existing_docs = self._client.mget(index=self._collection_name, body={"ids": batch_ids}, _source=False)
|
||||
return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]}
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching batch {batch_ids}: {e}")
|
||||
return set()
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(60))
|
||||
def __fetch_existing_routing_ids(batch_ids: list[str], route_ids: list[str]) -> set[str]:
|
||||
try:
|
||||
existing_docs = self._client.mget(
|
||||
body={
|
||||
"docs": [
|
||||
{"_index": self._collection_name, "_id": id, "routing": routing}
|
||||
for id, routing in zip(batch_ids, route_ids)
|
||||
]
|
||||
},
|
||||
_source=False,
|
||||
)
|
||||
return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]}
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching batch {batch_ids}: {e}")
|
||||
return set()
|
||||
|
||||
if ids is None:
|
||||
return texts, metadatas, ids
|
||||
|
||||
if len(texts) != len(ids):
|
||||
raise RuntimeError(f"texts {len(texts)} != {ids}")
|
||||
|
||||
filtered_texts = []
|
||||
filtered_metadatas = []
|
||||
filtered_ids = []
|
||||
|
||||
def batch(iterable, n):
|
||||
length = len(iterable)
|
||||
for idx in range(0, length, n):
|
||||
yield iterable[idx : min(idx + n, length)]
|
||||
|
||||
for ids_batch, texts_batch, metadatas_batch in zip(
|
||||
batch(ids, bulk_size),
|
||||
batch(texts, bulk_size),
|
||||
batch(metadatas, bulk_size) if metadatas is not None else batch([None] * len(ids), bulk_size),
|
||||
):
|
||||
existing_ids_set = __fetch_existing_ids(ids_batch)
|
||||
for text, metadata, doc_id in zip(texts_batch, metadatas_batch, ids_batch):
|
||||
if doc_id not in existing_ids_set:
|
||||
filtered_texts.append(text)
|
||||
filtered_ids.append(doc_id)
|
||||
if metadatas is not None:
|
||||
filtered_metadatas.append(metadata)
|
||||
|
||||
return filtered_texts, metadatas if metadatas is None else filtered_metadatas, filtered_ids
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
actions = []
|
||||
uuids = self._get_uuids(documents)
|
||||
for i in range(len(documents)):
|
||||
action = {
|
||||
"_op_type": "index",
|
||||
"_index": self._collection_name.lower(),
|
||||
"_id": uuids[i],
|
||||
"_source": {
|
||||
Field.CONTENT_KEY.value: documents[i].page_content,
|
||||
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
|
||||
Field.METADATA_KEY.value: documents[i].metadata,
|
||||
},
|
||||
}
|
||||
actions.append(action)
|
||||
bulk(self._client, actions)
|
||||
self.refresh()
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}}
|
||||
response = self._client.search(index=self._collection_name, body=query)
|
||||
if response["hits"]["hits"]:
|
||||
return [hit["_id"] for hit in response["hits"]["hits"]]
|
||||
else:
|
||||
return None
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
query_str = {"query": {"match": {f"metadata.{key}": f"{value}"}}}
|
||||
results = self._client.search(index=self._collection_name, body=query_str)
|
||||
ids = [hit["_id"] for hit in results["hits"]["hits"]]
|
||||
if ids:
|
||||
self.delete_by_ids(ids)
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
for id in ids:
|
||||
if self._client.exists(index=self._collection_name, id=id):
|
||||
self._client.delete(index=self._collection_name, id=id)
|
||||
else:
|
||||
logger.warning(f"DELETE BY ID: ID {id} does not exist in the index.")
|
||||
|
||||
def delete(self) -> None:
|
||||
try:
|
||||
if self._client.indices.exists(index=self._collection_name):
|
||||
self._client.indices.delete(index=self._collection_name, params={"timeout": 60})
|
||||
logger.info("Delete index success")
|
||||
else:
|
||||
logger.warning(f"Index '{self._collection_name}' does not exist. No deletion performed.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error occurred while deleting the index: {e}")
|
||||
raise e
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
try:
|
||||
self._client.get(index=self._collection_name, id=id)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
# Make sure query_vector is a list
|
||||
if not isinstance(query_vector, list):
|
||||
raise ValueError("query_vector should be a list of floats")
|
||||
|
||||
# Check whether query_vector is a floating-point number list
|
||||
if not all(isinstance(x, float) for x in query_vector):
|
||||
raise ValueError("All elements in query_vector should be floats")
|
||||
|
||||
top_k = kwargs.get("top_k", 10)
|
||||
query = default_vector_search_query(query_vector=query_vector, k=top_k, **kwargs)
|
||||
try:
|
||||
response = self._client.search(index=self._collection_name, body=query)
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing search: {e}")
|
||||
raise
|
||||
|
||||
docs_and_scores = []
|
||||
for hit in response["hits"]["hits"]:
|
||||
docs_and_scores.append(
|
||||
(
|
||||
Document(
|
||||
page_content=hit["_source"][Field.CONTENT_KEY.value],
|
||||
vector=hit["_source"][Field.VECTOR.value],
|
||||
metadata=hit["_source"][Field.METADATA_KEY.value],
|
||||
),
|
||||
hit["_score"],
|
||||
)
|
||||
)
|
||||
docs = []
|
||||
for doc, score in docs_and_scores:
|
||||
score_threshold = kwargs.get("score_threshold", 0.0) or 0.0
|
||||
if score > score_threshold:
|
||||
doc.metadata["score"] = score
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
must = kwargs.get("must")
|
||||
must_not = kwargs.get("must_not")
|
||||
should = kwargs.get("should")
|
||||
minimum_should_match = kwargs.get("minimum_should_match", 0)
|
||||
top_k = kwargs.get("top_k", 10)
|
||||
filters = kwargs.get("filter")
|
||||
routing = kwargs.get("routing")
|
||||
full_text_query = default_text_search_query(
|
||||
query_text=query,
|
||||
k=top_k,
|
||||
text_field=Field.CONTENT_KEY.value,
|
||||
must=must,
|
||||
must_not=must_not,
|
||||
should=should,
|
||||
minimum_should_match=minimum_should_match,
|
||||
filters=filters,
|
||||
routing=routing,
|
||||
)
|
||||
response = self._client.search(index=self._collection_name, body=full_text_query)
|
||||
docs = []
|
||||
for hit in response["hits"]["hits"]:
|
||||
docs.append(
|
||||
Document(
|
||||
page_content=hit["_source"][Field.CONTENT_KEY.value],
|
||||
vector=hit["_source"][Field.VECTOR.value],
|
||||
metadata=hit["_source"][Field.METADATA_KEY.value],
|
||||
)
|
||||
)
|
||||
|
||||
return docs
|
||||
|
||||
def create_collection(self, dimension: int, **kwargs):
|
||||
lock_name = f"vector_indexing_lock_{self._collection_name}"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
logger.info(f"Collection {self._collection_name} already exists.")
|
||||
return
|
||||
if self._client.indices.exists(index=self._collection_name):
|
||||
logger.info("{self._collection_name.lower()} already exists.")
|
||||
return
|
||||
if len(self.kwargs) == 0 and len(kwargs) != 0:
|
||||
self.kwargs = copy.deepcopy(kwargs)
|
||||
vector_field = kwargs.pop("vector_field", Field.VECTOR.value)
|
||||
shards = kwargs.pop("shards", 2)
|
||||
|
||||
engine = kwargs.pop("engine", "lvector")
|
||||
method_name = kwargs.pop("method_name", "hnsw")
|
||||
data_type = kwargs.pop("data_type", "float")
|
||||
space_type = kwargs.pop("space_type", "cosinesimil")
|
||||
|
||||
hnsw_m = kwargs.pop("hnsw_m", 24)
|
||||
hnsw_ef_construction = kwargs.pop("hnsw_ef_construction", 500)
|
||||
ivfpq_m = kwargs.pop("ivfpq_m", dimension)
|
||||
nlist = kwargs.pop("nlist", 1000)
|
||||
centroids_use_hnsw = kwargs.pop("centroids_use_hnsw", True if nlist >= 5000 else False)
|
||||
centroids_hnsw_m = kwargs.pop("centroids_hnsw_m", 24)
|
||||
centroids_hnsw_ef_construct = kwargs.pop("centroids_hnsw_ef_construct", 500)
|
||||
centroids_hnsw_ef_search = kwargs.pop("centroids_hnsw_ef_search", 100)
|
||||
mapping = default_text_mapping(
|
||||
dimension,
|
||||
method_name,
|
||||
shards=shards,
|
||||
engine=engine,
|
||||
data_type=data_type,
|
||||
space_type=space_type,
|
||||
vector_field=vector_field,
|
||||
hnsw_m=hnsw_m,
|
||||
hnsw_ef_construction=hnsw_ef_construction,
|
||||
nlist=nlist,
|
||||
ivfpq_m=ivfpq_m,
|
||||
centroids_use_hnsw=centroids_use_hnsw,
|
||||
centroids_hnsw_m=centroids_hnsw_m,
|
||||
centroids_hnsw_ef_construct=centroids_hnsw_ef_construct,
|
||||
centroids_hnsw_ef_search=centroids_hnsw_ef_search,
|
||||
**kwargs,
|
||||
)
|
||||
self._client.indices.create(index=self._collection_name.lower(), body=mapping)
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
# logger.info(f"create index success: {self._collection_name}")
|
||||
|
||||
|
||||
def default_text_mapping(dimension: int, method_name: str, **kwargs: Any) -> dict:
|
||||
routing_field = kwargs.get("routing_field")
|
||||
excludes_from_source = kwargs.get("excludes_from_source")
|
||||
analyzer = kwargs.get("analyzer", "ik_max_word")
|
||||
text_field = kwargs.get("text_field", Field.CONTENT_KEY.value)
|
||||
engine = kwargs["engine"]
|
||||
shard = kwargs["shards"]
|
||||
space_type = kwargs["space_type"]
|
||||
data_type = kwargs["data_type"]
|
||||
vector_field = kwargs.get("vector_field", Field.VECTOR.value)
|
||||
|
||||
if method_name == "ivfpq":
|
||||
ivfpq_m = kwargs["ivfpq_m"]
|
||||
nlist = kwargs["nlist"]
|
||||
centroids_use_hnsw = True if nlist > 10000 else False
|
||||
centroids_hnsw_m = 24
|
||||
centroids_hnsw_ef_construct = 500
|
||||
centroids_hnsw_ef_search = 100
|
||||
parameters = {
|
||||
"m": ivfpq_m,
|
||||
"nlist": nlist,
|
||||
"centroids_use_hnsw": centroids_use_hnsw,
|
||||
"centroids_hnsw_m": centroids_hnsw_m,
|
||||
"centroids_hnsw_ef_construct": centroids_hnsw_ef_construct,
|
||||
"centroids_hnsw_ef_search": centroids_hnsw_ef_search,
|
||||
}
|
||||
elif method_name == "hnsw":
|
||||
neighbor = kwargs["hnsw_m"]
|
||||
ef_construction = kwargs["hnsw_ef_construction"]
|
||||
parameters = {"m": neighbor, "ef_construction": ef_construction}
|
||||
elif method_name == "flat":
|
||||
parameters = {}
|
||||
else:
|
||||
raise RuntimeError(f"unexpected method_name: {method_name}")
|
||||
|
||||
mapping = {
|
||||
"settings": {"index": {"number_of_shards": shard, "knn": True}},
|
||||
"mappings": {
|
||||
"properties": {
|
||||
vector_field: {
|
||||
"type": "knn_vector",
|
||||
"dimension": dimension,
|
||||
"data_type": data_type,
|
||||
"method": {
|
||||
"engine": engine,
|
||||
"name": method_name,
|
||||
"space_type": space_type,
|
||||
"parameters": parameters,
|
||||
},
|
||||
},
|
||||
text_field: {"type": "text", "analyzer": analyzer},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
if excludes_from_source:
|
||||
mapping["mappings"]["_source"] = {"excludes": excludes_from_source} # e.g. {"excludes": ["vector_field"]}
|
||||
|
||||
if method_name == "ivfpq" and routing_field is not None:
|
||||
mapping["settings"]["index"]["knn_routing"] = True
|
||||
mapping["settings"]["index"]["knn.offline.construction"] = True
|
||||
|
||||
if method_name == "flat" and routing_field is not None:
|
||||
mapping["settings"]["index"]["knn_routing"] = True
|
||||
|
||||
return mapping
|
||||
|
||||
|
||||
def default_text_search_query(
|
||||
query_text: str,
|
||||
k: int = 4,
|
||||
text_field: str = Field.CONTENT_KEY.value,
|
||||
must: Optional[list[dict]] = None,
|
||||
must_not: Optional[list[dict]] = None,
|
||||
should: Optional[list[dict]] = None,
|
||||
minimum_should_match: int = 0,
|
||||
filters: Optional[list[dict]] = None,
|
||||
routing: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
if routing is not None:
|
||||
routing_field = kwargs.get("routing_field", "routing_field")
|
||||
query_clause = {
|
||||
"bool": {
|
||||
"must": [{"match": {text_field: query_text}}, {"term": {f"metadata.{routing_field}.keyword": routing}}]
|
||||
}
|
||||
}
|
||||
else:
|
||||
query_clause = {"match": {text_field: query_text}}
|
||||
# build the simplest search_query when only query_text is specified
|
||||
if not must and not must_not and not should and not filters:
|
||||
search_query = {"size": k, "query": query_clause}
|
||||
return search_query
|
||||
|
||||
# build complex search_query when either of must/must_not/should/filter is specified
|
||||
if must:
|
||||
if not isinstance(must, list):
|
||||
raise RuntimeError(f"unexpected [must] clause with {type(filters)}")
|
||||
if query_clause not in must:
|
||||
must.append(query_clause)
|
||||
else:
|
||||
must = [query_clause]
|
||||
|
||||
boolean_query = {"must": must}
|
||||
|
||||
if must_not:
|
||||
if not isinstance(must_not, list):
|
||||
raise RuntimeError(f"unexpected [must_not] clause with {type(filters)}")
|
||||
boolean_query["must_not"] = must_not
|
||||
|
||||
if should:
|
||||
if not isinstance(should, list):
|
||||
raise RuntimeError(f"unexpected [should] clause with {type(filters)}")
|
||||
boolean_query["should"] = should
|
||||
if minimum_should_match != 0:
|
||||
boolean_query["minimum_should_match"] = minimum_should_match
|
||||
|
||||
if filters:
|
||||
if not isinstance(filters, list):
|
||||
raise RuntimeError(f"unexpected [filter] clause with {type(filters)}")
|
||||
boolean_query["filter"] = filters
|
||||
|
||||
search_query = {"size": k, "query": {"bool": boolean_query}}
|
||||
return search_query
|
||||
|
||||
|
||||
def default_vector_search_query(
|
||||
query_vector: list[float],
|
||||
k: int = 4,
|
||||
min_score: str = "0.0",
|
||||
ef_search: Optional[str] = None, # only for hnsw
|
||||
nprobe: Optional[str] = None, # "2000"
|
||||
reorder_factor: Optional[str] = None, # "20"
|
||||
client_refactor: Optional[str] = None, # "true"
|
||||
vector_field: str = Field.VECTOR.value,
|
||||
filters: Optional[list[dict]] = None,
|
||||
filter_type: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
if filters is not None:
|
||||
filter_type = "post_filter" if filter_type is None else filter_type
|
||||
if not isinstance(filter, list):
|
||||
raise RuntimeError(f"unexpected filter with {type(filters)}")
|
||||
final_ext = {"lvector": {}}
|
||||
if min_score != "0.0":
|
||||
final_ext["lvector"]["min_score"] = min_score
|
||||
if ef_search:
|
||||
final_ext["lvector"]["ef_search"] = ef_search
|
||||
if nprobe:
|
||||
final_ext["lvector"]["nprobe"] = nprobe
|
||||
if reorder_factor:
|
||||
final_ext["lvector"]["reorder_factor"] = reorder_factor
|
||||
if client_refactor:
|
||||
final_ext["lvector"]["client_refactor"] = client_refactor
|
||||
|
||||
search_query = {
|
||||
"size": k,
|
||||
"_source": True, # force return '_source'
|
||||
"query": {"knn": {vector_field: {"vector": query_vector, "k": k}}},
|
||||
}
|
||||
|
||||
if filters is not None:
|
||||
# when using filter, transform filter from List[Dict] to Dict as valid format
|
||||
filters = {"bool": {"must": filters}} if len(filters) > 1 else filters[0]
|
||||
search_query["query"]["knn"][vector_field]["filter"] = filters # filter should be Dict
|
||||
if filter_type:
|
||||
final_ext["lvector"]["filter_type"] = filter_type
|
||||
|
||||
if final_ext != {"lvector": {}}:
|
||||
search_query["ext"] = final_ext
|
||||
return search_query
|
||||
|
||||
|
||||
class LindormVectorStoreFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> LindormVectorStore:
|
||||
if dataset.index_struct_dict:
|
||||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
|
||||
collection_name = class_prefix
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.LINDORM, collection_name))
|
||||
lindorm_config = LindormVectorStoreConfig(
|
||||
hosts=dify_config.LINDORM_URL,
|
||||
username=dify_config.LINDORM_USERNAME,
|
||||
password=dify_config.LINDORM_PASSWORD,
|
||||
)
|
||||
return LindormVectorStore(collection_name, lindorm_config)
|
|
@ -134,6 +134,10 @@ class Vector:
|
|||
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector import TidbOnQdrantVectorFactory
|
||||
|
||||
return TidbOnQdrantVectorFactory
|
||||
case VectorType.LINDORM:
|
||||
from core.rag.datasource.vdb.lindorm.lindorm_vector import LindormVectorStoreFactory
|
||||
|
||||
return LindormVectorStoreFactory
|
||||
case VectorType.OCEANBASE:
|
||||
from core.rag.datasource.vdb.oceanbase.oceanbase_vector import OceanBaseVectorFactory
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@ class VectorType(str, Enum):
|
|||
TENCENT = "tencent"
|
||||
ORACLE = "oracle"
|
||||
ELASTICSEARCH = "elasticsearch"
|
||||
LINDORM = "lindorm"
|
||||
COUCHBASE = "couchbase"
|
||||
BAIDU = "baidu"
|
||||
VIKINGDB = "vikingdb"
|
||||
|
|
|
@ -14,6 +14,7 @@ import requests
|
|||
from docx import Document as DocxDocument
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper import ssrf_proxy
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
|
@ -86,7 +87,7 @@ class WordExtractor(BaseExtractor):
|
|||
image_count += 1
|
||||
if rel.is_external:
|
||||
url = rel.reltype
|
||||
response = requests.get(url, stream=True)
|
||||
response = ssrf_proxy.get(url, stream=True)
|
||||
if response.status_code == 200:
|
||||
image_ext = mimetypes.guess_extension(response.headers["Content-Type"])
|
||||
file_uuid = str(uuid.uuid4())
|
||||
|
|
|
@ -5,7 +5,12 @@ from typing import Any, Optional, Union
|
|||
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator
|
||||
|
||||
from core.entities.parameter_entities import AppSelectorScope, CommonParameterType, ModelConfigScope
|
||||
from core.entities.parameter_entities import (
|
||||
AppSelectorScope,
|
||||
CommonParameterType,
|
||||
ModelSelectorScope,
|
||||
ToolSelectorScope,
|
||||
)
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
|
||||
|
@ -209,6 +214,9 @@ class ToolParameter(BaseModel):
|
|||
SECRET_INPUT = CommonParameterType.SECRET_INPUT.value
|
||||
FILE = CommonParameterType.FILE.value
|
||||
FILES = CommonParameterType.FILES.value
|
||||
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
|
||||
TOOL_SELECTOR = CommonParameterType.TOOL_SELECTOR.value
|
||||
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value
|
||||
|
||||
# deprecated, should not use.
|
||||
SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value
|
||||
|
@ -258,11 +266,26 @@ class ToolParameter(BaseModel):
|
|||
return float(value)
|
||||
else:
|
||||
return int(value)
|
||||
case ToolParameter.ToolParameterType.SYSTEM_FILES | ToolParameter.ToolParameterType.FILES:
|
||||
if not isinstance(value, list):
|
||||
return [value]
|
||||
return value
|
||||
case ToolParameter.ToolParameterType.FILE:
|
||||
if isinstance(value, list):
|
||||
if len(value) != 1:
|
||||
raise ValueError(
|
||||
"This parameter only accepts one file but got multiple files while invoking."
|
||||
)
|
||||
else:
|
||||
return value[0]
|
||||
return value
|
||||
case (
|
||||
ToolParameter.ToolParameterType.SYSTEM_FILES
|
||||
| ToolParameter.ToolParameterType.FILE
|
||||
| ToolParameter.ToolParameterType.FILES
|
||||
ToolParameter.ToolParameterType.TOOL_SELECTOR
|
||||
| ToolParameter.ToolParameterType.MODEL_SELECTOR
|
||||
| ToolParameter.ToolParameterType.APP_SELECTOR
|
||||
):
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError("The selector must be a dictionary.")
|
||||
return value
|
||||
case _:
|
||||
return str(value)
|
||||
|
@ -280,7 +303,7 @@ class ToolParameter(BaseModel):
|
|||
human_description: Optional[I18nObject] = Field(default=None, description="The description presented to the user")
|
||||
placeholder: Optional[I18nObject] = Field(default=None, description="The placeholder presented to the user")
|
||||
type: ToolParameterType = Field(..., description="The type of the parameter")
|
||||
scope: AppSelectorScope | ModelConfigScope | None = None
|
||||
scope: AppSelectorScope | ModelSelectorScope | ToolSelectorScope | None = None
|
||||
form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm")
|
||||
llm_description: Optional[str] = None
|
||||
required: Optional[bool] = False
|
||||
|
|
|
@ -23,6 +23,7 @@ class NodeRunMetadataKey(str, Enum):
|
|||
PARALLEL_START_NODE_ID = "parallel_start_node_id"
|
||||
PARENT_PARALLEL_ID = "parent_parallel_id"
|
||||
PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id"
|
||||
PARALLEL_MODE_RUN_ID = "parallel_mode_run_id"
|
||||
|
||||
|
||||
class NodeRunResult(BaseModel):
|
||||
|
|
|
@ -59,6 +59,7 @@ class BaseNodeEvent(GraphEngineEvent):
|
|||
|
||||
class NodeRunStartedEvent(BaseNodeEvent):
|
||||
predecessor_node_id: Optional[str] = None
|
||||
parallel_mode_run_id: Optional[str] = None
|
||||
"""predecessor node id"""
|
||||
|
||||
|
||||
|
@ -81,6 +82,10 @@ class NodeRunFailedEvent(BaseNodeEvent):
|
|||
error: str = Field(..., description="error")
|
||||
|
||||
|
||||
class NodeInIterationFailedEvent(BaseNodeEvent):
|
||||
error: str = Field(..., description="error")
|
||||
|
||||
|
||||
###########################################
|
||||
# Parallel Branch Events
|
||||
###########################################
|
||||
|
@ -129,6 +134,8 @@ class BaseIterationEvent(GraphEngineEvent):
|
|||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
parallel_mode_run_id: Optional[str] = None
|
||||
"""iteratoin run in parallel mode run id"""
|
||||
|
||||
|
||||
class IterationRunStartedEvent(BaseIterationEvent):
|
||||
|
|
|
@ -4,6 +4,7 @@ import time
|
|||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from concurrent.futures import ThreadPoolExecutor, wait
|
||||
from copy import copy, deepcopy
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import Flask, current_app
|
||||
|
@ -724,6 +725,16 @@ class GraphEngine:
|
|||
"""
|
||||
return time.perf_counter() - start_at > max_execution_time
|
||||
|
||||
def create_copy(self):
|
||||
"""
|
||||
create a graph engine copy
|
||||
:return: with a new variable pool instance of graph engine
|
||||
"""
|
||||
new_instance = copy(self)
|
||||
new_instance.graph_runtime_state = copy(self.graph_runtime_state)
|
||||
new_instance.graph_runtime_state.variable_pool = deepcopy(self.graph_runtime_state.variable_pool)
|
||||
return new_instance
|
||||
|
||||
|
||||
class GraphRunFailedError(Exception):
|
||||
def __init__(self, error: str):
|
||||
|
|
|
@ -12,6 +12,12 @@ from core.workflow.nodes.code.entities import CodeNodeData
|
|||
from core.workflow.nodes.enums import NodeType
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from .exc import (
|
||||
CodeNodeError,
|
||||
DepthLimitError,
|
||||
OutputValidationError,
|
||||
)
|
||||
|
||||
|
||||
class CodeNode(BaseNode[CodeNodeData]):
|
||||
_node_data_cls = CodeNodeData
|
||||
|
@ -60,7 +66,7 @@ class CodeNode(BaseNode[CodeNodeData]):
|
|||
|
||||
# Transform result
|
||||
result = self._transform_result(result, self.node_data.outputs)
|
||||
except (CodeExecutionError, ValueError) as e:
|
||||
except (CodeExecutionError, CodeNodeError) as e:
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e))
|
||||
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result)
|
||||
|
@ -76,10 +82,10 @@ class CodeNode(BaseNode[CodeNodeData]):
|
|||
if value is None:
|
||||
return None
|
||||
else:
|
||||
raise ValueError(f"Output variable `{variable}` must be a string")
|
||||
raise OutputValidationError(f"Output variable `{variable}` must be a string")
|
||||
|
||||
if len(value) > dify_config.CODE_MAX_STRING_LENGTH:
|
||||
raise ValueError(
|
||||
raise OutputValidationError(
|
||||
f"The length of output variable `{variable}` must be"
|
||||
f" less than {dify_config.CODE_MAX_STRING_LENGTH} characters"
|
||||
)
|
||||
|
@ -97,10 +103,10 @@ class CodeNode(BaseNode[CodeNodeData]):
|
|||
if value is None:
|
||||
return None
|
||||
else:
|
||||
raise ValueError(f"Output variable `{variable}` must be a number")
|
||||
raise OutputValidationError(f"Output variable `{variable}` must be a number")
|
||||
|
||||
if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER:
|
||||
raise ValueError(
|
||||
raise OutputValidationError(
|
||||
f"Output variable `{variable}` is out of range,"
|
||||
f" it must be between {dify_config.CODE_MIN_NUMBER} and {dify_config.CODE_MAX_NUMBER}."
|
||||
)
|
||||
|
@ -108,7 +114,7 @@ class CodeNode(BaseNode[CodeNodeData]):
|
|||
if isinstance(value, float):
|
||||
# raise error if precision is too high
|
||||
if len(str(value).split(".")[1]) > dify_config.CODE_MAX_PRECISION:
|
||||
raise ValueError(
|
||||
raise OutputValidationError(
|
||||
f"Output variable `{variable}` has too high precision,"
|
||||
f" it must be less than {dify_config.CODE_MAX_PRECISION} digits."
|
||||
)
|
||||
|
@ -125,7 +131,7 @@ class CodeNode(BaseNode[CodeNodeData]):
|
|||
:return:
|
||||
"""
|
||||
if depth > dify_config.CODE_MAX_DEPTH:
|
||||
raise ValueError(f"Depth limit ${dify_config.CODE_MAX_DEPTH} reached, object too deep.")
|
||||
raise DepthLimitError(f"Depth limit ${dify_config.CODE_MAX_DEPTH} reached, object too deep.")
|
||||
|
||||
transformed_result = {}
|
||||
if output_schema is None:
|
||||
|
@ -177,14 +183,14 @@ class CodeNode(BaseNode[CodeNodeData]):
|
|||
depth=depth + 1,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
raise OutputValidationError(
|
||||
f"Output {prefix}.{output_name} is not a valid array."
|
||||
f" make sure all elements are of the same type."
|
||||
)
|
||||
elif output_value is None:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"Output {prefix}.{output_name} is not a valid type.")
|
||||
raise OutputValidationError(f"Output {prefix}.{output_name} is not a valid type.")
|
||||
|
||||
return result
|
||||
|
||||
|
@ -192,7 +198,7 @@ class CodeNode(BaseNode[CodeNodeData]):
|
|||
for output_name, output_config in output_schema.items():
|
||||
dot = "." if prefix else ""
|
||||
if output_name not in result:
|
||||
raise ValueError(f"Output {prefix}{dot}{output_name} is missing.")
|
||||
raise OutputValidationError(f"Output {prefix}{dot}{output_name} is missing.")
|
||||
|
||||
if output_config.type == "object":
|
||||
# check if output is object
|
||||
|
@ -200,7 +206,7 @@ class CodeNode(BaseNode[CodeNodeData]):
|
|||
if isinstance(result.get(output_name), type(None)):
|
||||
transformed_result[output_name] = None
|
||||
else:
|
||||
raise ValueError(
|
||||
raise OutputValidationError(
|
||||
f"Output {prefix}{dot}{output_name} is not an object,"
|
||||
f" got {type(result.get(output_name))} instead."
|
||||
)
|
||||
|
@ -228,13 +234,13 @@ class CodeNode(BaseNode[CodeNodeData]):
|
|||
if isinstance(result[output_name], type(None)):
|
||||
transformed_result[output_name] = None
|
||||
else:
|
||||
raise ValueError(
|
||||
raise OutputValidationError(
|
||||
f"Output {prefix}{dot}{output_name} is not an array,"
|
||||
f" got {type(result.get(output_name))} instead."
|
||||
)
|
||||
else:
|
||||
if len(result[output_name]) > dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH:
|
||||
raise ValueError(
|
||||
raise OutputValidationError(
|
||||
f"The length of output variable `{prefix}{dot}{output_name}` must be"
|
||||
f" less than {dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH} elements."
|
||||
)
|
||||
|
@ -249,13 +255,13 @@ class CodeNode(BaseNode[CodeNodeData]):
|
|||
if isinstance(result[output_name], type(None)):
|
||||
transformed_result[output_name] = None
|
||||
else:
|
||||
raise ValueError(
|
||||
raise OutputValidationError(
|
||||
f"Output {prefix}{dot}{output_name} is not an array,"
|
||||
f" got {type(result.get(output_name))} instead."
|
||||
)
|
||||
else:
|
||||
if len(result[output_name]) > dify_config.CODE_MAX_STRING_ARRAY_LENGTH:
|
||||
raise ValueError(
|
||||
raise OutputValidationError(
|
||||
f"The length of output variable `{prefix}{dot}{output_name}` must be"
|
||||
f" less than {dify_config.CODE_MAX_STRING_ARRAY_LENGTH} elements."
|
||||
)
|
||||
|
@ -270,13 +276,13 @@ class CodeNode(BaseNode[CodeNodeData]):
|
|||
if isinstance(result[output_name], type(None)):
|
||||
transformed_result[output_name] = None
|
||||
else:
|
||||
raise ValueError(
|
||||
raise OutputValidationError(
|
||||
f"Output {prefix}{dot}{output_name} is not an array,"
|
||||
f" got {type(result.get(output_name))} instead."
|
||||
)
|
||||
else:
|
||||
if len(result[output_name]) > dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH:
|
||||
raise ValueError(
|
||||
raise OutputValidationError(
|
||||
f"The length of output variable `{prefix}{dot}{output_name}` must be"
|
||||
f" less than {dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH} elements."
|
||||
)
|
||||
|
@ -286,7 +292,7 @@ class CodeNode(BaseNode[CodeNodeData]):
|
|||
if value is None:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
raise OutputValidationError(
|
||||
f"Output {prefix}{dot}{output_name}[{i}] is not an object,"
|
||||
f" got {type(value)} instead at index {i}."
|
||||
)
|
||||
|
@ -303,13 +309,13 @@ class CodeNode(BaseNode[CodeNodeData]):
|
|||
for i, value in enumerate(result[output_name])
|
||||
]
|
||||
else:
|
||||
raise ValueError(f"Output type {output_config.type} is not supported.")
|
||||
raise OutputValidationError(f"Output type {output_config.type} is not supported.")
|
||||
|
||||
parameters_validated[output_name] = True
|
||||
|
||||
# check if all output parameters are validated
|
||||
if len(parameters_validated) != len(result):
|
||||
raise ValueError("Not all output parameters are validated.")
|
||||
raise CodeNodeError("Not all output parameters are validated.")
|
||||
|
||||
return transformed_result
|
||||
|
||||
|
|
16
api/core/workflow/nodes/code/exc.py
Normal file
16
api/core/workflow/nodes/code/exc.py
Normal file
|
@ -0,0 +1,16 @@
|
|||
class CodeNodeError(ValueError):
|
||||
"""Base class for code node errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class OutputValidationError(CodeNodeError):
|
||||
"""Raised when there is an output validation error."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DepthLimitError(CodeNodeError):
|
||||
"""Raised when the depth limit is reached."""
|
||||
|
||||
pass
|
|
@ -1,4 +1,4 @@
|
|||
class DocumentExtractorError(Exception):
|
||||
class DocumentExtractorError(ValueError):
|
||||
"""Base exception for errors related to the DocumentExtractorNode."""
|
||||
|
||||
|
||||
|
|
|
@ -6,12 +6,14 @@ import docx
|
|||
import pandas as pd
|
||||
import pypdfium2
|
||||
import yaml
|
||||
from unstructured.partition.api import partition_via_api
|
||||
from unstructured.partition.email import partition_email
|
||||
from unstructured.partition.epub import partition_epub
|
||||
from unstructured.partition.msg import partition_msg
|
||||
from unstructured.partition.ppt import partition_ppt
|
||||
from unstructured.partition.pptx import partition_pptx
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import File, FileTransferMethod, file_manager
|
||||
from core.helper import ssrf_proxy
|
||||
from core.variables import ArrayFileSegment
|
||||
|
@ -196,10 +198,8 @@ def _download_file_content(file: File) -> bytes:
|
|||
response = ssrf_proxy.get(file.remote_url)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
elif file.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
return file_manager.download(file)
|
||||
else:
|
||||
raise ValueError(f"Unsupported transfer method: {file.transfer_method}")
|
||||
return file_manager.download(file)
|
||||
except Exception as e:
|
||||
raise FileDownloadError(f"Error downloading file: {str(e)}") from e
|
||||
|
||||
|
@ -263,6 +263,13 @@ def _extract_text_from_ppt(file_content: bytes) -> str:
|
|||
def _extract_text_from_pptx(file_content: bytes) -> str:
|
||||
try:
|
||||
with io.BytesIO(file_content) as file:
|
||||
if dify_config.UNSTRUCTURED_API_URL and dify_config.UNSTRUCTURED_API_KEY:
|
||||
elements = partition_via_api(
|
||||
file=file,
|
||||
api_url=dify_config.UNSTRUCTURED_API_URL,
|
||||
api_key=dify_config.UNSTRUCTURED_API_KEY,
|
||||
)
|
||||
else:
|
||||
elements = partition_pptx(file=file)
|
||||
return "\n".join([getattr(element, "text", "") for element in elements])
|
||||
except Exception as e:
|
||||
|
|
18
api/core/workflow/nodes/http_request/exc.py
Normal file
18
api/core/workflow/nodes/http_request/exc.py
Normal file
|
@ -0,0 +1,18 @@
|
|||
class HttpRequestNodeError(ValueError):
|
||||
"""Custom error for HTTP request node."""
|
||||
|
||||
|
||||
class AuthorizationConfigError(HttpRequestNodeError):
|
||||
"""Raised when authorization config is missing or invalid."""
|
||||
|
||||
|
||||
class FileFetchError(HttpRequestNodeError):
|
||||
"""Raised when a file cannot be fetched."""
|
||||
|
||||
|
||||
class InvalidHttpMethodError(HttpRequestNodeError):
|
||||
"""Raised when an invalid HTTP method is used."""
|
||||
|
||||
|
||||
class ResponseSizeError(HttpRequestNodeError):
|
||||
"""Raised when the response size exceeds the allowed threshold."""
|
|
@ -18,6 +18,12 @@ from .entities import (
|
|||
HttpRequestNodeTimeout,
|
||||
Response,
|
||||
)
|
||||
from .exc import (
|
||||
AuthorizationConfigError,
|
||||
FileFetchError,
|
||||
InvalidHttpMethodError,
|
||||
ResponseSizeError,
|
||||
)
|
||||
|
||||
BODY_TYPE_TO_CONTENT_TYPE = {
|
||||
"json": "application/json",
|
||||
|
@ -51,7 +57,7 @@ class Executor:
|
|||
# If authorization API key is present, convert the API key using the variable pool
|
||||
if node_data.authorization.type == "api-key":
|
||||
if node_data.authorization.config is None:
|
||||
raise ValueError("authorization config is required")
|
||||
raise AuthorizationConfigError("authorization config is required")
|
||||
node_data.authorization.config.api_key = variable_pool.convert_template(
|
||||
node_data.authorization.config.api_key
|
||||
).text
|
||||
|
@ -82,8 +88,10 @@ class Executor:
|
|||
self.url = self.variable_pool.convert_template(self.node_data.url).text
|
||||
|
||||
def _init_params(self):
|
||||
params = self.variable_pool.convert_template(self.node_data.params).text
|
||||
self.params = _plain_text_to_dict(params)
|
||||
params = _plain_text_to_dict(self.node_data.params)
|
||||
for key in params:
|
||||
params[key] = self.variable_pool.convert_template(params[key]).text
|
||||
self.params = params
|
||||
|
||||
def _init_headers(self):
|
||||
headers = self.variable_pool.convert_template(self.node_data.headers).text
|
||||
|
@ -116,7 +124,7 @@ class Executor:
|
|||
file_selector = data[0].file
|
||||
file_variable = self.variable_pool.get_file(file_selector)
|
||||
if file_variable is None:
|
||||
raise ValueError(f"cannot fetch file with selector {file_selector}")
|
||||
raise FileFetchError(f"cannot fetch file with selector {file_selector}")
|
||||
file = file_variable.value
|
||||
self.content = file_manager.download(file)
|
||||
case "x-www-form-urlencoded":
|
||||
|
@ -155,12 +163,12 @@ class Executor:
|
|||
headers = deepcopy(self.headers) or {}
|
||||
if self.auth.type == "api-key":
|
||||
if self.auth.config is None:
|
||||
raise ValueError("self.authorization config is required")
|
||||
raise AuthorizationConfigError("self.authorization config is required")
|
||||
if authorization.config is None:
|
||||
raise ValueError("authorization config is required")
|
||||
raise AuthorizationConfigError("authorization config is required")
|
||||
|
||||
if self.auth.config.api_key is None:
|
||||
raise ValueError("api_key is required")
|
||||
raise AuthorizationConfigError("api_key is required")
|
||||
|
||||
if not authorization.config.header:
|
||||
authorization.config.header = "Authorization"
|
||||
|
@ -183,7 +191,7 @@ class Executor:
|
|||
else dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE
|
||||
)
|
||||
if executor_response.size > threshold_size:
|
||||
raise ValueError(
|
||||
raise ResponseSizeError(
|
||||
f'{"File" if executor_response.is_file else "Text"} size is too large,'
|
||||
f' max size is {threshold_size / 1024 / 1024:.2f} MB,'
|
||||
f' but current size is {executor_response.readable_size}.'
|
||||
|
@ -196,7 +204,7 @@ class Executor:
|
|||
do http request depending on api bundle
|
||||
"""
|
||||
if self.method not in {"get", "head", "post", "put", "delete", "patch"}:
|
||||
raise ValueError(f"Invalid http method {self.method}")
|
||||
raise InvalidHttpMethodError(f"Invalid http method {self.method}")
|
||||
|
||||
request_args = {
|
||||
"url": self.url,
|
||||
|
|
|
@ -20,6 +20,7 @@ from .entities import (
|
|||
HttpRequestNodeTimeout,
|
||||
Response,
|
||||
)
|
||||
from .exc import HttpRequestNodeError
|
||||
|
||||
HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout(
|
||||
connect=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT,
|
||||
|
@ -77,7 +78,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
|||
"request": http_executor.to_log(),
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
except HttpRequestNodeError as e:
|
||||
logger.warning(f"http request node {self.node_id} failed to run: {e}")
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
@ -5,6 +6,12 @@ from pydantic import Field
|
|||
from core.workflow.nodes.base import BaseIterationNodeData, BaseIterationState, BaseNodeData
|
||||
|
||||
|
||||
class ErrorHandleMode(str, Enum):
|
||||
TERMINATED = "terminated"
|
||||
CONTINUE_ON_ERROR = "continue-on-error"
|
||||
REMOVE_ABNORMAL_OUTPUT = "remove-abnormal-output"
|
||||
|
||||
|
||||
class IterationNodeData(BaseIterationNodeData):
|
||||
"""
|
||||
Iteration Node Data.
|
||||
|
@ -13,6 +20,9 @@ class IterationNodeData(BaseIterationNodeData):
|
|||
parent_loop_id: Optional[str] = None # redundant field, not used currently
|
||||
iterator_selector: list[str] # variable selector
|
||||
output_selector: list[str] # output selector
|
||||
is_parallel: bool = False # open the parallel mode or not
|
||||
parallel_nums: int = 10 # the numbers of parallel
|
||||
error_handle_mode: ErrorHandleMode = ErrorHandleMode.TERMINATED # how to handle the error
|
||||
|
||||
|
||||
class IterationStartNodeData(BaseNodeData):
|
||||
|
|
|
@ -1,12 +1,20 @@
|
|||
import logging
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from concurrent.futures import Future, wait
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, cast
|
||||
from queue import Empty, Queue
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
|
||||
from configs import dify_config
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.variables import IntegerSegment
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||
from core.workflow.entities.node_entities import (
|
||||
NodeRunMetadataKey,
|
||||
NodeRunResult,
|
||||
)
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
BaseGraphEvent,
|
||||
BaseNodeEvent,
|
||||
|
@ -17,6 +25,9 @@ from core.workflow.graph_engine.entities.event import (
|
|||
IterationRunNextEvent,
|
||||
IterationRunStartedEvent,
|
||||
IterationRunSucceededEvent,
|
||||
NodeInIterationFailedEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
|
@ -24,9 +35,11 @@ from core.workflow.graph_engine.entities.graph import Graph
|
|||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
||||
from core.workflow.nodes.iteration.entities import IterationNodeData
|
||||
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -38,6 +51,17 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||
_node_data_cls = IterationNodeData
|
||||
_node_type = NodeType.ITERATION
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
return {
|
||||
"type": "iteration",
|
||||
"config": {
|
||||
"is_parallel": False,
|
||||
"parallel_nums": 10,
|
||||
"error_handle_mode": ErrorHandleMode.TERMINATED.value,
|
||||
},
|
||||
}
|
||||
|
||||
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
||||
"""
|
||||
Run the node.
|
||||
|
@ -83,7 +107,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||
variable_pool.add([self.node_id, "item"], iterator_list_value[0])
|
||||
|
||||
# init graph engine
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine, GraphEngineThreadPool
|
||||
|
||||
graph_engine = GraphEngine(
|
||||
tenant_id=self.tenant_id,
|
||||
|
@ -123,108 +147,64 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||
index=0,
|
||||
pre_iteration_output=None,
|
||||
)
|
||||
|
||||
outputs: list[Any] = []
|
||||
try:
|
||||
for _ in range(len(iterator_list_value)):
|
||||
# run workflow
|
||||
rst = graph_engine.run()
|
||||
for event in rst:
|
||||
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
|
||||
event.in_iteration_id = self.node_id
|
||||
|
||||
if (
|
||||
isinstance(event, BaseNodeEvent)
|
||||
and event.node_type == NodeType.ITERATION_START
|
||||
and not isinstance(event, NodeRunStreamChunkEvent)
|
||||
):
|
||||
if self.node_data.is_parallel:
|
||||
futures: list[Future] = []
|
||||
q = Queue()
|
||||
thread_pool = GraphEngineThreadPool(max_workers=self.node_data.parallel_nums, max_submit_count=100)
|
||||
for index, item in enumerate(iterator_list_value):
|
||||
future: Future = thread_pool.submit(
|
||||
self._run_single_iter_parallel,
|
||||
current_app._get_current_object(),
|
||||
q,
|
||||
iterator_list_value,
|
||||
inputs,
|
||||
outputs,
|
||||
start_at,
|
||||
graph_engine,
|
||||
iteration_graph,
|
||||
index,
|
||||
item,
|
||||
)
|
||||
future.add_done_callback(thread_pool.task_done_callback)
|
||||
futures.append(future)
|
||||
succeeded_count = 0
|
||||
while True:
|
||||
try:
|
||||
event = q.get(timeout=1)
|
||||
if event is None:
|
||||
break
|
||||
if isinstance(event, IterationRunNextEvent):
|
||||
succeeded_count += 1
|
||||
if succeeded_count == len(futures):
|
||||
q.put(None)
|
||||
yield event
|
||||
if isinstance(event, RunCompletedEvent):
|
||||
q.put(None)
|
||||
for f in futures:
|
||||
if not f.done():
|
||||
f.cancel()
|
||||
yield event
|
||||
if isinstance(event, IterationRunFailedEvent):
|
||||
q.put(None)
|
||||
yield event
|
||||
except Empty:
|
||||
continue
|
||||
|
||||
if isinstance(event, NodeRunSucceededEvent):
|
||||
if event.route_node_state.node_run_result:
|
||||
metadata = event.route_node_state.node_run_result.metadata
|
||||
if not metadata:
|
||||
metadata = {}
|
||||
|
||||
if NodeRunMetadataKey.ITERATION_ID not in metadata:
|
||||
metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id
|
||||
index_variable = variable_pool.get([self.node_id, "index"])
|
||||
if not isinstance(index_variable, IntegerSegment):
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=f"Invalid index variable type: {type(index_variable)}",
|
||||
)
|
||||
)
|
||||
return
|
||||
metadata[NodeRunMetadataKey.ITERATION_INDEX] = index_variable.value
|
||||
event.route_node_state.node_run_result.metadata = metadata
|
||||
|
||||
yield event
|
||||
elif isinstance(event, BaseGraphEvent):
|
||||
if isinstance(event, GraphRunFailedEvent):
|
||||
# iteration run failed
|
||||
yield IterationRunFailedEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs={"output": jsonable_encoder(outputs)},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
|
||||
error=event.error,
|
||||
)
|
||||
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=event.error,
|
||||
)
|
||||
)
|
||||
return
|
||||
# wait all threads
|
||||
wait(futures)
|
||||
else:
|
||||
event = cast(InNodeEvent, event)
|
||||
yield event
|
||||
|
||||
# append to iteration output variable list
|
||||
current_iteration_output_variable = variable_pool.get(self.node_data.output_selector)
|
||||
if current_iteration_output_variable is None:
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=f"Iteration output variable {self.node_data.output_selector} not found",
|
||||
for _ in range(len(iterator_list_value)):
|
||||
yield from self._run_single_iter(
|
||||
iterator_list_value,
|
||||
variable_pool,
|
||||
inputs,
|
||||
outputs,
|
||||
start_at,
|
||||
graph_engine,
|
||||
iteration_graph,
|
||||
)
|
||||
)
|
||||
return
|
||||
current_iteration_output = current_iteration_output_variable.to_object()
|
||||
outputs.append(current_iteration_output)
|
||||
|
||||
# remove all nodes outputs from variable pool
|
||||
for node_id in iteration_graph.node_ids:
|
||||
variable_pool.remove([node_id])
|
||||
|
||||
# move to next iteration
|
||||
current_index_variable = variable_pool.get([self.node_id, "index"])
|
||||
if not isinstance(current_index_variable, IntegerSegment):
|
||||
raise ValueError(f"iteration {self.node_id} current index not found")
|
||||
|
||||
next_index = current_index_variable.value + 1
|
||||
variable_pool.add([self.node_id, "index"], next_index)
|
||||
|
||||
if next_index < len(iterator_list_value):
|
||||
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
|
||||
|
||||
yield IterationRunNextEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
index=next_index,
|
||||
pre_iteration_output=jsonable_encoder(current_iteration_output),
|
||||
)
|
||||
|
||||
yield IterationRunSucceededEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
|
@ -330,3 +310,231 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||
}
|
||||
|
||||
return variable_mapping
|
||||
|
||||
def _handle_event_metadata(
|
||||
self, event: BaseNodeEvent, iter_run_index: str, parallel_mode_run_id: str
|
||||
) -> NodeRunStartedEvent | BaseNodeEvent:
|
||||
"""
|
||||
add iteration metadata to event.
|
||||
"""
|
||||
if not isinstance(event, BaseNodeEvent):
|
||||
return event
|
||||
if self.node_data.is_parallel and isinstance(event, NodeRunStartedEvent):
|
||||
event.parallel_mode_run_id = parallel_mode_run_id
|
||||
return event
|
||||
if event.route_node_state.node_run_result:
|
||||
metadata = event.route_node_state.node_run_result.metadata
|
||||
if not metadata:
|
||||
metadata = {}
|
||||
|
||||
if NodeRunMetadataKey.ITERATION_ID not in metadata:
|
||||
metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id
|
||||
if self.node_data.is_parallel:
|
||||
metadata[NodeRunMetadataKey.PARALLEL_MODE_RUN_ID] = parallel_mode_run_id
|
||||
else:
|
||||
metadata[NodeRunMetadataKey.ITERATION_INDEX] = iter_run_index
|
||||
event.route_node_state.node_run_result.metadata = metadata
|
||||
return event
|
||||
|
||||
def _run_single_iter(
|
||||
self,
|
||||
iterator_list_value: list[str],
|
||||
variable_pool: VariablePool,
|
||||
inputs: dict[str, list],
|
||||
outputs: list,
|
||||
start_at: datetime,
|
||||
graph_engine: "GraphEngine",
|
||||
iteration_graph: Graph,
|
||||
parallel_mode_run_id: Optional[str] = None,
|
||||
) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
||||
"""
|
||||
run single iteration
|
||||
"""
|
||||
try:
|
||||
rst = graph_engine.run()
|
||||
# get current iteration index
|
||||
current_index = variable_pool.get([self.node_id, "index"]).value
|
||||
next_index = int(current_index) + 1
|
||||
|
||||
if current_index is None:
|
||||
raise ValueError(f"iteration {self.node_id} current index not found")
|
||||
for event in rst:
|
||||
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
|
||||
event.in_iteration_id = self.node_id
|
||||
|
||||
if (
|
||||
isinstance(event, BaseNodeEvent)
|
||||
and event.node_type == NodeType.ITERATION_START
|
||||
and not isinstance(event, NodeRunStreamChunkEvent)
|
||||
):
|
||||
continue
|
||||
|
||||
if isinstance(event, NodeRunSucceededEvent):
|
||||
yield self._handle_event_metadata(event, current_index, parallel_mode_run_id)
|
||||
elif isinstance(event, BaseGraphEvent):
|
||||
if isinstance(event, GraphRunFailedEvent):
|
||||
# iteration run failed
|
||||
if self.node_data.is_parallel:
|
||||
yield IterationRunFailedEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
parallel_mode_run_id=parallel_mode_run_id,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs={"output": jsonable_encoder(outputs)},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
|
||||
error=event.error,
|
||||
)
|
||||
else:
|
||||
yield IterationRunFailedEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs={"output": jsonable_encoder(outputs)},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
|
||||
error=event.error,
|
||||
)
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=event.error,
|
||||
)
|
||||
)
|
||||
return
|
||||
else:
|
||||
event = cast(InNodeEvent, event)
|
||||
metadata_event = self._handle_event_metadata(event, current_index, parallel_mode_run_id)
|
||||
if isinstance(event, NodeRunFailedEvent):
|
||||
if self.node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR:
|
||||
yield NodeInIterationFailedEvent(
|
||||
**metadata_event.model_dump(),
|
||||
)
|
||||
outputs.insert(current_index, None)
|
||||
variable_pool.add([self.node_id, "index"], next_index)
|
||||
if next_index < len(iterator_list_value):
|
||||
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
|
||||
yield IterationRunNextEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
index=next_index,
|
||||
parallel_mode_run_id=parallel_mode_run_id,
|
||||
pre_iteration_output=None,
|
||||
)
|
||||
return
|
||||
elif self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
|
||||
yield NodeInIterationFailedEvent(
|
||||
**metadata_event.model_dump(),
|
||||
)
|
||||
variable_pool.add([self.node_id, "index"], next_index)
|
||||
|
||||
if next_index < len(iterator_list_value):
|
||||
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
|
||||
yield IterationRunNextEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
index=next_index,
|
||||
parallel_mode_run_id=parallel_mode_run_id,
|
||||
pre_iteration_output=None,
|
||||
)
|
||||
return
|
||||
elif self.node_data.error_handle_mode == ErrorHandleMode.TERMINATED:
|
||||
yield IterationRunFailedEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs={"output": None},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
|
||||
error=event.error,
|
||||
)
|
||||
yield metadata_event
|
||||
|
||||
current_iteration_output = variable_pool.get(self.node_data.output_selector).value
|
||||
outputs.insert(current_index, current_iteration_output)
|
||||
# remove all nodes outputs from variable pool
|
||||
for node_id in iteration_graph.node_ids:
|
||||
variable_pool.remove([node_id])
|
||||
|
||||
# move to next iteration
|
||||
variable_pool.add([self.node_id, "index"], next_index)
|
||||
|
||||
if next_index < len(iterator_list_value):
|
||||
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
|
||||
yield IterationRunNextEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
index=next_index,
|
||||
parallel_mode_run_id=parallel_mode_run_id,
|
||||
pre_iteration_output=jsonable_encoder(current_iteration_output) if current_iteration_output else None,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Iteration run failed:{str(e)}")
|
||||
yield IterationRunFailedEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs={"output": None},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
|
||||
error=str(e),
|
||||
)
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
)
|
||||
)
|
||||
|
||||
def _run_single_iter_parallel(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
q: Queue,
|
||||
iterator_list_value: list[str],
|
||||
inputs: dict[str, list],
|
||||
outputs: list,
|
||||
start_at: datetime,
|
||||
graph_engine: "GraphEngine",
|
||||
iteration_graph: Graph,
|
||||
index: int,
|
||||
item: Any,
|
||||
) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
||||
"""
|
||||
run single iteration in parallel mode
|
||||
"""
|
||||
with flask_app.app_context():
|
||||
parallel_mode_run_id = uuid.uuid4().hex
|
||||
graph_engine_copy = graph_engine.create_copy()
|
||||
variable_pool_copy = graph_engine_copy.graph_runtime_state.variable_pool
|
||||
variable_pool_copy.add([self.node_id, "index"], index)
|
||||
variable_pool_copy.add([self.node_id, "item"], item)
|
||||
for event in self._run_single_iter(
|
||||
iterator_list_value=iterator_list_value,
|
||||
variable_pool=variable_pool_copy,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
start_at=start_at,
|
||||
graph_engine=graph_engine_copy,
|
||||
iteration_graph=iteration_graph,
|
||||
parallel_mode_run_id=parallel_mode_run_id,
|
||||
):
|
||||
q.put(event)
|
||||
|
|
16
api/core/workflow/nodes/list_operator/exc.py
Normal file
16
api/core/workflow/nodes/list_operator/exc.py
Normal file
|
@ -0,0 +1,16 @@
|
|||
class ListOperatorError(ValueError):
|
||||
"""Base class for all ListOperator errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class InvalidFilterValueError(ListOperatorError):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidKeyError(ListOperatorError):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidConditionError(ListOperatorError):
|
||||
pass
|
|
@ -1,5 +1,5 @@
|
|||
from collections.abc import Callable, Sequence
|
||||
from typing import Literal
|
||||
from typing import Literal, Union
|
||||
|
||||
from core.file import File
|
||||
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
|
||||
|
@ -9,6 +9,7 @@ from core.workflow.nodes.enums import NodeType
|
|||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from .entities import ListOperatorNodeData
|
||||
from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError
|
||||
|
||||
|
||||
class ListOperatorNode(BaseNode[ListOperatorNodeData]):
|
||||
|
@ -26,7 +27,17 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
|
|||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs
|
||||
)
|
||||
if variable.value and not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment):
|
||||
if not variable.value:
|
||||
inputs = {"variable": []}
|
||||
process_data = {"variable": []}
|
||||
outputs = {"result": [], "first_record": None, "last_record": None}
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
)
|
||||
if not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment):
|
||||
error_message = (
|
||||
f"Variable {self.node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment "
|
||||
"or ArrayStringSegment"
|
||||
|
@ -36,23 +47,59 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
|
|||
)
|
||||
|
||||
if isinstance(variable, ArrayFileSegment):
|
||||
inputs = {"variable": [item.to_dict() for item in variable.value]}
|
||||
process_data["variable"] = [item.to_dict() for item in variable.value]
|
||||
else:
|
||||
inputs = {"variable": variable.value}
|
||||
process_data["variable"] = variable.value
|
||||
|
||||
try:
|
||||
# Filter
|
||||
if self.node_data.filter_by.enabled:
|
||||
variable = self._apply_filter(variable)
|
||||
|
||||
# Order
|
||||
if self.node_data.order_by.enabled:
|
||||
variable = self._apply_order(variable)
|
||||
|
||||
# Slice
|
||||
if self.node_data.limit.enabled:
|
||||
variable = self._apply_slice(variable)
|
||||
|
||||
outputs = {
|
||||
"result": variable.value,
|
||||
"first_record": variable.value[0] if variable.value else None,
|
||||
"last_record": variable.value[-1] if variable.value else None,
|
||||
}
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
)
|
||||
except ListOperatorError as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
)
|
||||
|
||||
def _apply_filter(
|
||||
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
|
||||
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
|
||||
for condition in self.node_data.filter_by.conditions:
|
||||
if isinstance(variable, ArrayStringSegment):
|
||||
if not isinstance(condition.value, str):
|
||||
raise ValueError(f"Invalid filter value: {condition.value}")
|
||||
raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
|
||||
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
|
||||
filter_func = _get_string_filter_func(condition=condition.comparison_operator, value=value)
|
||||
result = list(filter(filter_func, variable.value))
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
elif isinstance(variable, ArrayNumberSegment):
|
||||
if not isinstance(condition.value, str):
|
||||
raise ValueError(f"Invalid filter value: {condition.value}")
|
||||
raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
|
||||
value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text
|
||||
filter_func = _get_number_filter_func(condition=condition.comparison_operator, value=float(value))
|
||||
result = list(filter(filter_func, variable.value))
|
||||
|
@ -69,9 +116,11 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
|
|||
)
|
||||
result = list(filter(filter_func, variable.value))
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
return variable
|
||||
|
||||
# Order
|
||||
if self.node_data.order_by.enabled:
|
||||
def _apply_order(
|
||||
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
|
||||
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
|
||||
if isinstance(variable, ArrayStringSegment):
|
||||
result = _order_string(order=self.node_data.order_by.value, array=variable.value)
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
|
@ -83,23 +132,13 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
|
|||
order=self.node_data.order_by.value, order_by=self.node_data.order_by.key, array=variable.value
|
||||
)
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
return variable
|
||||
|
||||
# Slice
|
||||
if self.node_data.limit.enabled:
|
||||
def _apply_slice(
|
||||
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
|
||||
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
|
||||
result = variable.value[: self.node_data.limit.size]
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
|
||||
outputs = {
|
||||
"result": variable.value,
|
||||
"first_record": variable.value[0] if variable.value else None,
|
||||
"last_record": variable.value[-1] if variable.value else None,
|
||||
}
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
)
|
||||
return variable.model_copy(update={"value": result})
|
||||
|
||||
|
||||
def _get_file_extract_number_func(*, key: str) -> Callable[[File], int]:
|
||||
|
@ -107,7 +146,7 @@ def _get_file_extract_number_func(*, key: str) -> Callable[[File], int]:
|
|||
case "size":
|
||||
return lambda x: x.size
|
||||
case _:
|
||||
raise ValueError(f"Invalid key: {key}")
|
||||
raise InvalidKeyError(f"Invalid key: {key}")
|
||||
|
||||
|
||||
def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]:
|
||||
|
@ -118,14 +157,14 @@ def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]:
|
|||
return lambda x: x.type
|
||||
case "extension":
|
||||
return lambda x: x.extension or ""
|
||||
case "mimetype":
|
||||
case "mime_type":
|
||||
return lambda x: x.mime_type or ""
|
||||
case "transfer_method":
|
||||
return lambda x: x.transfer_method
|
||||
case "url":
|
||||
return lambda x: x.remote_url or ""
|
||||
case _:
|
||||
raise ValueError(f"Invalid key: {key}")
|
||||
raise InvalidKeyError(f"Invalid key: {key}")
|
||||
|
||||
|
||||
def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bool]:
|
||||
|
@ -151,7 +190,7 @@ def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bo
|
|||
case "not empty":
|
||||
return lambda x: x != ""
|
||||
case _:
|
||||
raise ValueError(f"Invalid condition: {condition}")
|
||||
raise InvalidConditionError(f"Invalid condition: {condition}")
|
||||
|
||||
|
||||
def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callable[[str], bool]:
|
||||
|
@ -161,7 +200,7 @@ def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callab
|
|||
case "not in":
|
||||
return lambda x: not _in(value)(x)
|
||||
case _:
|
||||
raise ValueError(f"Invalid condition: {condition}")
|
||||
raise InvalidConditionError(f"Invalid condition: {condition}")
|
||||
|
||||
|
||||
def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[int | float], bool]:
|
||||
|
@ -179,7 +218,7 @@ def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[
|
|||
case "≥":
|
||||
return _ge(value)
|
||||
case _:
|
||||
raise ValueError(f"Invalid condition: {condition}")
|
||||
raise InvalidConditionError(f"Invalid condition: {condition}")
|
||||
|
||||
|
||||
def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]:
|
||||
|
@ -193,7 +232,7 @@ def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str
|
|||
extract_func = _get_file_extract_number_func(key=key)
|
||||
return lambda x: _get_number_filter_func(condition=condition, value=float(value))(extract_func(x))
|
||||
else:
|
||||
raise ValueError(f"Invalid key: {key}")
|
||||
raise InvalidKeyError(f"Invalid key: {key}")
|
||||
|
||||
|
||||
def _contains(value: str):
|
||||
|
@ -256,4 +295,4 @@ def _order_file(*, order: Literal["asc", "desc"], order_by: str = "", array: Seq
|
|||
extract_func = _get_file_extract_number_func(key=order_by)
|
||||
return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc")
|
||||
else:
|
||||
raise ValueError(f"Invalid order key: {order_by}")
|
||||
raise InvalidKeyError(f"Invalid order key: {order_by}")
|
||||
|
|
26
api/core/workflow/nodes/llm/exc.py
Normal file
26
api/core/workflow/nodes/llm/exc.py
Normal file
|
@ -0,0 +1,26 @@
|
|||
class LLMNodeError(ValueError):
|
||||
"""Base class for LLM Node errors."""
|
||||
|
||||
|
||||
class VariableNotFoundError(LLMNodeError):
|
||||
"""Raised when a required variable is not found."""
|
||||
|
||||
|
||||
class InvalidContextStructureError(LLMNodeError):
|
||||
"""Raised when the context structure is invalid."""
|
||||
|
||||
|
||||
class InvalidVariableTypeError(LLMNodeError):
|
||||
"""Raised when the variable type is invalid."""
|
||||
|
||||
|
||||
class ModelNotExistError(LLMNodeError):
|
||||
"""Raised when the specified model does not exist."""
|
||||
|
||||
|
||||
class LLMModeRequiredError(LLMNodeError):
|
||||
"""Raised when LLM mode is required but not provided."""
|
||||
|
||||
|
||||
class NoPromptFoundError(LLMNodeError):
|
||||
"""Raised when no prompt is found in the LLM configuration."""
|
|
@ -56,6 +56,15 @@ from .entities import (
|
|||
LLMNodeData,
|
||||
ModelConfig,
|
||||
)
|
||||
from .exc import (
|
||||
InvalidContextStructureError,
|
||||
InvalidVariableTypeError,
|
||||
LLMModeRequiredError,
|
||||
LLMNodeError,
|
||||
ModelNotExistError,
|
||||
NoPromptFoundError,
|
||||
VariableNotFoundError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.file.models import File
|
||||
|
@ -103,7 +112,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
yield event
|
||||
|
||||
if context:
|
||||
node_inputs["#context#"] = context # type: ignore
|
||||
node_inputs["#context#"] = context
|
||||
|
||||
# fetch model config
|
||||
model_instance, model_config = self._fetch_model_config(self.node_data.model)
|
||||
|
@ -115,7 +124,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
if self.node_data.memory:
|
||||
query = self.graph_runtime_state.variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
|
||||
if not query:
|
||||
raise ValueError("Query not found")
|
||||
raise VariableNotFoundError("Query not found")
|
||||
query = query.text
|
||||
else:
|
||||
query = None
|
||||
|
@ -161,7 +170,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
usage = event.usage
|
||||
finish_reason = event.finish_reason
|
||||
break
|
||||
except Exception as e:
|
||||
except LLMNodeError as e:
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
|
@ -275,7 +284,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
variable_name = variable_selector.variable
|
||||
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
|
||||
if variable is None:
|
||||
raise ValueError(f"Variable {variable_selector.variable} not found")
|
||||
raise VariableNotFoundError(f"Variable {variable_selector.variable} not found")
|
||||
|
||||
def parse_dict(input_dict: Mapping[str, Any]) -> str:
|
||||
"""
|
||||
|
@ -325,7 +334,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
for variable_selector in variable_selectors:
|
||||
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
|
||||
if variable is None:
|
||||
raise ValueError(f"Variable {variable_selector.variable} not found")
|
||||
raise VariableNotFoundError(f"Variable {variable_selector.variable} not found")
|
||||
if isinstance(variable, NoneSegment):
|
||||
inputs[variable_selector.variable] = ""
|
||||
inputs[variable_selector.variable] = variable.to_object()
|
||||
|
@ -338,7 +347,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
for variable_selector in query_variable_selectors:
|
||||
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
|
||||
if variable is None:
|
||||
raise ValueError(f"Variable {variable_selector.variable} not found")
|
||||
raise VariableNotFoundError(f"Variable {variable_selector.variable} not found")
|
||||
if isinstance(variable, NoneSegment):
|
||||
continue
|
||||
inputs[variable_selector.variable] = variable.to_object()
|
||||
|
@ -355,7 +364,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
return variable.value
|
||||
elif isinstance(variable, NoneSegment | ArrayAnySegment):
|
||||
return []
|
||||
raise ValueError(f"Invalid variable type: {type(variable)}")
|
||||
raise InvalidVariableTypeError(f"Invalid variable type: {type(variable)}")
|
||||
|
||||
def _fetch_context(self, node_data: LLMNodeData):
|
||||
if not node_data.context.enabled:
|
||||
|
@ -376,7 +385,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
context_str += item + "\n"
|
||||
else:
|
||||
if "content" not in item:
|
||||
raise ValueError(f"Invalid context structure: {item}")
|
||||
raise InvalidContextStructureError(f"Invalid context structure: {item}")
|
||||
|
||||
context_str += item["content"] + "\n"
|
||||
|
||||
|
@ -441,7 +450,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
)
|
||||
|
||||
if provider_model is None:
|
||||
raise ValueError(f"Model {model_name} not exist.")
|
||||
raise ModelNotExistError(f"Model {model_name} not exist.")
|
||||
|
||||
if provider_model.status == ModelStatus.NO_CONFIGURE:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
||||
|
@ -460,12 +469,12 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
# get model mode
|
||||
model_mode = node_data_model.mode
|
||||
if not model_mode:
|
||||
raise ValueError("LLM mode is required.")
|
||||
raise LLMModeRequiredError("LLM mode is required.")
|
||||
|
||||
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
|
||||
|
||||
if not model_schema:
|
||||
raise ValueError(f"Model {model_name} not exist.")
|
||||
raise ModelNotExistError(f"Model {model_name} not exist.")
|
||||
|
||||
return model_instance, ModelConfigWithCredentialsEntity(
|
||||
provider=provider_name,
|
||||
|
@ -564,7 +573,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
filtered_prompt_messages.append(prompt_message)
|
||||
|
||||
if not filtered_prompt_messages:
|
||||
raise ValueError(
|
||||
raise NoPromptFoundError(
|
||||
"No prompt found in the LLM configuration. "
|
||||
"Please ensure a prompt is properly configured before proceeding."
|
||||
)
|
||||
|
@ -636,7 +645,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
variable_template_parser = VariableTemplateParser(template=prompt_template.text)
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
else:
|
||||
raise ValueError(f"Invalid prompt template type: {type(prompt_template)}")
|
||||
raise InvalidVariableTypeError(f"Invalid prompt template type: {type(prompt_template)}")
|
||||
|
||||
variable_mapping = {}
|
||||
for variable_selector in variable_selectors:
|
||||
|
|
50
api/core/workflow/nodes/parameter_extractor/exc.py
Normal file
50
api/core/workflow/nodes/parameter_extractor/exc.py
Normal file
|
@ -0,0 +1,50 @@
|
|||
class ParameterExtractorNodeError(ValueError):
|
||||
"""Base error for ParameterExtractorNode."""
|
||||
|
||||
|
||||
class InvalidModelTypeError(ParameterExtractorNodeError):
|
||||
"""Raised when the model is not a Large Language Model."""
|
||||
|
||||
|
||||
class ModelSchemaNotFoundError(ParameterExtractorNodeError):
|
||||
"""Raised when the model schema is not found."""
|
||||
|
||||
|
||||
class InvalidInvokeResultError(ParameterExtractorNodeError):
|
||||
"""Raised when the invoke result is invalid."""
|
||||
|
||||
|
||||
class InvalidTextContentTypeError(ParameterExtractorNodeError):
|
||||
"""Raised when the text content type is invalid."""
|
||||
|
||||
|
||||
class InvalidNumberOfParametersError(ParameterExtractorNodeError):
|
||||
"""Raised when the number of parameters is invalid."""
|
||||
|
||||
|
||||
class RequiredParameterMissingError(ParameterExtractorNodeError):
|
||||
"""Raised when a required parameter is missing."""
|
||||
|
||||
|
||||
class InvalidSelectValueError(ParameterExtractorNodeError):
|
||||
"""Raised when a select value is invalid."""
|
||||
|
||||
|
||||
class InvalidNumberValueError(ParameterExtractorNodeError):
|
||||
"""Raised when a number value is invalid."""
|
||||
|
||||
|
||||
class InvalidBoolValueError(ParameterExtractorNodeError):
|
||||
"""Raised when a bool value is invalid."""
|
||||
|
||||
|
||||
class InvalidStringValueError(ParameterExtractorNodeError):
|
||||
"""Raised when a string value is invalid."""
|
||||
|
||||
|
||||
class InvalidArrayValueError(ParameterExtractorNodeError):
|
||||
"""Raised when an array value is invalid."""
|
||||
|
||||
|
||||
class InvalidModelModeError(ParameterExtractorNodeError):
|
||||
"""Raised when the model mode is invalid."""
|
|
@ -32,6 +32,21 @@ from extensions.ext_database import db
|
|||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from .entities import ParameterExtractorNodeData
|
||||
from .exc import (
|
||||
InvalidArrayValueError,
|
||||
InvalidBoolValueError,
|
||||
InvalidInvokeResultError,
|
||||
InvalidModelModeError,
|
||||
InvalidModelTypeError,
|
||||
InvalidNumberOfParametersError,
|
||||
InvalidNumberValueError,
|
||||
InvalidSelectValueError,
|
||||
InvalidStringValueError,
|
||||
InvalidTextContentTypeError,
|
||||
ModelSchemaNotFoundError,
|
||||
ParameterExtractorNodeError,
|
||||
RequiredParameterMissingError,
|
||||
)
|
||||
from .prompts import (
|
||||
CHAT_EXAMPLE,
|
||||
CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE,
|
||||
|
@ -85,7 +100,7 @@ class ParameterExtractorNode(LLMNode):
|
|||
|
||||
model_instance, model_config = self._fetch_model_config(node_data.model)
|
||||
if not isinstance(model_instance.model_type_instance, LargeLanguageModel):
|
||||
raise ValueError("Model is not a Large Language Model")
|
||||
raise InvalidModelTypeError("Model is not a Large Language Model")
|
||||
|
||||
llm_model = model_instance.model_type_instance
|
||||
model_schema = llm_model.get_model_schema(
|
||||
|
@ -93,7 +108,7 @@ class ParameterExtractorNode(LLMNode):
|
|||
credentials=model_config.credentials,
|
||||
)
|
||||
if not model_schema:
|
||||
raise ValueError("Model schema not found")
|
||||
raise ModelSchemaNotFoundError("Model schema not found")
|
||||
|
||||
# fetch memory
|
||||
memory = self._fetch_memory(
|
||||
|
@ -155,7 +170,7 @@ class ParameterExtractorNode(LLMNode):
|
|||
process_data["usage"] = jsonable_encoder(usage)
|
||||
process_data["tool_call"] = jsonable_encoder(tool_call)
|
||||
process_data["llm_text"] = text
|
||||
except Exception as e:
|
||||
except ParameterExtractorNodeError as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=inputs,
|
||||
|
@ -177,7 +192,7 @@ class ParameterExtractorNode(LLMNode):
|
|||
|
||||
try:
|
||||
result = self._validate_result(data=node_data, result=result or {})
|
||||
except Exception as e:
|
||||
except ParameterExtractorNodeError as e:
|
||||
error = str(e)
|
||||
|
||||
# transform result into standard format
|
||||
|
@ -217,11 +232,11 @@ class ParameterExtractorNode(LLMNode):
|
|||
|
||||
# handle invoke result
|
||||
if not isinstance(invoke_result, LLMResult):
|
||||
raise ValueError(f"Invalid invoke result: {invoke_result}")
|
||||
raise InvalidInvokeResultError(f"Invalid invoke result: {invoke_result}")
|
||||
|
||||
text = invoke_result.message.content
|
||||
if not isinstance(text, str):
|
||||
raise ValueError(f"Invalid text content type: {type(text)}. Expected str.")
|
||||
raise InvalidTextContentTypeError(f"Invalid text content type: {type(text)}. Expected str.")
|
||||
|
||||
usage = invoke_result.usage
|
||||
tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None
|
||||
|
@ -344,7 +359,7 @@ class ParameterExtractorNode(LLMNode):
|
|||
files=files,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid model mode: {model_mode}")
|
||||
raise InvalidModelModeError(f"Invalid model mode: {model_mode}")
|
||||
|
||||
def _generate_prompt_engineering_completion_prompt(
|
||||
self,
|
||||
|
@ -449,36 +464,36 @@ class ParameterExtractorNode(LLMNode):
|
|||
Validate result.
|
||||
"""
|
||||
if len(data.parameters) != len(result):
|
||||
raise ValueError("Invalid number of parameters")
|
||||
raise InvalidNumberOfParametersError("Invalid number of parameters")
|
||||
|
||||
for parameter in data.parameters:
|
||||
if parameter.required and parameter.name not in result:
|
||||
raise ValueError(f"Parameter {parameter.name} is required")
|
||||
raise RequiredParameterMissingError(f"Parameter {parameter.name} is required")
|
||||
|
||||
if parameter.type == "select" and parameter.options and result.get(parameter.name) not in parameter.options:
|
||||
raise ValueError(f"Invalid `select` value for parameter {parameter.name}")
|
||||
raise InvalidSelectValueError(f"Invalid `select` value for parameter {parameter.name}")
|
||||
|
||||
if parameter.type == "number" and not isinstance(result.get(parameter.name), int | float):
|
||||
raise ValueError(f"Invalid `number` value for parameter {parameter.name}")
|
||||
raise InvalidNumberValueError(f"Invalid `number` value for parameter {parameter.name}")
|
||||
|
||||
if parameter.type == "bool" and not isinstance(result.get(parameter.name), bool):
|
||||
raise ValueError(f"Invalid `bool` value for parameter {parameter.name}")
|
||||
raise InvalidBoolValueError(f"Invalid `bool` value for parameter {parameter.name}")
|
||||
|
||||
if parameter.type == "string" and not isinstance(result.get(parameter.name), str):
|
||||
raise ValueError(f"Invalid `string` value for parameter {parameter.name}")
|
||||
raise InvalidStringValueError(f"Invalid `string` value for parameter {parameter.name}")
|
||||
|
||||
if parameter.type.startswith("array"):
|
||||
parameters = result.get(parameter.name)
|
||||
if not isinstance(parameters, list):
|
||||
raise ValueError(f"Invalid `array` value for parameter {parameter.name}")
|
||||
raise InvalidArrayValueError(f"Invalid `array` value for parameter {parameter.name}")
|
||||
nested_type = parameter.type[6:-1]
|
||||
for item in parameters:
|
||||
if nested_type == "number" and not isinstance(item, int | float):
|
||||
raise ValueError(f"Invalid `array[number]` value for parameter {parameter.name}")
|
||||
raise InvalidArrayValueError(f"Invalid `array[number]` value for parameter {parameter.name}")
|
||||
if nested_type == "string" and not isinstance(item, str):
|
||||
raise ValueError(f"Invalid `array[string]` value for parameter {parameter.name}")
|
||||
raise InvalidArrayValueError(f"Invalid `array[string]` value for parameter {parameter.name}")
|
||||
if nested_type == "object" and not isinstance(item, dict):
|
||||
raise ValueError(f"Invalid `array[object]` value for parameter {parameter.name}")
|
||||
raise InvalidArrayValueError(f"Invalid `array[object]` value for parameter {parameter.name}")
|
||||
return result
|
||||
|
||||
def _transform_result(self, data: ParameterExtractorNodeData, result: dict) -> dict:
|
||||
|
@ -634,7 +649,7 @@ class ParameterExtractorNode(LLMNode):
|
|||
user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text)
|
||||
return [system_prompt_messages, user_prompt_message]
|
||||
else:
|
||||
raise ValueError(f"Model mode {model_mode} not support.")
|
||||
raise InvalidModelModeError(f"Model mode {model_mode} not support.")
|
||||
|
||||
def _get_prompt_engineering_prompt_template(
|
||||
self,
|
||||
|
@ -669,7 +684,7 @@ class ParameterExtractorNode(LLMNode):
|
|||
.replace("}γγγ", "")
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Model mode {model_mode} not support.")
|
||||
raise InvalidModelModeError(f"Model mode {model_mode} not support.")
|
||||
|
||||
def _calculate_rest_token(
|
||||
self,
|
||||
|
@ -683,12 +698,12 @@ class ParameterExtractorNode(LLMNode):
|
|||
|
||||
model_instance, model_config = self._fetch_model_config(node_data.model)
|
||||
if not isinstance(model_instance.model_type_instance, LargeLanguageModel):
|
||||
raise ValueError("Model is not a Large Language Model")
|
||||
raise InvalidModelTypeError("Model is not a Large Language Model")
|
||||
|
||||
llm_model = model_instance.model_type_instance
|
||||
model_schema = llm_model.get_model_schema(model_config.model, model_config.credentials)
|
||||
if not model_schema:
|
||||
raise ValueError("Model schema not found")
|
||||
raise ModelSchemaNotFoundError("Model schema not found")
|
||||
|
||||
if set(model_schema.features or []) & {ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}:
|
||||
prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000)
|
||||
|
|
|
@ -91,6 +91,8 @@ def build_segment(value: Any, /) -> Segment:
|
|||
return ArrayObjectSegment(value=value)
|
||||
case SegmentType.FILE:
|
||||
return ArrayFileSegment(value=value)
|
||||
case SegmentType.NONE:
|
||||
return ArrayAnySegment(value=value)
|
||||
case _:
|
||||
raise ValueError(f"not supported value {value}")
|
||||
raise ValueError(f"not supported value {value}")
|
||||
|
|
|
@ -8,6 +8,7 @@ upload_config_fields = {
|
|||
"image_file_size_limit": fields.Integer,
|
||||
"video_file_size_limit": fields.Integer,
|
||||
"audio_file_size_limit": fields.Integer,
|
||||
"workflow_file_upload_limit": fields.Integer,
|
||||
}
|
||||
|
||||
file_fields = {
|
||||
|
|
|
@ -6,7 +6,6 @@ from .model import (
|
|||
AppMode,
|
||||
Conversation,
|
||||
EndUser,
|
||||
FileUploadConfig,
|
||||
InstalledApp,
|
||||
Message,
|
||||
MessageAnnotation,
|
||||
|
@ -50,6 +49,5 @@ __all__ = [
|
|||
"Tenant",
|
||||
"Conversation",
|
||||
"MessageAnnotation",
|
||||
"FileUploadConfig",
|
||||
"ToolFile",
|
||||
]
|
||||
|
|
|
@ -121,7 +121,7 @@ class App(Base):
|
|||
return site
|
||||
|
||||
@property
|
||||
def app_model_config(self) -> Optional["AppModelConfig"]:
|
||||
def app_model_config(self):
|
||||
if self.app_model_config_id:
|
||||
return db.session.query(AppModelConfig).filter(AppModelConfig.id == self.app_model_config_id).first()
|
||||
|
||||
|
@ -1320,7 +1320,7 @@ class Site(Base):
|
|||
privacy_policy = db.Column(db.String(255))
|
||||
show_workflow_steps = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
|
||||
use_icon_as_answer_icon = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
custom_disclaimer: Mapped[str] = mapped_column(sa.TEXT, default="")
|
||||
_custom_disclaimer: Mapped[str] = mapped_column("custom_disclaimer", sa.TEXT, default="")
|
||||
customize_domain = db.Column(db.String(255))
|
||||
customize_token_strategy = db.Column(db.String(255), nullable=False)
|
||||
prompt_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
|
@ -1331,6 +1331,16 @@ class Site(Base):
|
|||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
code = db.Column(db.String(255))
|
||||
|
||||
@property
|
||||
def custom_disclaimer(self):
|
||||
return self._custom_disclaimer
|
||||
|
||||
@custom_disclaimer.setter
|
||||
def custom_disclaimer(self, value: str):
|
||||
if len(value) > 512:
|
||||
raise ValueError("Custom disclaimer cannot exceed 512 characters.")
|
||||
self._custom_disclaimer = value
|
||||
|
||||
@staticmethod
|
||||
def generate_code(n):
|
||||
while True:
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import json
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
|
@ -111,7 +111,9 @@ class Workflow(Base):
|
|||
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
|
||||
)
|
||||
updated_by: Mapped[Optional[str]] = mapped_column(StringUUID)
|
||||
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, default=datetime.now(tz=timezone.utc), server_onupdate=func.current_timestamp()
|
||||
)
|
||||
_environment_variables: Mapped[str] = mapped_column(
|
||||
"environment_variables", db.Text, nullable=False, server_default="{}"
|
||||
)
|
||||
|
|
70
api/poetry.lock
generated
70
api/poetry.lock
generated
|
@ -2532,6 +2532,19 @@ files = [
|
|||
{file = "filetype-1.2.0.tar.gz", hash = "sha256:66b56cd6474bf41d8c54660347d37afcc3f7d1970648de365c102ef77548aadb"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fire"
|
||||
version = "0.7.0"
|
||||
description = "A library for automatically generating command line interfaces."
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "fire-0.7.0.tar.gz", hash = "sha256:961550f07936eaf65ad1dc8360f2b2bf8408fad46abbfa4d2a3794f8d2a95cdf"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
termcolor = "*"
|
||||
|
||||
[[package]]
|
||||
name = "flasgger"
|
||||
version = "0.9.7.1"
|
||||
|
@ -2697,6 +2710,19 @@ files = [
|
|||
{file = "flatbuffers-24.3.25.tar.gz", hash = "sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fontmeta"
|
||||
version = "1.6.1"
|
||||
description = "An Utility to get ttf/otf font metadata"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "fontmeta-1.6.1.tar.gz", hash = "sha256:837e5bc4da879394b41bda1428a8a480eb7c4e993799a93cfb582bab771a9c24"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
fonttools = "*"
|
||||
|
||||
[[package]]
|
||||
name = "fonttools"
|
||||
version = "4.54.1"
|
||||
|
@ -5279,6 +5305,22 @@ files = [
|
|||
{file = "monotonic-1.6.tar.gz", hash = "sha256:3a55207bcfed53ddd5c5bae174524062935efed17792e9de2ad0205ce9ad63f7"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mplfonts"
|
||||
version = "0.0.8"
|
||||
description = "Fonts manager for matplotlib"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "mplfonts-0.0.8-py3-none-any.whl", hash = "sha256:b2182e5b0baa216cf016dec19942740e5b48956415708ad2d465e03952112ec1"},
|
||||
{file = "mplfonts-0.0.8.tar.gz", hash = "sha256:0abcb2fc0605645e1e7561c6923014d856f11676899b33b4d89757843f5e7c22"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
fire = ">=0.4.0"
|
||||
fontmeta = ">=1.6.1"
|
||||
matplotlib = ">=3.4"
|
||||
|
||||
[[package]]
|
||||
name = "mpmath"
|
||||
version = "1.3.0"
|
||||
|
@ -9300,6 +9342,20 @@ files = [
|
|||
[package.dependencies]
|
||||
tencentcloud-sdk-python-common = "3.0.1257"
|
||||
|
||||
[[package]]
|
||||
name = "termcolor"
|
||||
version = "2.5.0"
|
||||
description = "ANSI color formatting for output in terminal"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "termcolor-2.5.0-py3-none-any.whl", hash = "sha256:37b17b5fc1e604945c2642c872a3764b5d547a48009871aea3edd3afa180afb8"},
|
||||
{file = "termcolor-2.5.0.tar.gz", hash = "sha256:998d8d27da6d48442e8e1f016119076b690d962507531df4890fcd2db2ef8a6f"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
tests = ["pytest", "pytest-cov"]
|
||||
|
||||
[[package]]
|
||||
name = "threadpoolctl"
|
||||
version = "3.5.0"
|
||||
|
@ -10046,13 +10102,13 @@ files = [
|
|||
|
||||
[[package]]
|
||||
name = "vanna"
|
||||
version = "0.7.3"
|
||||
version = "0.7.5"
|
||||
description = "Generate SQL queries from natural language"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "vanna-0.7.3-py3-none-any.whl", hash = "sha256:82ba39e5d6c503d1c8cca60835ed401d20ec3a3da98d487f529901dcb30061d6"},
|
||||
{file = "vanna-0.7.3.tar.gz", hash = "sha256:4590dd94d2fe180b4efc7a83c867b73144ef58794018910dc226857cfb703077"},
|
||||
{file = "vanna-0.7.5-py3-none-any.whl", hash = "sha256:07458c7befa49de517a8760c2d80a13147278b484c515d49a906acc88edcb835"},
|
||||
{file = "vanna-0.7.5.tar.gz", hash = "sha256:2fdffc58832898e4fc8e93c45b173424db59a22773b22ca348640161d391eacf"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
@ -10073,7 +10129,7 @@ sqlparse = "*"
|
|||
tabulate = "*"
|
||||
|
||||
[package.extras]
|
||||
all = ["PyMySQL", "anthropic", "azure-common", "azure-identity", "azure-search-documents", "chromadb", "db-dtypes", "duckdb", "fastembed", "google-cloud-aiplatform", "google-cloud-bigquery", "google-generativeai", "httpx", "marqo", "mistralai (>=1.0.0)", "ollama", "openai", "opensearch-dsl", "opensearch-py", "pinecone-client", "psycopg2-binary", "pymilvus[model]", "qdrant-client", "qianfan", "snowflake-connector-python", "transformers", "weaviate-client", "zhipuai"]
|
||||
all = ["PyMySQL", "anthropic", "azure-common", "azure-identity", "azure-search-documents", "boto", "boto3", "botocore", "chromadb", "db-dtypes", "duckdb", "faiss-cpu", "fastembed", "google-cloud-aiplatform", "google-cloud-bigquery", "google-generativeai", "httpx", "langchain_core", "langchain_postgres", "marqo", "mistralai (>=1.0.0)", "ollama", "openai", "opensearch-dsl", "opensearch-py", "pinecone-client", "psycopg2-binary", "pymilvus[model]", "qdrant-client", "qianfan", "snowflake-connector-python", "transformers", "weaviate-client", "xinference-client", "zhipuai"]
|
||||
anthropic = ["anthropic"]
|
||||
azuresearch = ["azure-common", "azure-identity", "azure-search-documents", "fastembed"]
|
||||
bedrock = ["boto3", "botocore"]
|
||||
|
@ -10081,6 +10137,8 @@ bigquery = ["google-cloud-bigquery"]
|
|||
chromadb = ["chromadb"]
|
||||
clickhouse = ["clickhouse_connect"]
|
||||
duckdb = ["duckdb"]
|
||||
faiss-cpu = ["faiss-cpu"]
|
||||
faiss-gpu = ["faiss-gpu"]
|
||||
gemini = ["google-generativeai"]
|
||||
google = ["google-cloud-aiplatform", "google-generativeai"]
|
||||
hf = ["transformers"]
|
||||
|
@ -10091,6 +10149,7 @@ mysql = ["PyMySQL"]
|
|||
ollama = ["httpx", "ollama"]
|
||||
openai = ["openai"]
|
||||
opensearch = ["opensearch-dsl", "opensearch-py"]
|
||||
pgvector = ["langchain-postgres (>=0.0.12)"]
|
||||
pinecone = ["fastembed", "pinecone-client"]
|
||||
postgres = ["db-dtypes", "psycopg2-binary"]
|
||||
qdrant = ["fastembed", "qdrant-client"]
|
||||
|
@ -10099,6 +10158,7 @@ snowflake = ["snowflake-connector-python"]
|
|||
test = ["tox"]
|
||||
vllm = ["vllm"]
|
||||
weaviate = ["weaviate-client"]
|
||||
xinference-client = ["xinference-client"]
|
||||
zhipuai = ["zhipuai"]
|
||||
|
||||
[[package]]
|
||||
|
@ -10940,4 +11000,4 @@ cffi = ["cffi (>=1.11)"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.10,<3.13"
|
||||
content-hash = "ef927b98c33d704d680e08db0e5c7d9a4e05454c66fcd6a5f656a65eb08e886b"
|
||||
content-hash = "e4794898403da4ad7b51f248a6c07632a949114c1b569406d3aa6a94c62510a5"
|
||||
|
|
|
@ -206,13 +206,14 @@ cloudscraper = "1.2.71"
|
|||
duckduckgo-search = "~6.3.0"
|
||||
jsonpath-ng = "1.6.1"
|
||||
matplotlib = "~3.8.2"
|
||||
mplfonts = "~0.0.8"
|
||||
newspaper3k = "0.2.8"
|
||||
nltk = "3.9.1"
|
||||
numexpr = "~2.9.0"
|
||||
pydub = "~0.25.1"
|
||||
qrcode = "~7.4.2"
|
||||
twilio = "~9.0.4"
|
||||
vanna = { version = "0.7.3", extras = ["postgres", "mysql", "clickhouse", "duckdb"] }
|
||||
vanna = { version = "0.7.5", extras = ["postgres", "mysql", "clickhouse", "duckdb", "oracle"] }
|
||||
wikipedia = "1.4.0"
|
||||
yfinance = "~0.2.40"
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@ from models.dataset import Embedding
|
|||
@app.celery.task(queue="dataset")
|
||||
def clean_embedding_cache_task():
|
||||
click.echo(click.style("Start clean embedding cache.", fg="green"))
|
||||
clean_days = int(dify_config.CLEAN_DAY_SETTING)
|
||||
clean_days = int(dify_config.PLAN_SANDBOX_CLEAN_DAY_SETTING)
|
||||
start_at = time.perf_counter()
|
||||
thirty_days_ago = datetime.datetime.now() - datetime.timedelta(days=clean_days)
|
||||
while True:
|
||||
|
|
|
@ -986,9 +986,6 @@ class DocumentService:
|
|||
raise NotFound("Document not found")
|
||||
if document.display_status != "available":
|
||||
raise ValueError("Document is not available")
|
||||
# update document name
|
||||
if document_data.get("name"):
|
||||
document.name = document_data["name"]
|
||||
# save process rule
|
||||
if document_data.get("process_rule"):
|
||||
process_rule = document_data["process_rule"]
|
||||
|
@ -1065,6 +1062,10 @@ class DocumentService:
|
|||
document.data_source_type = document_data["data_source"]["type"]
|
||||
document.data_source_info = json.dumps(data_source_info)
|
||||
document.name = file_name
|
||||
|
||||
# update document name
|
||||
if document_data.get("name"):
|
||||
document.name = document_data["name"]
|
||||
# update document to be waiting
|
||||
document.indexing_status = "waiting"
|
||||
document.completed_at = None
|
||||
|
|
|
@ -62,6 +62,37 @@ class BuiltinToolManageService:
|
|||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def get_builtin_tool_provider_info(user_id: str, tenant_id: str, provider: str):
|
||||
"""
|
||||
get builtin tool provider info
|
||||
"""
|
||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
tool_provider_configurations = ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
)
|
||||
# check if user has added the provider
|
||||
builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
|
||||
|
||||
credentials = {}
|
||||
if builtin_provider is not None:
|
||||
# get credentials
|
||||
credentials = builtin_provider.credentials
|
||||
credentials = tool_provider_configurations.decrypt(credentials)
|
||||
|
||||
entity = ToolTransformService.builtin_provider_to_user_provider(
|
||||
provider_controller=provider_controller,
|
||||
db_provider=builtin_provider,
|
||||
decrypt_credentials=True,
|
||||
)
|
||||
|
||||
entity.original_credentials = {}
|
||||
|
||||
return entity
|
||||
|
||||
@staticmethod
|
||||
def list_builtin_provider_credentials_schema(provider_name: str, tenant_id: str):
|
||||
"""
|
||||
|
@ -255,6 +286,7 @@ class BuiltinToolManageService:
|
|||
@staticmethod
|
||||
def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None:
|
||||
try:
|
||||
full_provider_name = provider_name
|
||||
provider_id_entity = ToolProviderID(provider_name)
|
||||
provider_name = provider_id_entity.provider_name
|
||||
if provider_id_entity.organization != "langgenius":
|
||||
|
@ -264,7 +296,8 @@ class BuiltinToolManageService:
|
|||
db.session.query(BuiltinToolProvider)
|
||||
.filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
(BuiltinToolProvider.provider == provider_name) | (BuiltinToolProvider.provider == provider_name),
|
||||
(BuiltinToolProvider.provider == provider_name)
|
||||
| (BuiltinToolProvider.provider == full_provider_name),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
|
|
@ -96,5 +96,13 @@ VESSL_AI_MODEL_NAME=
|
|||
VESSL_AI_API_KEY=
|
||||
VESSL_AI_ENDPOINT_URL=
|
||||
|
||||
# GPUStack Credentials
|
||||
GPUSTACK_SERVER_URL=
|
||||
GPUSTACK_API_KEY=
|
||||
|
||||
# Gitee AI Credentials
|
||||
GITEE_AI_API_KEY=
|
||||
|
||||
# xAI Credentials
|
||||
XAI_API_KEY=
|
||||
XAI_API_BASE=
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.gpustack.text_embedding.text_embedding import (
|
||||
GPUStackTextEmbeddingModel,
|
||||
)
|
||||
|
||||
|
||||
def test_validate_credentials():
|
||||
model = GPUStackTextEmbeddingModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model="bge-m3",
|
||||
credentials={
|
||||
"endpoint_url": "invalid_url",
|
||||
"api_key": "invalid_api_key",
|
||||
},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model="bge-m3",
|
||||
credentials={
|
||||
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
|
||||
"api_key": os.environ.get("GPUSTACK_API_KEY"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = GPUStackTextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model="bge-m3",
|
||||
credentials={
|
||||
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
|
||||
"api_key": os.environ.get("GPUSTACK_API_KEY"),
|
||||
"context_size": 8192,
|
||||
},
|
||||
texts=["hello", "world"],
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 2
|
||||
assert result.usage.total_tokens == 7
|
162
api/tests/integration_tests/model_runtime/gpustack/test_llm.py
Normal file
162
api/tests/integration_tests/model_runtime/gpustack/test_llm.py
Normal file
|
@ -0,0 +1,162 @@
|
|||
import os
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.gpustack.llm.llm import GPUStackLanguageModel
|
||||
|
||||
|
||||
def test_validate_credentials_for_chat_model():
|
||||
model = GPUStackLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model="llama-3.2-1b-instruct",
|
||||
credentials={
|
||||
"endpoint_url": "invalid_url",
|
||||
"api_key": "invalid_api_key",
|
||||
"mode": "chat",
|
||||
},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model="llama-3.2-1b-instruct",
|
||||
credentials={
|
||||
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
|
||||
"api_key": os.environ.get("GPUSTACK_API_KEY"),
|
||||
"mode": "chat",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_completion_model():
|
||||
model = GPUStackLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model="llama-3.2-1b-instruct",
|
||||
credentials={
|
||||
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
|
||||
"api_key": os.environ.get("GPUSTACK_API_KEY"),
|
||||
"mode": "completion",
|
||||
},
|
||||
prompt_messages=[UserPromptMessage(content="ping")],
|
||||
model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
|
||||
stop=[],
|
||||
user="abc-123",
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
|
||||
def test_invoke_chat_model():
|
||||
model = GPUStackLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model="llama-3.2-1b-instruct",
|
||||
credentials={
|
||||
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
|
||||
"api_key": os.environ.get("GPUSTACK_API_KEY"),
|
||||
"mode": "chat",
|
||||
},
|
||||
prompt_messages=[UserPromptMessage(content="ping")],
|
||||
model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
|
||||
stop=[],
|
||||
user="abc-123",
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
|
||||
def test_invoke_stream_chat_model():
|
||||
model = GPUStackLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model="llama-3.2-1b-instruct",
|
||||
credentials={
|
||||
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
|
||||
"api_key": os.environ.get("GPUSTACK_API_KEY"),
|
||||
"mode": "chat",
|
||||
},
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
|
||||
stop=["you"],
|
||||
stream=True,
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
for chunk in response:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = GPUStackLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model="????",
|
||||
credentials={
|
||||
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
|
||||
"api_key": os.environ.get("GPUSTACK_API_KEY"),
|
||||
"mode": "chat",
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
tools=[
|
||||
PromptMessageTool(
|
||||
name="get_current_weather",
|
||||
description="Get the current weather in a given location",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {"type": "string", "enum": ["c", "f"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
assert isinstance(num_tokens, int)
|
||||
assert num_tokens == 80
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model="????",
|
||||
credentials={
|
||||
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
|
||||
"api_key": os.environ.get("GPUSTACK_API_KEY"),
|
||||
"mode": "chat",
|
||||
},
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
)
|
||||
|
||||
assert isinstance(num_tokens, int)
|
||||
assert num_tokens == 10
|
|
@ -0,0 +1,107 @@
|
|||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.gpustack.rerank.rerank import (
|
||||
GPUStackRerankModel,
|
||||
)
|
||||
|
||||
|
||||
def test_validate_credentials_for_rerank_model():
|
||||
model = GPUStackRerankModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model="bge-reranker-v2-m3",
|
||||
credentials={
|
||||
"endpoint_url": "invalid_url",
|
||||
"api_key": "invalid_api_key",
|
||||
},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model="bge-reranker-v2-m3",
|
||||
credentials={
|
||||
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
|
||||
"api_key": os.environ.get("GPUSTACK_API_KEY"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_rerank_model():
|
||||
model = GPUStackRerankModel()
|
||||
|
||||
response = model.invoke(
|
||||
model="bge-reranker-v2-m3",
|
||||
credentials={
|
||||
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
|
||||
"api_key": os.environ.get("GPUSTACK_API_KEY"),
|
||||
},
|
||||
query="Organic skincare products for sensitive skin",
|
||||
docs=[
|
||||
"Eco-friendly kitchenware for modern homes",
|
||||
"Biodegradable cleaning supplies for eco-conscious consumers",
|
||||
"Organic cotton baby clothes for sensitive skin",
|
||||
"Natural organic skincare range for sensitive skin",
|
||||
"Tech gadgets for smart homes: 2024 edition",
|
||||
"Sustainable gardening tools and compost solutions",
|
||||
"Sensitive skin-friendly facial cleansers and toners",
|
||||
"Organic food wraps and storage solutions",
|
||||
"Yoga mats made from recycled materials",
|
||||
],
|
||||
top_n=3,
|
||||
score_threshold=-0.75,
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, RerankResult)
|
||||
assert len(response.docs) == 3
|
||||
|
||||
|
||||
def test__invoke():
|
||||
model = GPUStackRerankModel()
|
||||
|
||||
# Test case 1: Empty docs
|
||||
result = model._invoke(
|
||||
model="bge-reranker-v2-m3",
|
||||
credentials={
|
||||
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
|
||||
"api_key": os.environ.get("GPUSTACK_API_KEY"),
|
||||
},
|
||||
query="Organic skincare products for sensitive skin",
|
||||
docs=[],
|
||||
top_n=3,
|
||||
score_threshold=0.75,
|
||||
user="abc-123",
|
||||
)
|
||||
assert isinstance(result, RerankResult)
|
||||
assert len(result.docs) == 0
|
||||
|
||||
# Test case 2: Expected docs
|
||||
result = model._invoke(
|
||||
model="bge-reranker-v2-m3",
|
||||
credentials={
|
||||
"endpoint_url": os.environ.get("GPUSTACK_SERVER_URL"),
|
||||
"api_key": os.environ.get("GPUSTACK_API_KEY"),
|
||||
},
|
||||
query="Organic skincare products for sensitive skin",
|
||||
docs=[
|
||||
"Eco-friendly kitchenware for modern homes",
|
||||
"Biodegradable cleaning supplies for eco-conscious consumers",
|
||||
"Organic cotton baby clothes for sensitive skin",
|
||||
"Natural organic skincare range for sensitive skin",
|
||||
"Tech gadgets for smart homes: 2024 edition",
|
||||
"Sustainable gardening tools and compost solutions",
|
||||
"Sensitive skin-friendly facial cleansers and toners",
|
||||
"Organic food wraps and storage solutions",
|
||||
"Yoga mats made from recycled materials",
|
||||
],
|
||||
top_n=3,
|
||||
score_threshold=-0.75,
|
||||
user="abc-123",
|
||||
)
|
||||
assert isinstance(result, RerankResult)
|
||||
assert len(result.docs) == 3
|
||||
assert all(isinstance(doc, RerankDocument) for doc in result.docs)
|
204
api/tests/integration_tests/model_runtime/x/test_llm.py
Normal file
204
api/tests/integration_tests/model_runtime/x/test_llm.py
Normal file
|
@ -0,0 +1,204 @@
|
|||
import os
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.x.llm.llm import XAILargeLanguageModel
|
||||
|
||||
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
|
||||
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
|
||||
|
||||
|
||||
def test_predefined_models():
|
||||
model = XAILargeLanguageModel()
|
||||
model_schemas = model.predefined_models()
|
||||
|
||||
assert len(model_schemas) >= 1
|
||||
assert isinstance(model_schemas[0], AIModelEntity)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||
def test_validate_credentials_for_chat_model(setup_openai_mock):
|
||||
model = XAILargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
# model name to gpt-3.5-turbo because of mocking
|
||||
model.validate_credentials(
|
||||
model="gpt-3.5-turbo",
|
||||
credentials={"api_key": "invalid_key", "endpoint_url": os.environ.get("XAI_API_BASE"), "mode": "chat"},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model="grok-beta",
|
||||
credentials={
|
||||
"api_key": os.environ.get("XAI_API_KEY"),
|
||||
"endpoint_url": os.environ.get("XAI_API_BASE"),
|
||||
"mode": "chat",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||
def test_invoke_chat_model(setup_openai_mock):
|
||||
model = XAILargeLanguageModel()
|
||||
|
||||
result = model.invoke(
|
||||
model="grok-beta",
|
||||
credentials={
|
||||
"api_key": os.environ.get("XAI_API_KEY"),
|
||||
"endpoint_url": os.environ.get("XAI_API_BASE"),
|
||||
"mode": "chat",
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
model_parameters={
|
||||
"temperature": 0.0,
|
||||
"top_p": 1.0,
|
||||
"presence_penalty": 0.0,
|
||||
"frequency_penalty": 0.0,
|
||||
"max_tokens": 10,
|
||||
},
|
||||
stop=["How"],
|
||||
stream=False,
|
||||
user="foo",
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResult)
|
||||
assert len(result.message.content) > 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||
def test_invoke_chat_model_with_tools(setup_openai_mock):
|
||||
model = XAILargeLanguageModel()
|
||||
|
||||
result = model.invoke(
|
||||
model="grok-beta",
|
||||
credentials={
|
||||
"api_key": os.environ.get("XAI_API_KEY"),
|
||||
"endpoint_url": os.environ.get("XAI_API_BASE"),
|
||||
"mode": "chat",
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content="what's the weather today in London?",
|
||||
),
|
||||
],
|
||||
model_parameters={"temperature": 0.0, "max_tokens": 100},
|
||||
tools=[
|
||||
PromptMessageTool(
|
||||
name="get_weather",
|
||||
description="Determine weather in my location",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
|
||||
"unit": {"type": "string", "enum": ["c", "f"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
),
|
||||
PromptMessageTool(
|
||||
name="get_stock_price",
|
||||
description="Get the current stock price",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {"symbol": {"type": "string", "description": "The stock symbol"}},
|
||||
"required": ["symbol"],
|
||||
},
|
||||
),
|
||||
],
|
||||
stream=False,
|
||||
user="foo",
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResult)
|
||||
assert isinstance(result.message, AssistantPromptMessage)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||
def test_invoke_stream_chat_model(setup_openai_mock):
|
||||
model = XAILargeLanguageModel()
|
||||
|
||||
result = model.invoke(
|
||||
model="grok-beta",
|
||||
credentials={
|
||||
"api_key": os.environ.get("XAI_API_KEY"),
|
||||
"endpoint_url": os.environ.get("XAI_API_BASE"),
|
||||
"mode": "chat",
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
model_parameters={"temperature": 0.0, "max_tokens": 100},
|
||||
stream=True,
|
||||
user="foo",
|
||||
)
|
||||
|
||||
assert isinstance(result, Generator)
|
||||
|
||||
for chunk in result:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
if chunk.delta.finish_reason is not None:
|
||||
assert chunk.delta.usage is not None
|
||||
assert chunk.delta.usage.completion_tokens > 0
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = XAILargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model="grok-beta",
|
||||
credentials={"api_key": os.environ.get("XAI_API_KEY"), "endpoint_url": os.environ.get("XAI_API_BASE")},
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
)
|
||||
|
||||
assert num_tokens == 10
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model="grok-beta",
|
||||
credentials={"api_key": os.environ.get("XAI_API_KEY"), "endpoint_url": os.environ.get("XAI_API_BASE")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
tools=[
|
||||
PromptMessageTool(
|
||||
name="get_weather",
|
||||
description="Determine weather in my location",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
|
||||
"unit": {"type": "string", "enum": ["c", "f"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
assert num_tokens == 77
|
0
api/tests/integration_tests/vdb/lindorm/__init__.py
Normal file
0
api/tests/integration_tests/vdb/lindorm/__init__.py
Normal file
35
api/tests/integration_tests/vdb/lindorm/test_lindorm.py
Normal file
35
api/tests/integration_tests/vdb/lindorm/test_lindorm.py
Normal file
|
@ -0,0 +1,35 @@
|
|||
import environs
|
||||
|
||||
from core.rag.datasource.vdb.lindorm.lindorm_vector import LindormVectorStore, LindormVectorStoreConfig
|
||||
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, setup_mock_redis
|
||||
|
||||
env = environs.Env()
|
||||
|
||||
|
||||
class Config:
|
||||
SEARCH_ENDPOINT = env.str("SEARCH_ENDPOINT", "http://ld-*************-proxy-search-pub.lindorm.aliyuncs.com:30070")
|
||||
SEARCH_USERNAME = env.str("SEARCH_USERNAME", "ADMIN")
|
||||
SEARCH_PWD = env.str("SEARCH_PWD", "PWD")
|
||||
|
||||
|
||||
class TestLindormVectorStore(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vector = LindormVectorStore(
|
||||
collection_name=self.collection_name,
|
||||
config=LindormVectorStoreConfig(
|
||||
hosts=Config.SEARCH_ENDPOINT,
|
||||
username=Config.SEARCH_USERNAME,
|
||||
password=Config.SEARCH_PWD,
|
||||
),
|
||||
)
|
||||
|
||||
def get_ids_by_metadata_field(self):
|
||||
ids = self.vector.get_ids_by_metadata_field(key="doc_id", value=self.example_doc_id)
|
||||
assert ids is not None
|
||||
assert len(ids) == 1
|
||||
assert ids[0] == self.example_doc_id
|
||||
|
||||
|
||||
def test_lindorm_vector(setup_mock_redis):
|
||||
TestLindormVectorStore().run_all_tests()
|
|
@ -0,0 +1,52 @@
|
|||
import pytest
|
||||
|
||||
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
|
||||
|
||||
def test_validate_inputs_with_zero():
|
||||
base_app_generator = BaseAppGenerator()
|
||||
|
||||
var = VariableEntity(
|
||||
variable="test_var",
|
||||
label="test_var",
|
||||
type=VariableEntityType.NUMBER,
|
||||
required=True,
|
||||
)
|
||||
|
||||
# Test with input 0
|
||||
result = base_app_generator._validate_inputs(
|
||||
variable_entity=var,
|
||||
value=0,
|
||||
)
|
||||
|
||||
assert result == 0
|
||||
|
||||
# Test with input "0" (string)
|
||||
result = base_app_generator._validate_inputs(
|
||||
variable_entity=var,
|
||||
value="0",
|
||||
)
|
||||
|
||||
assert result == 0
|
||||
|
||||
|
||||
def test_validate_input_with_none_for_required_variable():
|
||||
base_app_generator = BaseAppGenerator()
|
||||
|
||||
for var_type in VariableEntityType:
|
||||
var = VariableEntity(
|
||||
variable="test_var",
|
||||
label="test_var",
|
||||
type=var_type,
|
||||
required=True,
|
||||
)
|
||||
|
||||
# Test with input None
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
base_app_generator._validate_inputs(
|
||||
variable_entity=var,
|
||||
value=None,
|
||||
)
|
||||
|
||||
assert str(exc_info.value) == "test_var is required in input form"
|
|
@ -13,6 +13,7 @@ from core.variables import (
|
|||
StringVariable,
|
||||
)
|
||||
from core.variables.exc import VariableError
|
||||
from core.variables.segments import ArrayAnySegment
|
||||
from factories import variable_factory
|
||||
|
||||
|
||||
|
@ -156,3 +157,9 @@ def test_variable_cannot_large_than_200_kb():
|
|||
"value": "a" * 1024 * 201,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_array_none_variable():
|
||||
var = variable_factory.build_segment([None, None, None, None])
|
||||
assert isinstance(var, ArrayAnySegment)
|
||||
assert var.value == [None, None, None, None]
|
||||
|
|
|
@ -0,0 +1,198 @@
|
|||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.http_request import (
|
||||
BodyData,
|
||||
HttpRequestNodeAuthorization,
|
||||
HttpRequestNodeBody,
|
||||
HttpRequestNodeData,
|
||||
)
|
||||
from core.workflow.nodes.http_request.entities import HttpRequestNodeTimeout
|
||||
from core.workflow.nodes.http_request.executor import Executor
|
||||
|
||||
|
||||
def test_executor_with_json_body_and_number_variable():
|
||||
# Prepare the variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
)
|
||||
variable_pool.add(["pre_node_id", "number"], 42)
|
||||
|
||||
# Prepare the node data
|
||||
node_data = HttpRequestNodeData(
|
||||
title="Test JSON Body with Number Variable",
|
||||
method="post",
|
||||
url="https://api.example.com/data",
|
||||
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||
headers="Content-Type: application/json",
|
||||
params="",
|
||||
body=HttpRequestNodeBody(
|
||||
type="json",
|
||||
data=[
|
||||
BodyData(
|
||||
key="",
|
||||
type="text",
|
||||
value='{"number": {{#pre_node_id.number#}}}',
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
# Initialize the Executor
|
||||
executor = Executor(
|
||||
node_data=node_data,
|
||||
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
# Check the executor's data
|
||||
assert executor.method == "post"
|
||||
assert executor.url == "https://api.example.com/data"
|
||||
assert executor.headers == {"Content-Type": "application/json"}
|
||||
assert executor.params == {}
|
||||
assert executor.json == {"number": 42}
|
||||
assert executor.data is None
|
||||
assert executor.files is None
|
||||
assert executor.content is None
|
||||
|
||||
# Check the raw request (to_log method)
|
||||
raw_request = executor.to_log()
|
||||
assert "POST /data HTTP/1.1" in raw_request
|
||||
assert "Host: api.example.com" in raw_request
|
||||
assert "Content-Type: application/json" in raw_request
|
||||
assert '{"number": 42}' in raw_request
|
||||
|
||||
|
||||
def test_executor_with_json_body_and_object_variable():
|
||||
# Prepare the variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
)
|
||||
variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"})
|
||||
|
||||
# Prepare the node data
|
||||
node_data = HttpRequestNodeData(
|
||||
title="Test JSON Body with Object Variable",
|
||||
method="post",
|
||||
url="https://api.example.com/data",
|
||||
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||
headers="Content-Type: application/json",
|
||||
params="",
|
||||
body=HttpRequestNodeBody(
|
||||
type="json",
|
||||
data=[
|
||||
BodyData(
|
||||
key="",
|
||||
type="text",
|
||||
value="{{#pre_node_id.object#}}",
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
# Initialize the Executor
|
||||
executor = Executor(
|
||||
node_data=node_data,
|
||||
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
# Check the executor's data
|
||||
assert executor.method == "post"
|
||||
assert executor.url == "https://api.example.com/data"
|
||||
assert executor.headers == {"Content-Type": "application/json"}
|
||||
assert executor.params == {}
|
||||
assert executor.json == {"name": "John Doe", "age": 30, "email": "john@example.com"}
|
||||
assert executor.data is None
|
||||
assert executor.files is None
|
||||
assert executor.content is None
|
||||
|
||||
# Check the raw request (to_log method)
|
||||
raw_request = executor.to_log()
|
||||
assert "POST /data HTTP/1.1" in raw_request
|
||||
assert "Host: api.example.com" in raw_request
|
||||
assert "Content-Type: application/json" in raw_request
|
||||
assert '"name": "John Doe"' in raw_request
|
||||
assert '"age": 30' in raw_request
|
||||
assert '"email": "john@example.com"' in raw_request
|
||||
|
||||
|
||||
def test_executor_with_json_body_and_nested_object_variable():
|
||||
# Prepare the variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
)
|
||||
variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"})
|
||||
|
||||
# Prepare the node data
|
||||
node_data = HttpRequestNodeData(
|
||||
title="Test JSON Body with Nested Object Variable",
|
||||
method="post",
|
||||
url="https://api.example.com/data",
|
||||
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||
headers="Content-Type: application/json",
|
||||
params="",
|
||||
body=HttpRequestNodeBody(
|
||||
type="json",
|
||||
data=[
|
||||
BodyData(
|
||||
key="",
|
||||
type="text",
|
||||
value='{"object": {{#pre_node_id.object#}}}',
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
# Initialize the Executor
|
||||
executor = Executor(
|
||||
node_data=node_data,
|
||||
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
# Check the executor's data
|
||||
assert executor.method == "post"
|
||||
assert executor.url == "https://api.example.com/data"
|
||||
assert executor.headers == {"Content-Type": "application/json"}
|
||||
assert executor.params == {}
|
||||
assert executor.json == {"object": {"name": "John Doe", "age": 30, "email": "john@example.com"}}
|
||||
assert executor.data is None
|
||||
assert executor.files is None
|
||||
assert executor.content is None
|
||||
|
||||
# Check the raw request (to_log method)
|
||||
raw_request = executor.to_log()
|
||||
assert "POST /data HTTP/1.1" in raw_request
|
||||
assert "Host: api.example.com" in raw_request
|
||||
assert "Content-Type: application/json" in raw_request
|
||||
assert '"object": {' in raw_request
|
||||
assert '"name": "John Doe"' in raw_request
|
||||
assert '"age": 30' in raw_request
|
||||
assert '"email": "john@example.com"' in raw_request
|
||||
|
||||
|
||||
def test_extract_selectors_from_template_with_newline():
|
||||
variable_pool = VariablePool()
|
||||
variable_pool.add(("node_id", "custom_query"), "line1\nline2")
|
||||
node_data = HttpRequestNodeData(
|
||||
title="Test JSON Body with Nested Object Variable",
|
||||
method="post",
|
||||
url="https://api.example.com/data",
|
||||
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||
headers="Content-Type: application/json",
|
||||
params="test: {{#node_id.custom_query#}}",
|
||||
body=HttpRequestNodeBody(
|
||||
type="none",
|
||||
data=[],
|
||||
),
|
||||
)
|
||||
|
||||
executor = Executor(
|
||||
node_data=node_data,
|
||||
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
assert executor.params == {"test": "line1\nline2"}
|
|
@ -1,5 +1,3 @@
|
|||
import json
|
||||
|
||||
import httpx
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
|
@ -16,8 +14,7 @@ from core.workflow.nodes.http_request import (
|
|||
HttpRequestNodeBody,
|
||||
HttpRequestNodeData,
|
||||
)
|
||||
from core.workflow.nodes.http_request.entities import HttpRequestNodeTimeout
|
||||
from core.workflow.nodes.http_request.executor import Executor, _plain_text_to_dict
|
||||
from core.workflow.nodes.http_request.executor import _plain_text_to_dict
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||
|
||||
|
@ -203,167 +200,3 @@ def test_http_request_node_form_with_file(monkeypatch):
|
|||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert result.outputs["body"] == ""
|
||||
|
||||
|
||||
def test_executor_with_json_body_and_number_variable():
|
||||
# Prepare the variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
)
|
||||
variable_pool.add(["pre_node_id", "number"], 42)
|
||||
|
||||
# Prepare the node data
|
||||
node_data = HttpRequestNodeData(
|
||||
title="Test JSON Body with Number Variable",
|
||||
method="post",
|
||||
url="https://api.example.com/data",
|
||||
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||
headers="Content-Type: application/json",
|
||||
params="",
|
||||
body=HttpRequestNodeBody(
|
||||
type="json",
|
||||
data=[
|
||||
BodyData(
|
||||
key="",
|
||||
type="text",
|
||||
value='{"number": {{#pre_node_id.number#}}}',
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
# Initialize the Executor
|
||||
executor = Executor(
|
||||
node_data=node_data,
|
||||
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
# Check the executor's data
|
||||
assert executor.method == "post"
|
||||
assert executor.url == "https://api.example.com/data"
|
||||
assert executor.headers == {"Content-Type": "application/json"}
|
||||
assert executor.params == {}
|
||||
assert executor.json == {"number": 42}
|
||||
assert executor.data is None
|
||||
assert executor.files is None
|
||||
assert executor.content is None
|
||||
|
||||
# Check the raw request (to_log method)
|
||||
raw_request = executor.to_log()
|
||||
assert "POST /data HTTP/1.1" in raw_request
|
||||
assert "Host: api.example.com" in raw_request
|
||||
assert "Content-Type: application/json" in raw_request
|
||||
assert '{"number": 42}' in raw_request
|
||||
|
||||
|
||||
def test_executor_with_json_body_and_object_variable():
|
||||
# Prepare the variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
)
|
||||
variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"})
|
||||
|
||||
# Prepare the node data
|
||||
node_data = HttpRequestNodeData(
|
||||
title="Test JSON Body with Object Variable",
|
||||
method="post",
|
||||
url="https://api.example.com/data",
|
||||
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||
headers="Content-Type: application/json",
|
||||
params="",
|
||||
body=HttpRequestNodeBody(
|
||||
type="json",
|
||||
data=[
|
||||
BodyData(
|
||||
key="",
|
||||
type="text",
|
||||
value="{{#pre_node_id.object#}}",
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
# Initialize the Executor
|
||||
executor = Executor(
|
||||
node_data=node_data,
|
||||
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
# Check the executor's data
|
||||
assert executor.method == "post"
|
||||
assert executor.url == "https://api.example.com/data"
|
||||
assert executor.headers == {"Content-Type": "application/json"}
|
||||
assert executor.params == {}
|
||||
assert executor.json == {"name": "John Doe", "age": 30, "email": "john@example.com"}
|
||||
assert executor.data is None
|
||||
assert executor.files is None
|
||||
assert executor.content is None
|
||||
|
||||
# Check the raw request (to_log method)
|
||||
raw_request = executor.to_log()
|
||||
assert "POST /data HTTP/1.1" in raw_request
|
||||
assert "Host: api.example.com" in raw_request
|
||||
assert "Content-Type: application/json" in raw_request
|
||||
assert '"name": "John Doe"' in raw_request
|
||||
assert '"age": 30' in raw_request
|
||||
assert '"email": "john@example.com"' in raw_request
|
||||
|
||||
|
||||
def test_executor_with_json_body_and_nested_object_variable():
|
||||
# Prepare the variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
)
|
||||
variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"})
|
||||
|
||||
# Prepare the node data
|
||||
node_data = HttpRequestNodeData(
|
||||
title="Test JSON Body with Nested Object Variable",
|
||||
method="post",
|
||||
url="https://api.example.com/data",
|
||||
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||
headers="Content-Type: application/json",
|
||||
params="",
|
||||
body=HttpRequestNodeBody(
|
||||
type="json",
|
||||
data=[
|
||||
BodyData(
|
||||
key="",
|
||||
type="text",
|
||||
value='{"object": {{#pre_node_id.object#}}}',
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
# Initialize the Executor
|
||||
executor = Executor(
|
||||
node_data=node_data,
|
||||
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
# Check the executor's data
|
||||
assert executor.method == "post"
|
||||
assert executor.url == "https://api.example.com/data"
|
||||
assert executor.headers == {"Content-Type": "application/json"}
|
||||
assert executor.params == {}
|
||||
assert executor.json == {"object": {"name": "John Doe", "age": 30, "email": "john@example.com"}}
|
||||
assert executor.data is None
|
||||
assert executor.files is None
|
||||
assert executor.content is None
|
||||
|
||||
# Check the raw request (to_log method)
|
||||
raw_request = executor.to_log()
|
||||
assert "POST /data HTTP/1.1" in raw_request
|
||||
assert "Host: api.example.com" in raw_request
|
||||
assert "Content-Type: application/json" in raw_request
|
||||
assert '"object": {' in raw_request
|
||||
assert '"name": "John Doe"' in raw_request
|
||||
assert '"age": 30' in raw_request
|
||||
assert '"email": "john@example.com"' in raw_request
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user