mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 03:32:23 +08:00
merge main
This commit is contained in:
commit
02f494c0de
8
.github/workflows/style.yml
vendored
8
.github/workflows/style.yml
vendored
|
@ -20,7 +20,7 @@ jobs:
|
|||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v44
|
||||
uses: tj-actions/changed-files@v45
|
||||
with:
|
||||
files: api/**
|
||||
|
||||
|
@ -66,7 +66,7 @@ jobs:
|
|||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v44
|
||||
uses: tj-actions/changed-files@v45
|
||||
with:
|
||||
files: web/**
|
||||
|
||||
|
@ -97,7 +97,7 @@ jobs:
|
|||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v44
|
||||
uses: tj-actions/changed-files@v45
|
||||
with:
|
||||
files: |
|
||||
**.sh
|
||||
|
@ -107,7 +107,7 @@ jobs:
|
|||
dev/**
|
||||
|
||||
- name: Super-linter
|
||||
uses: super-linter/super-linter/slim@v6
|
||||
uses: super-linter/super-linter/slim@v7
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
env:
|
||||
BASH_SEVERITY: warning
|
||||
|
|
54
.github/workflows/translate-i18n-base-on-english.yml
vendored
Normal file
54
.github/workflows/translate-i18n-base-on-english.yml
vendored
Normal file
|
@ -0,0 +1,54 @@
|
|||
name: Check i18n Files and Create PR
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [closed]
|
||||
branches: [main]
|
||||
|
||||
jobs:
|
||||
check-and-update:
|
||||
if: github.event.pull_request.merged == true
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
working-directory: web
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 2 # last 2 commits
|
||||
|
||||
- name: Check for file changes in i18n/en-US
|
||||
id: check_files
|
||||
run: |
|
||||
recent_commit_sha=$(git rev-parse HEAD)
|
||||
second_recent_commit_sha=$(git rev-parse HEAD~1)
|
||||
changed_files=$(git diff --name-only $recent_commit_sha $second_recent_commit_sha -- 'i18n/en-US/*.ts')
|
||||
echo "Changed files: $changed_files"
|
||||
if [ -n "$changed_files" ]; then
|
||||
echo "FILES_CHANGED=true" >> $GITHUB_ENV
|
||||
else
|
||||
echo "FILES_CHANGED=false" >> $GITHUB_ENV
|
||||
fi
|
||||
|
||||
- name: Set up Node.js
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
uses: actions/setup-node@v2
|
||||
with:
|
||||
node-version: 'lts/*'
|
||||
|
||||
- name: Install dependencies
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
run: yarn install --frozen-lockfile
|
||||
|
||||
- name: Run npm script
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
run: npm run auto-gen-i18n
|
||||
|
||||
- name: Create Pull Request
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
uses: peter-evans/create-pull-request@v6
|
||||
with:
|
||||
commit-message: Update i18n files based on en-US changes
|
||||
title: 'chore: translate i18n files'
|
||||
body: This PR was automatically created to update i18n files based on changes in en-US locale.
|
||||
branch: chore/automated-i18n-updates
|
|
@ -8,7 +8,7 @@ In terms of licensing, please take a minute to read our short [License and Contr
|
|||
|
||||
## Before you jump in
|
||||
|
||||
[Find](https://github.com/langgenius/dify/issues?q=is:issue+is:closed) an existing issue, or [open](https://github.com/langgenius/dify/issues/new/choose) a new one. We categorize issues into 2 types:
|
||||
[Find](https://github.com/langgenius/dify/issues?q=is:issue+is:open) an existing issue, or [open](https://github.com/langgenius/dify/issues/new/choose) a new one. We categorize issues into 2 types:
|
||||
|
||||
### Feature requests:
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
|
||||
## 在开始之前
|
||||
|
||||
[查找](https://github.com/langgenius/dify/issues?q=is:issue+is:closed)现有问题,或 [创建](https://github.com/langgenius/dify/issues/new/choose) 一个新问题。我们将问题分为两类:
|
||||
[查找](https://github.com/langgenius/dify/issues?q=is:issue+is:open)现有问题,或 [创建](https://github.com/langgenius/dify/issues/new/choose) 一个新问题。我们将问题分为两类:
|
||||
|
||||
### 功能请求:
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ Dify にコントリビュートしたいとお考えなのですね。それは
|
|||
|
||||
## 飛び込む前に
|
||||
|
||||
[既存の Issue](https://github.com/langgenius/dify/issues?q=is:issue+is:closed) を探すか、[新しい Issue](https://github.com/langgenius/dify/issues/new/choose) を作成してください。私たちは Issue を 2 つのタイプに分類しています。
|
||||
[既存の Issue](https://github.com/langgenius/dify/issues?q=is:issue+is:open) を探すか、[新しい Issue](https://github.com/langgenius/dify/issues/new/choose) を作成してください。私たちは Issue を 2 つのタイプに分類しています。
|
||||
|
||||
### 機能リクエスト
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ Về vấn đề cấp phép, xin vui lòng dành chút thời gian đọc qua [
|
|||
|
||||
## Trước khi bắt đầu
|
||||
|
||||
[Tìm kiếm](https://github.com/langgenius/dify/issues?q=is:issue+is:closed) một vấn đề hiện có, hoặc [tạo mới](https://github.com/langgenius/dify/issues/new/choose) một vấn đề. Chúng tôi phân loại các vấn đề thành 2 loại:
|
||||
[Tìm kiếm](https://github.com/langgenius/dify/issues?q=is:issue+is:open) một vấn đề hiện có, hoặc [tạo mới](https://github.com/langgenius/dify/issues/new/choose) một vấn đề. Chúng tôi phân loại các vấn đề thành 2 loại:
|
||||
|
||||
### Yêu cầu tính năng:
|
||||
|
||||
|
|
2
LICENSE
2
LICENSE
|
@ -4,7 +4,7 @@ Dify is licensed under the Apache License 2.0, with the following additional con
|
|||
|
||||
1. Dify may be utilized commercially, including as a backend service for other applications or as an application development platform for enterprises. Should the conditions below be met, a commercial license must be obtained from the producer:
|
||||
|
||||
a. Multi-tenant SaaS service: Unless explicitly authorized by Dify in writing, you may not use the Dify source code to operate a multi-tenant environment.
|
||||
a. Multi-tenant service: Unless explicitly authorized by Dify in writing, you may not use the Dify source code to operate a multi-tenant environment.
|
||||
- Tenant Definition: Within the context of Dify, one tenant corresponds to one workspace. The workspace provides a separated area for each tenant's data and configurations.
|
||||
|
||||
b. LOGO and copyright information: In the process of using Dify's frontend components, you may not remove or modify the LOGO or copyright information in the Dify console or applications. This restriction is inapplicable to uses of Dify that do not involve its frontend components.
|
||||
|
|
|
@ -39,7 +39,7 @@ DB_DATABASE=dify
|
|||
|
||||
# Storage configuration
|
||||
# use for store upload files, private keys...
|
||||
# storage type: local, s3, azure-blob, google-storage
|
||||
# storage type: local, s3, azure-blob, google-storage, tencent-cos, huawei-obs, volcengine-tos
|
||||
STORAGE_TYPE=local
|
||||
STORAGE_LOCAL_PATH=storage
|
||||
S3_USE_AWS_MANAGED_IAM=false
|
||||
|
@ -60,7 +60,8 @@ ALIYUN_OSS_SECRET_KEY=your-secret-key
|
|||
ALIYUN_OSS_ENDPOINT=your-endpoint
|
||||
ALIYUN_OSS_AUTH_VERSION=v1
|
||||
ALIYUN_OSS_REGION=your-region
|
||||
|
||||
# Don't start with '/'. OSS doesn't support leading slash in object names.
|
||||
ALIYUN_OSS_PATH=your-path
|
||||
# Google Storage configuration
|
||||
GOOGLE_STORAGE_BUCKET_NAME=yout-bucket-name
|
||||
GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64=your-google-service-account-json-base64-string
|
||||
|
@ -72,6 +73,12 @@ TENCENT_COS_SECRET_ID=your-secret-id
|
|||
TENCENT_COS_REGION=your-region
|
||||
TENCENT_COS_SCHEME=your-scheme
|
||||
|
||||
# Huawei OBS Storage Configuration
|
||||
HUAWEI_OBS_BUCKET_NAME=your-bucket-name
|
||||
HUAWEI_OBS_SECRET_KEY=your-secret-key
|
||||
HUAWEI_OBS_ACCESS_KEY=your-access-key
|
||||
HUAWEI_OBS_SERVER=your-server-url
|
||||
|
||||
# OCI Storage configuration
|
||||
OCI_ENDPOINT=your-endpoint
|
||||
OCI_BUCKET_NAME=your-bucket-name
|
||||
|
@ -79,6 +86,13 @@ OCI_ACCESS_KEY=your-access-key
|
|||
OCI_SECRET_KEY=your-secret-key
|
||||
OCI_REGION=your-region
|
||||
|
||||
# Volcengine tos Storage configuration
|
||||
VOLCENGINE_TOS_ENDPOINT=your-endpoint
|
||||
VOLCENGINE_TOS_BUCKET_NAME=your-bucket-name
|
||||
VOLCENGINE_TOS_ACCESS_KEY=your-access-key
|
||||
VOLCENGINE_TOS_SECRET_KEY=your-secret-key
|
||||
VOLCENGINE_TOS_REGION=your-region
|
||||
|
||||
# CORS configuration
|
||||
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
|
@ -100,11 +114,10 @@ QDRANT_GRPC_ENABLED=false
|
|||
QDRANT_GRPC_PORT=6334
|
||||
|
||||
# Milvus configuration
|
||||
MILVUS_HOST=127.0.0.1
|
||||
MILVUS_PORT=19530
|
||||
MILVUS_URI=http://127.0.0.1:19530
|
||||
MILVUS_TOKEN=
|
||||
MILVUS_USER=root
|
||||
MILVUS_PASSWORD=Milvus
|
||||
MILVUS_SECURE=false
|
||||
|
||||
# MyScale configuration
|
||||
MYSCALE_HOST=127.0.0.1
|
||||
|
|
|
@ -55,7 +55,7 @@ RUN apt-get update \
|
|||
&& echo "deb http://deb.debian.org/debian testing main" > /etc/apt/sources.list \
|
||||
&& apt-get update \
|
||||
# For Security
|
||||
&& apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.2-1 libldap-2.5-0=2.5.18+dfsg-2 perl=5.38.2-5 libsqlite3-0=3.46.0-1 \
|
||||
&& apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.3-1 libldap-2.5-0=2.5.18+dfsg-3 perl=5.38.2-5 libsqlite3-0=3.46.0-1 \
|
||||
&& apt-get autoremove -y \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
|
|
|
@ -559,8 +559,9 @@ def add_qdrant_doc_id_index(field: str):
|
|||
|
||||
@click.command("create-tenant", help="Create account and tenant.")
|
||||
@click.option("--email", prompt=True, help="The email address of the tenant account.")
|
||||
@click.option("--name", prompt=True, help="The workspace name of the tenant account.")
|
||||
@click.option("--language", prompt=True, help="Account language, default: en-US.")
|
||||
def create_tenant(email: str, language: Optional[str] = None):
|
||||
def create_tenant(email: str, language: Optional[str] = None, name: Optional[str] = None):
|
||||
"""
|
||||
Create tenant account
|
||||
"""
|
||||
|
@ -580,13 +581,15 @@ def create_tenant(email: str, language: Optional[str] = None):
|
|||
if language not in languages:
|
||||
language = "en-US"
|
||||
|
||||
name = name.strip()
|
||||
|
||||
# generate random password
|
||||
new_password = secrets.token_urlsafe(16)
|
||||
|
||||
# register account
|
||||
account = RegisterService.register(email=email, name=account_name, password=new_password, language=language)
|
||||
|
||||
TenantService.create_owner_tenant_if_not_exist(account)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name)
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from typing import Optional
|
||||
from typing import Annotated, Optional
|
||||
|
||||
from pydantic import AliasChoices, Field, NegativeInt, NonNegativeInt, PositiveInt, computed_field
|
||||
from pydantic import AliasChoices, Field, HttpUrl, NegativeInt, NonNegativeInt, PositiveInt, computed_field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
from configs.feature.hosted_service import HostedServiceConfig
|
||||
|
@ -45,8 +45,8 @@ class CodeExecutionSandboxConfig(BaseSettings):
|
|||
Code Execution Sandbox configs
|
||||
"""
|
||||
|
||||
CODE_EXECUTION_ENDPOINT: str = Field(
|
||||
description="endpoint URL of code execution servcie",
|
||||
CODE_EXECUTION_ENDPOINT: HttpUrl = Field(
|
||||
description="endpoint URL of code execution service",
|
||||
default="http://sandbox:8194",
|
||||
)
|
||||
|
||||
|
@ -55,6 +55,21 @@ class CodeExecutionSandboxConfig(BaseSettings):
|
|||
default="dify-sandbox",
|
||||
)
|
||||
|
||||
CODE_EXECUTION_CONNECT_TIMEOUT: Optional[float] = Field(
|
||||
description="connect timeout in seconds for code execution request",
|
||||
default=10.0,
|
||||
)
|
||||
|
||||
CODE_EXECUTION_READ_TIMEOUT: Optional[float] = Field(
|
||||
description="read timeout in seconds for code execution request",
|
||||
default=60.0,
|
||||
)
|
||||
|
||||
CODE_EXECUTION_WRITE_TIMEOUT: Optional[float] = Field(
|
||||
description="write timeout in seconds for code execution request",
|
||||
default=10.0,
|
||||
)
|
||||
|
||||
CODE_MAX_NUMBER: PositiveInt = Field(
|
||||
description="max depth for code execution",
|
||||
default=9223372036854775807,
|
||||
|
@ -202,20 +217,17 @@ class HttpConfig(BaseSettings):
|
|||
def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]:
|
||||
return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(",")
|
||||
|
||||
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: NonNegativeInt = Field(
|
||||
description="",
|
||||
default=300,
|
||||
)
|
||||
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: Annotated[
|
||||
PositiveInt, Field(ge=10, description="connect timeout in seconds for HTTP request")
|
||||
] = 10
|
||||
|
||||
HTTP_REQUEST_MAX_READ_TIMEOUT: NonNegativeInt = Field(
|
||||
description="",
|
||||
default=600,
|
||||
)
|
||||
HTTP_REQUEST_MAX_READ_TIMEOUT: Annotated[
|
||||
PositiveInt, Field(ge=60, description="read timeout in seconds for HTTP request")
|
||||
] = 60
|
||||
|
||||
HTTP_REQUEST_MAX_WRITE_TIMEOUT: NonNegativeInt = Field(
|
||||
description="",
|
||||
default=600,
|
||||
)
|
||||
HTTP_REQUEST_MAX_WRITE_TIMEOUT: Annotated[
|
||||
PositiveInt, Field(ge=10, description="read timeout in seconds for HTTP request")
|
||||
] = 20
|
||||
|
||||
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: PositiveInt = Field(
|
||||
description="",
|
||||
|
@ -403,7 +415,7 @@ class MailConfig(BaseSettings):
|
|||
"""
|
||||
|
||||
MAIL_TYPE: Optional[str] = Field(
|
||||
description="Mail provider type name, default to None, availabile values are `smtp` and `resend`.",
|
||||
description="Mail provider type name, default to None, available values are `smtp` and `resend`.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Any, Optional
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
from pydantic import Field, NonNegativeInt, PositiveInt, computed_field
|
||||
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, computed_field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
from configs.middleware.cache.redis_config import RedisConfig
|
||||
|
@ -9,8 +9,10 @@ from configs.middleware.storage.aliyun_oss_storage_config import AliyunOSSStorag
|
|||
from configs.middleware.storage.amazon_s3_storage_config import S3StorageConfig
|
||||
from configs.middleware.storage.azure_blob_storage_config import AzureBlobStorageConfig
|
||||
from configs.middleware.storage.google_cloud_storage_config import GoogleCloudStorageConfig
|
||||
from configs.middleware.storage.huawei_obs_storage_config import HuaweiCloudOBSStorageConfig
|
||||
from configs.middleware.storage.oci_storage_config import OCIStorageConfig
|
||||
from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
|
||||
from configs.middleware.storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
|
||||
from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig
|
||||
from configs.middleware.vdb.chroma_config import ChromaConfig
|
||||
from configs.middleware.vdb.elasticsearch_config import ElasticsearchConfig
|
||||
|
@ -157,6 +159,21 @@ class CeleryConfig(DatabaseConfig):
|
|||
default=None,
|
||||
)
|
||||
|
||||
CELERY_USE_SENTINEL: Optional[bool] = Field(
|
||||
description="Whether to use Redis Sentinel mode",
|
||||
default=False,
|
||||
)
|
||||
|
||||
CELERY_SENTINEL_MASTER_NAME: Optional[str] = Field(
|
||||
description="Redis Sentinel master name",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CELERY_SENTINEL_SOCKET_TIMEOUT: Optional[PositiveFloat] = Field(
|
||||
description="Redis Sentinel socket timeout",
|
||||
default=0.1,
|
||||
)
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def CELERY_RESULT_BACKEND(self) -> str | None:
|
||||
|
@ -184,6 +201,8 @@ class MiddlewareConfig(
|
|||
AzureBlobStorageConfig,
|
||||
GoogleCloudStorageConfig,
|
||||
TencentCloudCOSStorageConfig,
|
||||
HuaweiCloudOBSStorageConfig,
|
||||
VolcengineTOSStorageConfig,
|
||||
S3StorageConfig,
|
||||
OCIStorageConfig,
|
||||
# configs of vdb and vdb providers
|
||||
|
|
32
api/configs/middleware/cache/redis_config.py
vendored
32
api/configs/middleware/cache/redis_config.py
vendored
|
@ -1,6 +1,6 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import Field, NonNegativeInt, PositiveInt
|
||||
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
|
@ -38,3 +38,33 @@ class RedisConfig(BaseSettings):
|
|||
description="whether to use SSL for Redis connection",
|
||||
default=False,
|
||||
)
|
||||
|
||||
REDIS_USE_SENTINEL: Optional[bool] = Field(
|
||||
description="Whether to use Redis Sentinel mode",
|
||||
default=False,
|
||||
)
|
||||
|
||||
REDIS_SENTINELS: Optional[str] = Field(
|
||||
description="Redis Sentinel nodes",
|
||||
default=None,
|
||||
)
|
||||
|
||||
REDIS_SENTINEL_SERVICE_NAME: Optional[str] = Field(
|
||||
description="Redis Sentinel service name",
|
||||
default=None,
|
||||
)
|
||||
|
||||
REDIS_SENTINEL_USERNAME: Optional[str] = Field(
|
||||
description="Redis Sentinel username",
|
||||
default=None,
|
||||
)
|
||||
|
||||
REDIS_SENTINEL_PASSWORD: Optional[str] = Field(
|
||||
description="Redis Sentinel password",
|
||||
default=None,
|
||||
)
|
||||
|
||||
REDIS_SENTINEL_SOCKET_TIMEOUT: Optional[PositiveFloat] = Field(
|
||||
description="Redis Sentinel socket timeout",
|
||||
default=0.1,
|
||||
)
|
||||
|
|
|
@ -38,3 +38,8 @@ class AliyunOSSStorageConfig(BaseSettings):
|
|||
description="Aliyun OSS authentication version",
|
||||
default=None,
|
||||
)
|
||||
|
||||
ALIYUN_OSS_PATH: Optional[str] = Field(
|
||||
description="Aliyun OSS path",
|
||||
default=None,
|
||||
)
|
||||
|
|
29
api/configs/middleware/storage/huawei_obs_storage_config.py
Normal file
29
api/configs/middleware/storage/huawei_obs_storage_config.py
Normal file
|
@ -0,0 +1,29 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class HuaweiCloudOBSStorageConfig(BaseModel):
|
||||
"""
|
||||
Huawei Cloud OBS storage configs
|
||||
"""
|
||||
|
||||
HUAWEI_OBS_BUCKET_NAME: Optional[str] = Field(
|
||||
description="Huawei Cloud OBS bucket name",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HUAWEI_OBS_ACCESS_KEY: Optional[str] = Field(
|
||||
description="Huawei Cloud OBS Access key",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HUAWEI_OBS_SECRET_KEY: Optional[str] = Field(
|
||||
description="Huawei Cloud OBS Secret key",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HUAWEI_OBS_SERVER: Optional[str] = Field(
|
||||
description="Huawei Cloud OBS server URL",
|
||||
default=None,
|
||||
)
|
|
@ -0,0 +1,34 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class VolcengineTOSStorageConfig(BaseModel):
|
||||
"""
|
||||
Volcengine tos storage configs
|
||||
"""
|
||||
|
||||
VOLCENGINE_TOS_BUCKET_NAME: Optional[str] = Field(
|
||||
description="Volcengine TOS Bucket Name",
|
||||
default=None,
|
||||
)
|
||||
|
||||
VOLCENGINE_TOS_ACCESS_KEY: Optional[str] = Field(
|
||||
description="Volcengine TOS Access Key",
|
||||
default=None,
|
||||
)
|
||||
|
||||
VOLCENGINE_TOS_SECRET_KEY: Optional[str] = Field(
|
||||
description="Volcengine TOS Secret Key",
|
||||
default=None,
|
||||
)
|
||||
|
||||
VOLCENGINE_TOS_ENDPOINT: Optional[str] = Field(
|
||||
description="Volcengine TOS Endpoint URL",
|
||||
default=None,
|
||||
)
|
||||
|
||||
VOLCENGINE_TOS_REGION: Optional[str] = Field(
|
||||
description="Volcengine TOS Region",
|
||||
default=None,
|
||||
)
|
|
@ -1,6 +1,6 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import Field, PositiveInt
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
|
@ -9,14 +9,14 @@ class MilvusConfig(BaseSettings):
|
|||
Milvus configs
|
||||
"""
|
||||
|
||||
MILVUS_HOST: Optional[str] = Field(
|
||||
description="Milvus host",
|
||||
default=None,
|
||||
MILVUS_URI: Optional[str] = Field(
|
||||
description="Milvus uri",
|
||||
default="http://127.0.0.1:19530",
|
||||
)
|
||||
|
||||
MILVUS_PORT: PositiveInt = Field(
|
||||
description="Milvus RestFul API port",
|
||||
default=9091,
|
||||
MILVUS_TOKEN: Optional[str] = Field(
|
||||
description="Milvus token",
|
||||
default=None,
|
||||
)
|
||||
|
||||
MILVUS_USER: Optional[str] = Field(
|
||||
|
@ -29,11 +29,6 @@ class MilvusConfig(BaseSettings):
|
|||
default=None,
|
||||
)
|
||||
|
||||
MILVUS_SECURE: bool = Field(
|
||||
description="whether to use SSL connection for Milvus",
|
||||
default=False,
|
||||
)
|
||||
|
||||
MILVUS_DATABASE: str = Field(
|
||||
description="Milvus database, default to `default`",
|
||||
default="default",
|
||||
|
|
|
@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
|||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description="Dify version",
|
||||
default="0.7.2",
|
||||
default="0.8.0",
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
|
File diff suppressed because one or more lines are too long
|
@ -174,6 +174,7 @@ class AppApi(Resource):
|
|||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
parser.add_argument("max_active_requests", type=int, location="json")
|
||||
parser.add_argument("use_icon_as_answer_icon", type=bool, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
app_service = AppService()
|
||||
|
|
|
@ -201,7 +201,11 @@ class ChatConversationApi(Resource):
|
|||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
query = query.where(Conversation.created_at >= start_datetime_utc)
|
||||
match args["sort_by"]:
|
||||
case "updated_at" | "-updated_at":
|
||||
query = query.where(Conversation.updated_at >= start_datetime_utc)
|
||||
case "created_at" | "-created_at" | _:
|
||||
query = query.where(Conversation.created_at >= start_datetime_utc)
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
|
@ -210,7 +214,11 @@ class ChatConversationApi(Resource):
|
|||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
query = query.where(Conversation.created_at < end_datetime_utc)
|
||||
match args["sort_by"]:
|
||||
case "updated_at" | "-updated_at":
|
||||
query = query.where(Conversation.updated_at <= end_datetime_utc)
|
||||
case "created_at" | "-created_at" | _:
|
||||
query = query.where(Conversation.created_at <= end_datetime_utc)
|
||||
|
||||
if args["annotation_status"] == "annotated":
|
||||
query = query.options(joinedload(Conversation.message_annotations)).join(
|
||||
|
|
|
@ -32,6 +32,8 @@ class ModelConfigResource(Resource):
|
|||
|
||||
new_app_model_config = AppModelConfig(
|
||||
app_id=app_model.id,
|
||||
created_by=current_user.id,
|
||||
updated_by=current_user.id,
|
||||
)
|
||||
new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)
|
||||
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from datetime import datetime, timezone
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
@ -32,6 +34,7 @@ def parse_app_site_args():
|
|||
)
|
||||
parser.add_argument("prompt_public", type=bool, required=False, location="json")
|
||||
parser.add_argument("show_workflow_steps", type=bool, required=False, location="json")
|
||||
parser.add_argument("use_icon_as_answer_icon", type=bool, required=False, location="json")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
|
@ -66,11 +69,14 @@ class AppSite(Resource):
|
|||
"customize_token_strategy",
|
||||
"prompt_public",
|
||||
"show_workflow_steps",
|
||||
"use_icon_as_answer_icon",
|
||||
]:
|
||||
value = args.get(attr_name)
|
||||
if value is not None:
|
||||
setattr(site, attr_name, value)
|
||||
|
||||
site.updated_by = current_user.id
|
||||
site.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
return site
|
||||
|
@ -93,6 +99,8 @@ class AppSiteAccessTokenReset(Resource):
|
|||
raise NotFound
|
||||
|
||||
site.code = Site.generate_code(16)
|
||||
site.updated_by = current_user.id
|
||||
site.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
return site
|
||||
|
|
|
@ -18,7 +18,7 @@ from core.model_runtime.entities.model_entities import ModelType
|
|||
from core.provider_manager import ProviderManager
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.retrieval.retrival_methods import RetrievalMethod
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import related_app_list
|
||||
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
|
||||
|
@ -122,6 +122,7 @@ class DatasetListApi(Resource):
|
|||
name=args["name"],
|
||||
indexing_technique=args["indexing_technique"],
|
||||
account=current_user,
|
||||
permission=DatasetPermissionEnum.ONLY_ME,
|
||||
)
|
||||
except services.errors.dataset.DatasetNameDuplicateError:
|
||||
raise DatasetNameDuplicateError()
|
||||
|
|
|
@ -302,6 +302,8 @@ class DatasetInitApi(Resource):
|
|||
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
|
||||
)
|
||||
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
|
||||
parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
|
@ -309,6 +311,8 @@ class DatasetInitApi(Resource):
|
|||
raise Forbidden()
|
||||
|
||||
if args["indexing_technique"] == "high_quality":
|
||||
if args["embedding_model"] is None or args["embedding_model_provider"] is None:
|
||||
raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_default_model_instance(
|
||||
|
@ -599,6 +603,7 @@ class DocumentDetailApi(DocumentResource):
|
|||
"hit_count": document.hit_count,
|
||||
"display_status": document.display_status,
|
||||
"doc_form": document.doc_form,
|
||||
"doc_language": document.doc_language,
|
||||
}
|
||||
else:
|
||||
process_rules = DatasetService.get_process_rules(dataset_id)
|
||||
|
@ -631,6 +636,7 @@ class DocumentDetailApi(DocumentResource):
|
|||
"hit_count": document.hit_count,
|
||||
"display_status": document.display_status,
|
||||
"doc_form": document.doc_form,
|
||||
"doc_language": document.doc_language,
|
||||
}
|
||||
|
||||
return response, 200
|
||||
|
|
|
@ -39,7 +39,7 @@ class FileApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(file_fields)
|
||||
@cloud_edition_billing_resource_check(resource="documents")
|
||||
@cloud_edition_billing_resource_check("documents")
|
||||
def post(self):
|
||||
# get file from request
|
||||
file = request.files["file"]
|
||||
|
|
|
@ -35,6 +35,7 @@ class InstalledAppsListApi(Resource):
|
|||
"uninstallable": current_tenant_id == installed_app.app_owner_tenant_id,
|
||||
}
|
||||
for installed_app in installed_apps
|
||||
if installed_app.app is not None
|
||||
]
|
||||
installed_apps.sort(
|
||||
key=lambda app: (
|
||||
|
|
|
@ -13,7 +13,7 @@ from services.tag_service import TagService
|
|||
|
||||
|
||||
def _validate_name(name):
|
||||
if not name or len(name) < 1 or len(name) > 40:
|
||||
if not name or len(name) < 1 or len(name) > 50:
|
||||
raise ValueError("Name must be between 1 to 50 characters.")
|
||||
return name
|
||||
|
||||
|
|
|
@ -46,9 +46,7 @@ def only_edition_self_hosted(view):
|
|||
return decorated
|
||||
|
||||
|
||||
def cloud_edition_billing_resource_check(
|
||||
resource: str, error_msg: str = "You have reached the limit of your subscription."
|
||||
):
|
||||
def cloud_edition_billing_resource_check(resource: str):
|
||||
def interceptor(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
|
@ -60,22 +58,22 @@ def cloud_edition_billing_resource_check(
|
|||
documents_upload_quota = features.documents_upload_quota
|
||||
annotation_quota_limit = features.annotation_quota_limit
|
||||
if resource == "members" and 0 < members.limit <= members.size:
|
||||
abort(403, error_msg)
|
||||
abort(403, "The number of members has reached the limit of your subscription.")
|
||||
elif resource == "apps" and 0 < apps.limit <= apps.size:
|
||||
abort(403, error_msg)
|
||||
abort(403, "The number of apps has reached the limit of your subscription.")
|
||||
elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size:
|
||||
abort(403, error_msg)
|
||||
abort(403, "The capacity of the vector space has reached the limit of your subscription.")
|
||||
elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
|
||||
# The api of file upload is used in the multiple places, so we need to check the source of the request from datasets
|
||||
source = request.args.get("source")
|
||||
if source == "datasets":
|
||||
abort(403, error_msg)
|
||||
abort(403, "The number of documents has reached the limit of your subscription.")
|
||||
else:
|
||||
return view(*args, **kwargs)
|
||||
elif resource == "workspace_custom" and not features.can_replace_logo:
|
||||
abort(403, error_msg)
|
||||
abort(403, "The workspace custom feature has reached the limit of your subscription.")
|
||||
elif resource == "annotation" and 0 < annotation_quota_limit.limit < annotation_quota_limit.size:
|
||||
abort(403, error_msg)
|
||||
abort(403, "The annotation quota has reached the limit of your subscription.")
|
||||
else:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
|
@ -86,10 +84,7 @@ def cloud_edition_billing_resource_check(
|
|||
return interceptor
|
||||
|
||||
|
||||
def cloud_edition_billing_knowledge_limit_check(
|
||||
resource: str,
|
||||
error_msg: str = "To unlock this feature and elevate your Dify experience, please upgrade to a paid plan.",
|
||||
):
|
||||
def cloud_edition_billing_knowledge_limit_check(resource: str):
|
||||
def interceptor(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
|
@ -97,7 +92,10 @@ def cloud_edition_billing_knowledge_limit_check(
|
|||
if features.billing.enabled:
|
||||
if resource == "add_segment":
|
||||
if features.billing.subscription.plan == "sandbox":
|
||||
abort(403, error_msg)
|
||||
abort(
|
||||
403,
|
||||
"To unlock this feature and elevate your Dify experience, please upgrade to a paid plan.",
|
||||
)
|
||||
else:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
|
|
|
@ -36,6 +36,10 @@ class SegmentApi(DatasetApiResource):
|
|||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
if document.indexing_status != "completed":
|
||||
raise NotFound("Document is not completed.")
|
||||
if not document.enabled:
|
||||
raise NotFound("Document is disabled.")
|
||||
# check embedding model setting
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
try:
|
||||
|
@ -63,7 +67,7 @@ class SegmentApi(DatasetApiResource):
|
|||
segments = SegmentService.multi_create_segment(args["segments"], document, dataset)
|
||||
return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form}, 200
|
||||
else:
|
||||
return {"error": "Segemtns is required"}, 400
|
||||
return {"error": "Segments is required"}, 400
|
||||
|
||||
def get(self, tenant_id, dataset_id, document_id):
|
||||
"""Create single segment."""
|
||||
|
|
|
@ -83,9 +83,7 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
|
|||
return decorator(view)
|
||||
|
||||
|
||||
def cloud_edition_billing_resource_check(
|
||||
resource: str, api_token_type: str, error_msg: str = "You have reached the limit of your subscription."
|
||||
):
|
||||
def cloud_edition_billing_resource_check(resource: str, api_token_type: str):
|
||||
def interceptor(view):
|
||||
def decorated(*args, **kwargs):
|
||||
api_token = validate_and_get_api_token(api_token_type)
|
||||
|
@ -98,13 +96,13 @@ def cloud_edition_billing_resource_check(
|
|||
documents_upload_quota = features.documents_upload_quota
|
||||
|
||||
if resource == "members" and 0 < members.limit <= members.size:
|
||||
raise Forbidden(error_msg)
|
||||
raise Forbidden("The number of members has reached the limit of your subscription.")
|
||||
elif resource == "apps" and 0 < apps.limit <= apps.size:
|
||||
raise Forbidden(error_msg)
|
||||
raise Forbidden("The number of apps has reached the limit of your subscription.")
|
||||
elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size:
|
||||
raise Forbidden(error_msg)
|
||||
raise Forbidden("The capacity of the vector space has reached the limit of your subscription.")
|
||||
elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
|
||||
raise Forbidden(error_msg)
|
||||
raise Forbidden("The number of documents has reached the limit of your subscription.")
|
||||
else:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
|
@ -115,11 +113,7 @@ def cloud_edition_billing_resource_check(
|
|||
return interceptor
|
||||
|
||||
|
||||
def cloud_edition_billing_knowledge_limit_check(
|
||||
resource: str,
|
||||
api_token_type: str,
|
||||
error_msg: str = "To unlock this feature and elevate your Dify experience, please upgrade to a paid plan.",
|
||||
):
|
||||
def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: str):
|
||||
def interceptor(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
|
@ -128,7 +122,9 @@ def cloud_edition_billing_knowledge_limit_check(
|
|||
if features.billing.enabled:
|
||||
if resource == "add_segment":
|
||||
if features.billing.subscription.plan == "sandbox":
|
||||
raise Forbidden(error_msg)
|
||||
raise Forbidden(
|
||||
"To unlock this feature and elevate your Dify experience, please upgrade to a paid plan."
|
||||
)
|
||||
else:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
|
|
|
@ -39,6 +39,7 @@ class AppSiteApi(WebApiResource):
|
|||
"default_language": fields.String,
|
||||
"prompt_public": fields.Boolean,
|
||||
"show_workflow_steps": fields.Boolean,
|
||||
"use_icon_as_answer_icon": fields.Boolean,
|
||||
}
|
||||
|
||||
app_fields = {
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
|
@ -45,22 +46,25 @@ from models.tools import ToolConversationVariables
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseAgentRunner(AppRunner):
|
||||
def __init__(self, tenant_id: str,
|
||||
application_generate_entity: AgentChatAppGenerateEntity,
|
||||
conversation: Conversation,
|
||||
app_config: AgentChatAppConfig,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
config: AgentEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
message: Message,
|
||||
user_id: str,
|
||||
memory: Optional[TokenBufferMemory] = None,
|
||||
prompt_messages: Optional[list[PromptMessage]] = None,
|
||||
variables_pool: Optional[ToolRuntimeVariablePool] = None,
|
||||
db_variables: Optional[ToolConversationVariables] = None,
|
||||
model_instance: ModelInstance = None
|
||||
) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
application_generate_entity: AgentChatAppGenerateEntity,
|
||||
conversation: Conversation,
|
||||
app_config: AgentChatAppConfig,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
config: AgentEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
message: Message,
|
||||
user_id: str,
|
||||
memory: Optional[TokenBufferMemory] = None,
|
||||
prompt_messages: Optional[list[PromptMessage]] = None,
|
||||
variables_pool: Optional[ToolRuntimeVariablePool] = None,
|
||||
db_variables: Optional[ToolConversationVariables] = None,
|
||||
model_instance: ModelInstance = None,
|
||||
) -> None:
|
||||
"""
|
||||
Agent runner
|
||||
:param tenant_id: tenant id
|
||||
|
@ -88,9 +92,7 @@ class BaseAgentRunner(AppRunner):
|
|||
self.message = message
|
||||
self.user_id = user_id
|
||||
self.memory = memory
|
||||
self.history_prompt_messages = self.organize_agent_history(
|
||||
prompt_messages=prompt_messages or []
|
||||
)
|
||||
self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or [])
|
||||
self.variables_pool = variables_pool
|
||||
self.db_variables_pool = db_variables
|
||||
self.model_instance = model_instance
|
||||
|
@ -111,12 +113,16 @@ class BaseAgentRunner(AppRunner):
|
|||
retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None,
|
||||
return_resource=app_config.additional_features.show_retrieve_source,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
hit_callback=hit_callback
|
||||
hit_callback=hit_callback,
|
||||
)
|
||||
# get how many agent thoughts have been created
|
||||
self.agent_thought_count = db.session.query(MessageAgentThought).filter(
|
||||
MessageAgentThought.message_id == self.message.id,
|
||||
).count()
|
||||
self.agent_thought_count = (
|
||||
db.session.query(MessageAgentThought)
|
||||
.filter(
|
||||
MessageAgentThought.message_id == self.message.id,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
db.session.close()
|
||||
|
||||
# check if model supports stream tool call
|
||||
|
@ -135,25 +141,26 @@ class BaseAgentRunner(AppRunner):
|
|||
self.query = None
|
||||
self._current_thoughts: list[PromptMessage] = []
|
||||
|
||||
def _repack_app_generate_entity(self, app_generate_entity: AgentChatAppGenerateEntity) \
|
||||
-> AgentChatAppGenerateEntity:
|
||||
def _repack_app_generate_entity(
|
||||
self, app_generate_entity: AgentChatAppGenerateEntity
|
||||
) -> AgentChatAppGenerateEntity:
|
||||
"""
|
||||
Repack app generate entity
|
||||
"""
|
||||
if app_generate_entity.app_config.prompt_template.simple_prompt_template is None:
|
||||
app_generate_entity.app_config.prompt_template.simple_prompt_template = ''
|
||||
app_generate_entity.app_config.prompt_template.simple_prompt_template = ""
|
||||
|
||||
return app_generate_entity
|
||||
|
||||
|
||||
def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]:
|
||||
"""
|
||||
convert tool to prompt message tool
|
||||
convert tool to prompt message tool
|
||||
"""
|
||||
tool_entity = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_config.app_id,
|
||||
agent_tool=tool,
|
||||
invoke_from=self.application_generate_entity.invoke_from
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
)
|
||||
tool_entity.load_variables(self.variables_pool)
|
||||
|
||||
|
@ -164,7 +171,7 @@ class BaseAgentRunner(AppRunner):
|
|||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
parameters = tool_entity.get_all_runtime_parameters()
|
||||
|
@ -177,19 +184,19 @@ class BaseAgentRunner(AppRunner):
|
|||
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
||||
enum = [option.value for option in parameter.options]
|
||||
|
||||
message_tool.parameters['properties'][parameter.name] = {
|
||||
message_tool.parameters["properties"][parameter.name] = {
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or '',
|
||||
"description": parameter.llm_description or "",
|
||||
}
|
||||
|
||||
if len(enum) > 0:
|
||||
message_tool.parameters['properties'][parameter.name]['enum'] = enum
|
||||
message_tool.parameters["properties"][parameter.name]["enum"] = enum
|
||||
|
||||
if parameter.required:
|
||||
message_tool.parameters['required'].append(parameter.name)
|
||||
message_tool.parameters["required"].append(parameter.name)
|
||||
|
||||
return message_tool, tool_entity
|
||||
|
||||
|
||||
def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool:
|
||||
"""
|
||||
convert dataset retriever tool to prompt message tool
|
||||
|
@ -201,24 +208,24 @@ class BaseAgentRunner(AppRunner):
|
|||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
for parameter in tool.get_runtime_parameters():
|
||||
parameter_type = 'string'
|
||||
|
||||
prompt_tool.parameters['properties'][parameter.name] = {
|
||||
parameter_type = "string"
|
||||
|
||||
prompt_tool.parameters["properties"][parameter.name] = {
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or '',
|
||||
"description": parameter.llm_description or "",
|
||||
}
|
||||
|
||||
if parameter.required:
|
||||
if parameter.name not in prompt_tool.parameters['required']:
|
||||
prompt_tool.parameters['required'].append(parameter.name)
|
||||
if parameter.name not in prompt_tool.parameters["required"]:
|
||||
prompt_tool.parameters["required"].append(parameter.name)
|
||||
|
||||
return prompt_tool
|
||||
|
||||
def _init_prompt_tools(self) -> tuple[dict[str, Tool], list[PromptMessageTool]]:
|
||||
|
||||
def _init_prompt_tools(self) -> tuple[Mapping[str, Tool], Sequence[PromptMessageTool]]:
|
||||
"""
|
||||
Init tools
|
||||
"""
|
||||
|
@ -261,51 +268,51 @@ class BaseAgentRunner(AppRunner):
|
|||
enum = []
|
||||
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
||||
enum = [option.value for option in parameter.options]
|
||||
|
||||
prompt_tool.parameters['properties'][parameter.name] = {
|
||||
|
||||
prompt_tool.parameters["properties"][parameter.name] = {
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or '',
|
||||
"description": parameter.llm_description or "",
|
||||
}
|
||||
|
||||
if len(enum) > 0:
|
||||
prompt_tool.parameters['properties'][parameter.name]['enum'] = enum
|
||||
prompt_tool.parameters["properties"][parameter.name]["enum"] = enum
|
||||
|
||||
if parameter.required:
|
||||
if parameter.name not in prompt_tool.parameters['required']:
|
||||
prompt_tool.parameters['required'].append(parameter.name)
|
||||
if parameter.name not in prompt_tool.parameters["required"]:
|
||||
prompt_tool.parameters["required"].append(parameter.name)
|
||||
|
||||
return prompt_tool
|
||||
|
||||
def create_agent_thought(self, message_id: str, message: str,
|
||||
tool_name: str, tool_input: str, messages_ids: list[str]
|
||||
) -> MessageAgentThought:
|
||||
|
||||
def create_agent_thought(
|
||||
self, message_id: str, message: str, tool_name: str, tool_input: str, messages_ids: list[str]
|
||||
) -> MessageAgentThought:
|
||||
"""
|
||||
Create agent thought
|
||||
"""
|
||||
thought = MessageAgentThought(
|
||||
message_id=message_id,
|
||||
message_chain_id=None,
|
||||
thought='',
|
||||
thought="",
|
||||
tool=tool_name,
|
||||
tool_labels_str='{}',
|
||||
tool_meta_str='{}',
|
||||
tool_labels_str="{}",
|
||||
tool_meta_str="{}",
|
||||
tool_input=tool_input,
|
||||
message=message,
|
||||
message_token=0,
|
||||
message_unit_price=0,
|
||||
message_price_unit=0,
|
||||
message_files=json.dumps(messages_ids) if messages_ids else '',
|
||||
answer='',
|
||||
observation='',
|
||||
message_files=json.dumps(messages_ids) if messages_ids else "",
|
||||
answer="",
|
||||
observation="",
|
||||
answer_token=0,
|
||||
answer_unit_price=0,
|
||||
answer_price_unit=0,
|
||||
tokens=0,
|
||||
total_price=0,
|
||||
position=self.agent_thought_count + 1,
|
||||
currency='USD',
|
||||
currency="USD",
|
||||
latency=0,
|
||||
created_by_role='account',
|
||||
created_by_role="account",
|
||||
created_by=self.user_id,
|
||||
)
|
||||
|
||||
|
@ -318,22 +325,22 @@ class BaseAgentRunner(AppRunner):
|
|||
|
||||
return thought
|
||||
|
||||
def save_agent_thought(self,
|
||||
agent_thought: MessageAgentThought,
|
||||
tool_name: str,
|
||||
tool_input: Union[str, dict],
|
||||
thought: str,
|
||||
observation: Union[str, dict],
|
||||
tool_invoke_meta: Union[str, dict],
|
||||
answer: str,
|
||||
messages_ids: list[str],
|
||||
llm_usage: LLMUsage = None) -> MessageAgentThought:
|
||||
def save_agent_thought(
|
||||
self,
|
||||
agent_thought: MessageAgentThought,
|
||||
tool_name: str,
|
||||
tool_input: Union[str, dict],
|
||||
thought: str,
|
||||
observation: Union[str, dict],
|
||||
tool_invoke_meta: Union[str, dict],
|
||||
answer: str,
|
||||
messages_ids: list[str],
|
||||
llm_usage: LLMUsage = None,
|
||||
) -> MessageAgentThought:
|
||||
"""
|
||||
Save agent thought
|
||||
"""
|
||||
agent_thought = db.session.query(MessageAgentThought).filter(
|
||||
MessageAgentThought.id == agent_thought.id
|
||||
).first()
|
||||
agent_thought = db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()
|
||||
|
||||
if thought is not None:
|
||||
agent_thought.thought = thought
|
||||
|
@ -356,7 +363,7 @@ class BaseAgentRunner(AppRunner):
|
|||
observation = json.dumps(observation, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
observation = json.dumps(observation)
|
||||
|
||||
|
||||
agent_thought.observation = observation
|
||||
|
||||
if answer is not None:
|
||||
|
@ -364,7 +371,7 @@ class BaseAgentRunner(AppRunner):
|
|||
|
||||
if messages_ids is not None and len(messages_ids) > 0:
|
||||
agent_thought.message_files = json.dumps(messages_ids)
|
||||
|
||||
|
||||
if llm_usage:
|
||||
agent_thought.message_token = llm_usage.prompt_tokens
|
||||
agent_thought.message_price_unit = llm_usage.prompt_price_unit
|
||||
|
@ -377,7 +384,7 @@ class BaseAgentRunner(AppRunner):
|
|||
|
||||
# check if tool labels is not empty
|
||||
labels = agent_thought.tool_labels or {}
|
||||
tools = agent_thought.tool.split(';') if agent_thought.tool else []
|
||||
tools = agent_thought.tool.split(";") if agent_thought.tool else []
|
||||
for tool in tools:
|
||||
if not tool:
|
||||
continue
|
||||
|
@ -386,7 +393,7 @@ class BaseAgentRunner(AppRunner):
|
|||
if tool_label:
|
||||
labels[tool] = tool_label.to_dict()
|
||||
else:
|
||||
labels[tool] = {'en_US': tool, 'zh_Hans': tool}
|
||||
labels[tool] = {"en_US": tool, "zh_Hans": tool}
|
||||
|
||||
agent_thought.tool_labels_str = json.dumps(labels)
|
||||
|
||||
|
@ -401,14 +408,18 @@ class BaseAgentRunner(AppRunner):
|
|||
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
||||
|
||||
def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables):
|
||||
"""
|
||||
convert tool variables to db variables
|
||||
"""
|
||||
db_variables = db.session.query(ToolConversationVariables).filter(
|
||||
ToolConversationVariables.conversation_id == self.message.conversation_id,
|
||||
).first()
|
||||
db_variables = (
|
||||
db.session.query(ToolConversationVariables)
|
||||
.filter(
|
||||
ToolConversationVariables.conversation_id == self.message.conversation_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
db_variables.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
|
||||
|
@ -425,9 +436,14 @@ class BaseAgentRunner(AppRunner):
|
|||
if isinstance(prompt_message, SystemPromptMessage):
|
||||
result.append(prompt_message)
|
||||
|
||||
messages: list[Message] = db.session.query(Message).filter(
|
||||
Message.conversation_id == self.message.conversation_id,
|
||||
).order_by(Message.created_at.asc()).all()
|
||||
messages: list[Message] = (
|
||||
db.session.query(Message)
|
||||
.filter(
|
||||
Message.conversation_id == self.message.conversation_id,
|
||||
)
|
||||
.order_by(Message.created_at.asc())
|
||||
.all()
|
||||
)
|
||||
|
||||
for message in messages:
|
||||
if message.id == self.message.id:
|
||||
|
@ -439,13 +455,13 @@ class BaseAgentRunner(AppRunner):
|
|||
for agent_thought in agent_thoughts:
|
||||
tools = agent_thought.tool
|
||||
if tools:
|
||||
tools = tools.split(';')
|
||||
tools = tools.split(";")
|
||||
tool_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
tool_call_response: list[ToolPromptMessage] = []
|
||||
try:
|
||||
tool_inputs = json.loads(agent_thought.tool_input)
|
||||
except Exception as e:
|
||||
tool_inputs = { tool: {} for tool in tools }
|
||||
tool_inputs = {tool: {} for tool in tools}
|
||||
try:
|
||||
tool_responses = json.loads(agent_thought.observation)
|
||||
except Exception as e:
|
||||
|
@ -454,27 +470,33 @@ class BaseAgentRunner(AppRunner):
|
|||
for tool in tools:
|
||||
# generate a uuid for tool call
|
||||
tool_call_id = str(uuid.uuid4())
|
||||
tool_calls.append(AssistantPromptMessage.ToolCall(
|
||||
id=tool_call_id,
|
||||
type='function',
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=tool,
|
||||
arguments=json.dumps(tool_inputs.get(tool, {})),
|
||||
tool_calls.append(
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id=tool_call_id,
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=tool,
|
||||
arguments=json.dumps(tool_inputs.get(tool, {})),
|
||||
),
|
||||
)
|
||||
))
|
||||
tool_call_response.append(ToolPromptMessage(
|
||||
content=tool_responses.get(tool, agent_thought.observation),
|
||||
name=tool,
|
||||
tool_call_id=tool_call_id,
|
||||
))
|
||||
)
|
||||
tool_call_response.append(
|
||||
ToolPromptMessage(
|
||||
content=tool_responses.get(tool, agent_thought.observation),
|
||||
name=tool,
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
)
|
||||
|
||||
result.extend([
|
||||
AssistantPromptMessage(
|
||||
content=agent_thought.thought,
|
||||
tool_calls=tool_calls,
|
||||
),
|
||||
*tool_call_response
|
||||
])
|
||||
result.extend(
|
||||
[
|
||||
AssistantPromptMessage(
|
||||
content=agent_thought.thought,
|
||||
tool_calls=tool_calls,
|
||||
),
|
||||
*tool_call_response,
|
||||
]
|
||||
)
|
||||
if not tools:
|
||||
result.append(AssistantPromptMessage(content=agent_thought.thought))
|
||||
else:
|
||||
|
@ -496,10 +518,7 @@ class BaseAgentRunner(AppRunner):
|
|||
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
|
||||
|
||||
if file_extra_config:
|
||||
file_objs = message_file_parser.transform_message_files(
|
||||
files,
|
||||
file_extra_config
|
||||
)
|
||||
file_objs = message_file_parser.transform_message_files(files, file_extra_config)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
|
|
|
@ -93,7 +93,7 @@ class DatasetConfigManager:
|
|||
reranking_model=dataset_configs.get('reranking_model'),
|
||||
weights=dataset_configs.get('weights'),
|
||||
reranking_enabled=dataset_configs.get('reranking_enabled', True),
|
||||
rerank_mode=dataset_configs.get('rerank_mode', 'reranking_model'),
|
||||
rerank_mode=dataset_configs.get('reranking_mode', 'reranking_model'),
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -4,12 +4,10 @@ import os
|
|||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
from typing import Any, Literal, Optional, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import contexts
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
|
@ -20,33 +18,49 @@ from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGe
|
|||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
InvokeFrom,
|
||||
)
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
|
||||
from core.file.message_file_parser import MessageFileParser
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import App, Conversation, EndUser, Message
|
||||
from models.workflow import ConversationVariable, Workflow
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
@overload
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
):
|
||||
stream: Literal[True] = True,
|
||||
) -> Generator[str, None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: Literal[False] = False,
|
||||
) -> dict: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
) -> dict[str, Any] | Generator[str, Any, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
@ -134,7 +148,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
node_id: str,
|
||||
user: Account,
|
||||
args: dict,
|
||||
stream: bool = True):
|
||||
stream: bool = True) \
|
||||
-> dict[str, Any] | Generator[str, Any, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
@ -151,16 +166,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
if args.get('inputs') is None:
|
||||
raise ValueError('inputs is required')
|
||||
|
||||
extras = {
|
||||
"auto_generate_conversation_name": False
|
||||
}
|
||||
|
||||
# get conversation
|
||||
conversation = None
|
||||
conversation_id = args.get('conversation_id')
|
||||
if conversation_id:
|
||||
conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user)
|
||||
|
||||
# convert to app config
|
||||
app_config = AdvancedChatAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
|
@ -171,14 +176,16 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
application_generate_entity = AdvancedChatAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
conversation_id=conversation.id if conversation else None,
|
||||
conversation_id=None,
|
||||
inputs={},
|
||||
query='',
|
||||
files=[],
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
extras=extras,
|
||||
extras={
|
||||
"auto_generate_conversation_name": False
|
||||
},
|
||||
single_iteration_run=AdvancedChatAppGenerateEntity.SingleIterationRunEntity(
|
||||
node_id=node_id,
|
||||
inputs=args['inputs']
|
||||
|
@ -191,17 +198,28 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
user=user,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
application_generate_entity=application_generate_entity,
|
||||
conversation=conversation,
|
||||
conversation=None,
|
||||
stream=stream
|
||||
)
|
||||
|
||||
def _generate(self, *,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
invoke_from: InvokeFrom,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
conversation: Conversation | None = None,
|
||||
stream: bool = True):
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
invoke_from: InvokeFrom,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
conversation: Optional[Conversation] = None,
|
||||
stream: bool = True) \
|
||||
-> dict[str, Any] | Generator[str, Any, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
:param workflow: Workflow
|
||||
:param user: account or end user
|
||||
:param invoke_from: invoke from source
|
||||
:param application_generate_entity: application generate entity
|
||||
:param conversation: conversation
|
||||
:param stream: is stream
|
||||
"""
|
||||
is_first_conversation = False
|
||||
if not conversation:
|
||||
is_first_conversation = True
|
||||
|
@ -216,7 +234,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
# update conversation features
|
||||
conversation.override_model_configs = workflow.features
|
||||
db.session.commit()
|
||||
# db.session.refresh(conversation)
|
||||
db.session.refresh(conversation)
|
||||
|
||||
# init queue manager
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
|
@ -228,67 +246,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
message_id=message.id
|
||||
)
|
||||
|
||||
# Init conversation variables
|
||||
stmt = select(ConversationVariable).where(
|
||||
ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
conversation_variables = session.scalars(stmt).all()
|
||||
if not conversation_variables:
|
||||
# Create conversation variables if they don't exist.
|
||||
conversation_variables = [
|
||||
ConversationVariable.from_variable(
|
||||
app_id=conversation.app_id, conversation_id=conversation.id, variable=variable
|
||||
)
|
||||
for variable in workflow.conversation_variables
|
||||
]
|
||||
session.add_all(conversation_variables)
|
||||
# Convert database entities to variables.
|
||||
conversation_variables = [item.to_variable() for item in conversation_variables]
|
||||
|
||||
session.commit()
|
||||
|
||||
# Increment dialogue count.
|
||||
conversation.dialogue_count += 1
|
||||
|
||||
conversation_id = conversation.id
|
||||
conversation_dialogue_count = conversation.dialogue_count
|
||||
db.session.commit()
|
||||
db.session.refresh(conversation)
|
||||
|
||||
inputs = application_generate_entity.inputs
|
||||
query = application_generate_entity.query
|
||||
files = application_generate_entity.files
|
||||
|
||||
user_id = None
|
||||
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
|
||||
if end_user:
|
||||
user_id = end_user.session_id
|
||||
else:
|
||||
user_id = application_generate_entity.user_id
|
||||
|
||||
# Create a variable pool.
|
||||
system_inputs = {
|
||||
SystemVariableKey.QUERY: query,
|
||||
SystemVariableKey.FILES: files,
|
||||
SystemVariableKey.CONVERSATION_ID: conversation_id,
|
||||
SystemVariableKey.USER_ID: user_id,
|
||||
SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count,
|
||||
}
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=inputs,
|
||||
environment_variables=workflow.environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
contexts.workflow_variable_pool.set(variable_pool)
|
||||
|
||||
# new thread
|
||||
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'flask_app': current_app._get_current_object(), # type: ignore
|
||||
'application_generate_entity': application_generate_entity,
|
||||
'queue_manager': queue_manager,
|
||||
'conversation_id': conversation.id,
|
||||
'message_id': message.id,
|
||||
'context': contextvars.copy_context(),
|
||||
})
|
||||
|
@ -314,6 +277,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
def _generate_worker(self, flask_app: Flask,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation_id: str,
|
||||
message_id: str,
|
||||
context: contextvars.Context) -> None:
|
||||
"""
|
||||
|
@ -329,28 +293,19 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
var.set(val)
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
runner = AdvancedChatAppRunner()
|
||||
if application_generate_entity.single_iteration_run:
|
||||
single_iteration_run = application_generate_entity.single_iteration_run
|
||||
runner.single_iteration_run(
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
workflow_id=application_generate_entity.app_config.workflow_id,
|
||||
queue_manager=queue_manager,
|
||||
inputs=single_iteration_run.inputs,
|
||||
node_id=single_iteration_run.node_id,
|
||||
user_id=application_generate_entity.user_id
|
||||
)
|
||||
else:
|
||||
# get message
|
||||
message = self._get_message(message_id)
|
||||
# get conversation and message
|
||||
conversation = self._get_conversation(conversation_id)
|
||||
message = self._get_message(message_id)
|
||||
|
||||
# chatbot app
|
||||
runner = AdvancedChatAppRunner()
|
||||
runner.run(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
message=message
|
||||
)
|
||||
# chatbot app
|
||||
runner = AdvancedChatAppRunner(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message
|
||||
)
|
||||
|
||||
runner.run()
|
||||
except GenerateTaskStoppedException:
|
||||
pass
|
||||
except InvokeAuthorizationError:
|
||||
|
|
|
@ -1,49 +1,67 @@
|
|||
import logging
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||
from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
InvokeFrom,
|
||||
)
|
||||
from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAnnotationReplyEvent,
|
||||
QueueStopEvent,
|
||||
QueueTextChunkEvent,
|
||||
)
|
||||
from core.moderation.base import ModerationException
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||
from core.workflow.entities.node_entities import UserFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from models import App, Message, Workflow
|
||||
from models.model import App, Conversation, EndUser, Message
|
||||
from models.workflow import ConversationVariable, WorkflowType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AdvancedChatAppRunner(AppRunner):
|
||||
class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
"""
|
||||
AdvancedChat Application Runner
|
||||
"""
|
||||
|
||||
def run(
|
||||
self,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
message: Message,
|
||||
def __init__(
|
||||
self,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message
|
||||
) -> None:
|
||||
"""
|
||||
Run application
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: application queue manager
|
||||
:param conversation: conversation
|
||||
:param message: message
|
||||
"""
|
||||
super().__init__(queue_manager)
|
||||
|
||||
self.application_generate_entity = application_generate_entity
|
||||
self.conversation = conversation
|
||||
self.message = message
|
||||
|
||||
def run(self) -> None:
|
||||
"""
|
||||
Run application
|
||||
:return:
|
||||
"""
|
||||
app_config = application_generate_entity.app_config
|
||||
app_config = self.application_generate_entity.app_config
|
||||
app_config = cast(AdvancedChatAppConfig, app_config)
|
||||
|
||||
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
|
||||
|
@ -54,101 +72,133 @@ class AdvancedChatAppRunner(AppRunner):
|
|||
if not workflow:
|
||||
raise ValueError('Workflow not initialized')
|
||||
|
||||
inputs = application_generate_entity.inputs
|
||||
query = application_generate_entity.query
|
||||
user_id = None
|
||||
if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
|
||||
if end_user:
|
||||
user_id = end_user.session_id
|
||||
else:
|
||||
user_id = self.application_generate_entity.user_id
|
||||
|
||||
# moderation
|
||||
if self.handle_input_moderation(
|
||||
queue_manager=queue_manager,
|
||||
app_record=app_record,
|
||||
app_generate_entity=application_generate_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
message_id=message.id,
|
||||
):
|
||||
return
|
||||
workflow_callbacks: list[WorkflowCallback] = []
|
||||
if bool(os.environ.get("DEBUG", 'False').lower() == 'true'):
|
||||
workflow_callbacks.append(WorkflowLoggingCallback())
|
||||
|
||||
# annotation reply
|
||||
if self.handle_annotation_reply(
|
||||
app_record=app_record,
|
||||
message=message,
|
||||
query=query,
|
||||
queue_manager=queue_manager,
|
||||
app_generate_entity=application_generate_entity,
|
||||
):
|
||||
return
|
||||
if self.application_generate_entity.single_iteration_run:
|
||||
# if only single iteration run is requested
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
workflow=workflow,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
query = self.application_generate_entity.query
|
||||
files = self.application_generate_entity.files
|
||||
|
||||
# moderation
|
||||
if self.handle_input_moderation(
|
||||
app_record=app_record,
|
||||
app_generate_entity=self.application_generate_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
message_id=self.message.id
|
||||
):
|
||||
return
|
||||
|
||||
# annotation reply
|
||||
if self.handle_annotation_reply(
|
||||
app_record=app_record,
|
||||
message=self.message,
|
||||
query=query,
|
||||
app_generate_entity=self.application_generate_entity
|
||||
):
|
||||
return
|
||||
|
||||
# Init conversation variables
|
||||
stmt = select(ConversationVariable).where(
|
||||
ConversationVariable.app_id == self.conversation.app_id, ConversationVariable.conversation_id == self.conversation.id
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
conversation_variables = session.scalars(stmt).all()
|
||||
if not conversation_variables:
|
||||
# Create conversation variables if they don't exist.
|
||||
conversation_variables = [
|
||||
ConversationVariable.from_variable(
|
||||
app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable
|
||||
)
|
||||
for variable in workflow.conversation_variables
|
||||
]
|
||||
session.add_all(conversation_variables)
|
||||
# Convert database entities to variables.
|
||||
conversation_variables = [item.to_variable() for item in conversation_variables]
|
||||
|
||||
session.commit()
|
||||
|
||||
# Increment dialogue count.
|
||||
self.conversation.dialogue_count += 1
|
||||
|
||||
conversation_dialogue_count = self.conversation.dialogue_count
|
||||
db.session.commit()
|
||||
|
||||
# Create a variable pool.
|
||||
system_inputs = {
|
||||
SystemVariableKey.QUERY: query,
|
||||
SystemVariableKey.FILES: files,
|
||||
SystemVariableKey.CONVERSATION_ID: self.conversation.id,
|
||||
SystemVariableKey.USER_ID: user_id,
|
||||
SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count,
|
||||
}
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=inputs,
|
||||
environment_variables=workflow.environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = self._init_graph(graph_config=workflow.graph_dict)
|
||||
|
||||
db.session.close()
|
||||
|
||||
workflow_callbacks: list[WorkflowCallback] = [
|
||||
WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)
|
||||
]
|
||||
|
||||
if bool(os.environ.get('DEBUG', 'False').lower() == 'true'):
|
||||
workflow_callbacks.append(WorkflowLoggingCallback())
|
||||
|
||||
# RUN WORKFLOW
|
||||
workflow_engine_manager = WorkflowEngineManager()
|
||||
workflow_engine_manager.run_workflow(
|
||||
workflow=workflow,
|
||||
user_id=application_generate_entity.user_id,
|
||||
user_from=UserFrom.ACCOUNT
|
||||
if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
|
||||
else UserFrom.END_USER,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
workflow_entry = WorkflowEntry(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_type=WorkflowType.value_of(workflow.type),
|
||||
graph=graph,
|
||||
graph_config=workflow.graph_dict,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
user_from=(
|
||||
UserFrom.ACCOUNT
|
||||
if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
|
||||
else UserFrom.END_USER
|
||||
),
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
call_depth=self.application_generate_entity.call_depth,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
generator = workflow_entry.run(
|
||||
callbacks=workflow_callbacks,
|
||||
call_depth=application_generate_entity.call_depth,
|
||||
)
|
||||
|
||||
def single_iteration_run(
|
||||
self, app_id: str, workflow_id: str, queue_manager: AppQueueManager, inputs: dict, node_id: str, user_id: str
|
||||
) -> None:
|
||||
"""
|
||||
Single iteration run
|
||||
"""
|
||||
app_record = db.session.query(App).filter(App.id == app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError('App not found')
|
||||
|
||||
workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id)
|
||||
if not workflow:
|
||||
raise ValueError('Workflow not initialized')
|
||||
|
||||
workflow_callbacks = [WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)]
|
||||
|
||||
workflow_engine_manager = WorkflowEngineManager()
|
||||
workflow_engine_manager.single_step_run_iteration_workflow_node(
|
||||
workflow=workflow, node_id=node_id, user_id=user_id, user_inputs=inputs, callbacks=workflow_callbacks
|
||||
)
|
||||
|
||||
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
|
||||
"""
|
||||
Get workflow
|
||||
"""
|
||||
# fetch workflow by workflow_id
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.filter(
|
||||
Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# return workflow
|
||||
return workflow
|
||||
for event in generator:
|
||||
self._handle_event(workflow_entry, event)
|
||||
|
||||
def handle_input_moderation(
|
||||
self,
|
||||
queue_manager: AppQueueManager,
|
||||
app_record: App,
|
||||
app_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
inputs: Mapping[str, Any],
|
||||
query: str,
|
||||
message_id: str,
|
||||
self,
|
||||
app_record: App,
|
||||
app_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
inputs: Mapping[str, Any],
|
||||
query: str,
|
||||
message_id: str
|
||||
) -> bool:
|
||||
"""
|
||||
Handle input moderation
|
||||
:param queue_manager: application queue manager
|
||||
:param app_record: app record
|
||||
:param app_generate_entity: application generate entity
|
||||
:param inputs: inputs
|
||||
|
@ -167,30 +217,23 @@ class AdvancedChatAppRunner(AppRunner):
|
|||
message_id=message_id,
|
||||
)
|
||||
except ModerationException as e:
|
||||
self._stream_output(
|
||||
queue_manager=queue_manager,
|
||||
self._complete_with_stream_output(
|
||||
text=str(e),
|
||||
stream=app_generate_entity.stream,
|
||||
stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION,
|
||||
stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def handle_annotation_reply(
|
||||
self,
|
||||
app_record: App,
|
||||
message: Message,
|
||||
query: str,
|
||||
queue_manager: AppQueueManager,
|
||||
app_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
) -> bool:
|
||||
def handle_annotation_reply(self, app_record: App,
|
||||
message: Message,
|
||||
query: str,
|
||||
app_generate_entity: AdvancedChatAppGenerateEntity) -> bool:
|
||||
"""
|
||||
Handle annotation reply
|
||||
:param app_record: app record
|
||||
:param message: message
|
||||
:param query: query
|
||||
:param queue_manager: application queue manager
|
||||
:param app_generate_entity: application generate entity
|
||||
"""
|
||||
# annotation reply
|
||||
|
@ -203,37 +246,32 @@ class AdvancedChatAppRunner(AppRunner):
|
|||
)
|
||||
|
||||
if annotation_reply:
|
||||
queue_manager.publish(
|
||||
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), PublishFrom.APPLICATION_MANAGER
|
||||
self._publish_event(
|
||||
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id)
|
||||
)
|
||||
|
||||
self._stream_output(
|
||||
queue_manager=queue_manager,
|
||||
self._complete_with_stream_output(
|
||||
text=annotation_reply.content,
|
||||
stream=app_generate_entity.stream,
|
||||
stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY,
|
||||
stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _stream_output(
|
||||
self, queue_manager: AppQueueManager, text: str, stream: bool, stopped_by: QueueStopEvent.StopBy
|
||||
) -> None:
|
||||
def _complete_with_stream_output(self,
|
||||
text: str,
|
||||
stopped_by: QueueStopEvent.StopBy) -> None:
|
||||
"""
|
||||
Direct output
|
||||
:param queue_manager: application queue manager
|
||||
:param text: text
|
||||
:param stream: stream
|
||||
:return:
|
||||
"""
|
||||
if stream:
|
||||
index = 0
|
||||
for token in text:
|
||||
queue_manager.publish(QueueTextChunkEvent(text=token), PublishFrom.APPLICATION_MANAGER)
|
||||
index += 1
|
||||
time.sleep(0.01)
|
||||
else:
|
||||
queue_manager.publish(QueueTextChunkEvent(text=text), PublishFrom.APPLICATION_MANAGER)
|
||||
self._publish_event(
|
||||
QueueTextChunkEvent(
|
||||
text=text
|
||||
)
|
||||
)
|
||||
|
||||
queue_manager.publish(QueueStopEvent(stopped_by=stopped_by), PublishFrom.APPLICATION_MANAGER)
|
||||
self._publish_event(
|
||||
QueueStopEvent(stopped_by=stopped_by)
|
||||
)
|
||||
|
|
|
@ -2,9 +2,8 @@ import json
|
|||
import logging
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional, Union, cast
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import contexts
|
||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
|
@ -22,6 +21,9 @@ from core.app.entities.queue_entities import (
|
|||
QueueNodeFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
QueueParallelBranchRunStartedEvent,
|
||||
QueueParallelBranchRunSucceededEvent,
|
||||
QueuePingEvent,
|
||||
QueueRetrieverResourcesEvent,
|
||||
QueueStopEvent,
|
||||
|
@ -31,34 +33,28 @@ from core.app.entities.queue_entities import (
|
|||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
AdvancedChatTaskState,
|
||||
ChatbotAppBlockingResponse,
|
||||
ChatbotAppStreamResponse,
|
||||
ChatflowStreamGenerateRoute,
|
||||
ErrorStreamResponse,
|
||||
MessageAudioEndStreamResponse,
|
||||
MessageAudioStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
StreamResponse,
|
||||
WorkflowTaskState,
|
||||
)
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
|
||||
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
|
||||
from core.file.file_obj import FileVar
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from core.workflow.nodes.answer.entities import TextGenerateRouteChunk, VarGenerateRouteChunk
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import Conversation, EndUser, Message
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
|
||||
|
@ -69,16 +65,15 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
"""
|
||||
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
_task_state: AdvancedChatTaskState
|
||||
_task_state: WorkflowTaskState
|
||||
_application_generate_entity: AdvancedChatAppGenerateEntity
|
||||
_workflow: Workflow
|
||||
_user: Union[Account, EndUser]
|
||||
# Deprecated
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
_iteration_nested_relations: dict[str, list[str]]
|
||||
|
||||
def __init__(
|
||||
self, application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
self,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
|
@ -106,7 +101,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
self._workflow = workflow
|
||||
self._conversation = conversation
|
||||
self._message = message
|
||||
# Deprecated
|
||||
self._workflow_system_variables = {
|
||||
SystemVariableKey.QUERY: message.query,
|
||||
SystemVariableKey.FILES: application_generate_entity.files,
|
||||
|
@ -114,12 +108,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
SystemVariableKey.USER_ID: user_id,
|
||||
}
|
||||
|
||||
self._task_state = AdvancedChatTaskState(
|
||||
usage=LLMUsage.empty_usage()
|
||||
)
|
||||
self._task_state = WorkflowTaskState()
|
||||
|
||||
self._iteration_nested_relations = self._get_iteration_nested_relations(self._workflow.graph_dict)
|
||||
self._stream_generate_routes = self._get_stream_generate_routes()
|
||||
self._conversation_name_generate_thread = None
|
||||
|
||||
def process(self):
|
||||
|
@ -140,6 +130,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
generator = self._wrapper_process_stream_response(
|
||||
trace_manager=self._application_generate_entity.trace_manager
|
||||
)
|
||||
|
||||
if self._stream:
|
||||
return self._to_stream_response(generator)
|
||||
else:
|
||||
|
@ -199,17 +190,18 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
|
||||
Generator[StreamResponse, None, None]:
|
||||
|
||||
publisher = None
|
||||
tts_publisher = None
|
||||
task_id = self._application_generate_entity.task_id
|
||||
tenant_id = self._application_generate_entity.app_config.tenant_id
|
||||
features_dict = self._workflow.features_dict
|
||||
|
||||
if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[
|
||||
'text_to_speech'].get('autoPlay') == 'enabled':
|
||||
publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
|
||||
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
|
||||
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
|
||||
|
||||
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
audio_response = self._listenAudioMsg(publisher, task_id=task_id)
|
||||
audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id)
|
||||
if audio_response:
|
||||
yield audio_response
|
||||
else:
|
||||
|
@ -220,9 +212,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
# timeout
|
||||
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
|
||||
try:
|
||||
if not publisher:
|
||||
if not tts_publisher:
|
||||
break
|
||||
audio_trunk = publisher.checkAndGetAudio()
|
||||
audio_trunk = tts_publisher.checkAndGetAudio()
|
||||
if audio_trunk is None:
|
||||
# release cpu
|
||||
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
|
||||
|
@ -240,34 +232,34 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
|
||||
def _process_stream_response(
|
||||
self,
|
||||
publisher: AppGeneratorTTSPublisher,
|
||||
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""
|
||||
Process stream response.
|
||||
:return:
|
||||
"""
|
||||
for message in self._queue_manager.listen():
|
||||
if (message.event
|
||||
and getattr(message.event, 'metadata', None)
|
||||
and message.event.metadata.get('is_answer_previous_node', False)
|
||||
and publisher):
|
||||
publisher.publish(message=message)
|
||||
elif (hasattr(message.event, 'execution_metadata')
|
||||
and message.event.execution_metadata
|
||||
and message.event.execution_metadata.get('is_answer_previous_node', False)
|
||||
and publisher):
|
||||
publisher.publish(message=message)
|
||||
event = message.event
|
||||
# init fake graph runtime state
|
||||
graph_runtime_state = None
|
||||
workflow_run = None
|
||||
|
||||
if isinstance(event, QueueErrorEvent):
|
||||
for queue_message in self._queue_manager.listen():
|
||||
event = queue_message.event
|
||||
|
||||
if isinstance(event, QueuePingEvent):
|
||||
yield self._ping_stream_response()
|
||||
elif isinstance(event, QueueErrorEvent):
|
||||
err = self._handle_error(event, self._message)
|
||||
yield self._error_to_stream_response(err)
|
||||
break
|
||||
elif isinstance(event, QueueWorkflowStartedEvent):
|
||||
workflow_run = self._handle_workflow_start()
|
||||
# override graph runtime state
|
||||
graph_runtime_state = event.graph_runtime_state
|
||||
|
||||
self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
|
||||
# init workflow run
|
||||
workflow_run = self._handle_workflow_run_start()
|
||||
|
||||
self._refetch_message()
|
||||
self._message.workflow_run_id = workflow_run.id
|
||||
|
||||
db.session.commit()
|
||||
|
@ -279,133 +271,242 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
workflow_run=workflow_run
|
||||
)
|
||||
elif isinstance(event, QueueNodeStartedEvent):
|
||||
workflow_node_execution = self._handle_node_start(event)
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
# search stream_generate_routes if node id is answer start at node
|
||||
if not self._task_state.current_stream_generate_state and event.node_id in self._stream_generate_routes:
|
||||
self._task_state.current_stream_generate_state = self._stream_generate_routes[event.node_id]
|
||||
# reset current route position to 0
|
||||
self._task_state.current_stream_generate_state.current_route_position = 0
|
||||
workflow_node_execution = self._handle_node_execution_start(
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
)
|
||||
|
||||
# generate stream outputs when node started
|
||||
yield from self._generate_stream_outputs_when_node_started()
|
||||
|
||||
yield self._workflow_node_start_to_stream_response(
|
||||
response = self._workflow_node_start_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution
|
||||
)
|
||||
elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent):
|
||||
workflow_node_execution = self._handle_node_finished(event)
|
||||
|
||||
# stream outputs when node finished
|
||||
generator = self._generate_stream_outputs_when_node_finished()
|
||||
if generator:
|
||||
yield from generator
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueNodeSucceededEvent):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_success(event)
|
||||
|
||||
yield self._workflow_node_finish_to_stream_response(
|
||||
response = self._workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution
|
||||
)
|
||||
|
||||
if isinstance(event, QueueNodeFailedEvent):
|
||||
yield from self._handle_iteration_exception(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
error=f'Child node failed: {event.error}'
|
||||
)
|
||||
elif isinstance(event, QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent):
|
||||
if isinstance(event, QueueIterationNextEvent):
|
||||
# clear ran node execution infos of current iteration
|
||||
iteration_relations = self._iteration_nested_relations.get(event.node_id)
|
||||
if iteration_relations:
|
||||
for node_id in iteration_relations:
|
||||
self._task_state.ran_node_execution_infos.pop(node_id, None)
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueNodeFailedEvent):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
|
||||
|
||||
yield self._handle_iteration_to_stream_response(self._application_generate_entity.task_id, event)
|
||||
self._handle_iteration_operation(event)
|
||||
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
|
||||
workflow_run = self._handle_workflow_finished(
|
||||
event, conversation_id=self._conversation.id, trace_manager=trace_manager
|
||||
response = self._workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution
|
||||
)
|
||||
if workflow_run:
|
||||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueParallelBranchRunStartedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
yield self._workflow_parallel_branch_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
yield self._workflow_parallel_branch_finished_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationStartEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
yield self._workflow_iteration_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationNextEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
yield self._workflow_iteration_next_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationCompletedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
yield self._workflow_iteration_completed_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, QueueWorkflowSucceededEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception('Graph runtime state not initialized.')
|
||||
|
||||
workflow_run = self._handle_workflow_run_success(
|
||||
workflow_run=workflow_run,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
outputs=json.dumps(event.outputs) if event.outputs else None,
|
||||
conversation_id=self._conversation.id,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
yield self._workflow_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run
|
||||
)
|
||||
|
||||
self._queue_manager.publish(
|
||||
QueueAdvancedChatMessageEndEvent(),
|
||||
PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
elif isinstance(event, QueueWorkflowFailedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception('Graph runtime state not initialized.')
|
||||
|
||||
workflow_run = self._handle_workflow_run_failed(
|
||||
workflow_run=workflow_run,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
status=WorkflowRunStatus.FAILED,
|
||||
error=event.error,
|
||||
conversation_id=self._conversation.id,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
yield self._workflow_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run
|
||||
)
|
||||
|
||||
err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))
|
||||
yield self._error_to_stream_response(self._handle_error(err_event, self._message))
|
||||
break
|
||||
elif isinstance(event, QueueStopEvent):
|
||||
if workflow_run and graph_runtime_state:
|
||||
workflow_run = self._handle_workflow_run_failed(
|
||||
workflow_run=workflow_run,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
status=WorkflowRunStatus.STOPPED,
|
||||
error=event.get_stop_reason(),
|
||||
conversation_id=self._conversation.id,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
yield self._workflow_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run
|
||||
)
|
||||
|
||||
if workflow_run.status == WorkflowRunStatus.FAILED.value:
|
||||
err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))
|
||||
yield self._error_to_stream_response(self._handle_error(err_event, self._message))
|
||||
break
|
||||
|
||||
if isinstance(event, QueueStopEvent):
|
||||
# Save message
|
||||
self._save_message()
|
||||
|
||||
yield self._message_end_to_stream_response()
|
||||
break
|
||||
else:
|
||||
self._queue_manager.publish(
|
||||
QueueAdvancedChatMessageEndEvent(),
|
||||
PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
|
||||
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
|
||||
if output_moderation_answer:
|
||||
self._task_state.answer = output_moderation_answer
|
||||
yield self._message_replace_to_stream_response(answer=output_moderation_answer)
|
||||
|
||||
# Save message
|
||||
self._save_message()
|
||||
self._save_message(graph_runtime_state=graph_runtime_state)
|
||||
|
||||
yield self._message_end_to_stream_response()
|
||||
break
|
||||
elif isinstance(event, QueueRetrieverResourcesEvent):
|
||||
self._handle_retriever_resources(event)
|
||||
|
||||
self._refetch_message()
|
||||
|
||||
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
|
||||
if self._task_state.metadata else None
|
||||
|
||||
db.session.commit()
|
||||
db.session.refresh(self._message)
|
||||
db.session.close()
|
||||
elif isinstance(event, QueueAnnotationReplyEvent):
|
||||
self._handle_annotation_reply(event)
|
||||
|
||||
self._refetch_message()
|
||||
|
||||
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
|
||||
if self._task_state.metadata else None
|
||||
|
||||
db.session.commit()
|
||||
db.session.refresh(self._message)
|
||||
db.session.close()
|
||||
elif isinstance(event, QueueTextChunkEvent):
|
||||
delta_text = event.text
|
||||
if delta_text is None:
|
||||
continue
|
||||
|
||||
if not self._is_stream_out_support(
|
||||
event=event
|
||||
):
|
||||
continue
|
||||
|
||||
# handle output moderation chunk
|
||||
should_direct_answer = self._handle_output_moderation_chunk(delta_text)
|
||||
if should_direct_answer:
|
||||
continue
|
||||
|
||||
# only publish tts message at text chunk streaming
|
||||
if tts_publisher:
|
||||
tts_publisher.publish(message=queue_message)
|
||||
|
||||
self._task_state.answer += delta_text
|
||||
yield self._message_to_stream_response(delta_text, self._message.id)
|
||||
elif isinstance(event, QueueMessageReplaceEvent):
|
||||
# published by moderation
|
||||
yield self._message_replace_to_stream_response(answer=event.text)
|
||||
elif isinstance(event, QueuePingEvent):
|
||||
yield self._ping_stream_response()
|
||||
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
|
||||
if not graph_runtime_state:
|
||||
raise Exception('Graph runtime state not initialized.')
|
||||
|
||||
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
|
||||
if output_moderation_answer:
|
||||
self._task_state.answer = output_moderation_answer
|
||||
yield self._message_replace_to_stream_response(answer=output_moderation_answer)
|
||||
|
||||
# Save message
|
||||
self._save_message(graph_runtime_state=graph_runtime_state)
|
||||
|
||||
yield self._message_end_to_stream_response()
|
||||
else:
|
||||
continue
|
||||
if publisher:
|
||||
publisher.publish(None)
|
||||
|
||||
# publish None when task finished
|
||||
if tts_publisher:
|
||||
tts_publisher.publish(None)
|
||||
|
||||
if self._conversation_name_generate_thread:
|
||||
self._conversation_name_generate_thread.join()
|
||||
|
||||
def _save_message(self) -> None:
|
||||
def _save_message(self, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
|
||||
"""
|
||||
Save message.
|
||||
:return:
|
||||
"""
|
||||
self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
|
||||
self._refetch_message()
|
||||
|
||||
self._message.answer = self._task_state.answer
|
||||
self._message.provider_response_latency = time.perf_counter() - self._start_at
|
||||
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
|
||||
if self._task_state.metadata else None
|
||||
|
||||
if self._task_state.metadata and self._task_state.metadata.get('usage'):
|
||||
usage = LLMUsage(**self._task_state.metadata['usage'])
|
||||
|
||||
if graph_runtime_state and graph_runtime_state.llm_usage:
|
||||
usage = graph_runtime_state.llm_usage
|
||||
self._message.message_tokens = usage.prompt_tokens
|
||||
self._message.message_unit_price = usage.prompt_unit_price
|
||||
self._message.message_price_unit = usage.prompt_price_unit
|
||||
|
@ -432,7 +533,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
"""
|
||||
extras = {}
|
||||
if self._task_state.metadata:
|
||||
extras['metadata'] = self._task_state.metadata
|
||||
extras['metadata'] = self._task_state.metadata.copy()
|
||||
|
||||
if 'annotation_reply' in extras['metadata']:
|
||||
del extras['metadata']['annotation_reply']
|
||||
|
||||
return MessageEndStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
|
@ -440,323 +544,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
**extras
|
||||
)
|
||||
|
||||
def _get_stream_generate_routes(self) -> dict[str, ChatflowStreamGenerateRoute]:
|
||||
"""
|
||||
Get stream generate routes.
|
||||
:return:
|
||||
"""
|
||||
# find all answer nodes
|
||||
graph = self._workflow.graph_dict
|
||||
answer_node_configs = [
|
||||
node for node in graph['nodes']
|
||||
if node.get('data', {}).get('type') == NodeType.ANSWER.value
|
||||
]
|
||||
|
||||
# parse stream output node value selectors of answer nodes
|
||||
stream_generate_routes = {}
|
||||
for node_config in answer_node_configs:
|
||||
# get generate route for stream output
|
||||
answer_node_id = node_config['id']
|
||||
generate_route = AnswerNode.extract_generate_route_selectors(node_config)
|
||||
start_node_ids = self._get_answer_start_at_node_ids(graph, answer_node_id)
|
||||
if not start_node_ids:
|
||||
continue
|
||||
|
||||
for start_node_id in start_node_ids:
|
||||
stream_generate_routes[start_node_id] = ChatflowStreamGenerateRoute(
|
||||
answer_node_id=answer_node_id,
|
||||
generate_route=generate_route
|
||||
)
|
||||
|
||||
return stream_generate_routes
|
||||
|
||||
def _get_answer_start_at_node_ids(self, graph: dict, target_node_id: str) \
|
||||
-> list[str]:
|
||||
"""
|
||||
Get answer start at node id.
|
||||
:param graph: graph
|
||||
:param target_node_id: target node ID
|
||||
:return:
|
||||
"""
|
||||
nodes = graph.get('nodes')
|
||||
edges = graph.get('edges')
|
||||
|
||||
# fetch all ingoing edges from source node
|
||||
ingoing_edges = []
|
||||
for edge in edges:
|
||||
if edge.get('target') == target_node_id:
|
||||
ingoing_edges.append(edge)
|
||||
|
||||
if not ingoing_edges:
|
||||
# check if it's the first node in the iteration
|
||||
target_node = next((node for node in nodes if node.get('id') == target_node_id), None)
|
||||
if not target_node:
|
||||
return []
|
||||
|
||||
node_iteration_id = target_node.get('data', {}).get('iteration_id')
|
||||
# get iteration start node id
|
||||
for node in nodes:
|
||||
if node.get('id') == node_iteration_id:
|
||||
if node.get('data', {}).get('start_node_id') == target_node_id:
|
||||
return [target_node_id]
|
||||
|
||||
return []
|
||||
|
||||
start_node_ids = []
|
||||
for ingoing_edge in ingoing_edges:
|
||||
source_node_id = ingoing_edge.get('source')
|
||||
source_node = next((node for node in nodes if node.get('id') == source_node_id), None)
|
||||
if not source_node:
|
||||
continue
|
||||
|
||||
node_type = source_node.get('data', {}).get('type')
|
||||
node_iteration_id = source_node.get('data', {}).get('iteration_id')
|
||||
iteration_start_node_id = None
|
||||
if node_iteration_id:
|
||||
iteration_node = next((node for node in nodes if node.get('id') == node_iteration_id), None)
|
||||
iteration_start_node_id = iteration_node.get('data', {}).get('start_node_id')
|
||||
|
||||
if node_type in [
|
||||
NodeType.ANSWER.value,
|
||||
NodeType.IF_ELSE.value,
|
||||
NodeType.QUESTION_CLASSIFIER.value,
|
||||
NodeType.ITERATION.value,
|
||||
NodeType.LOOP.value
|
||||
]:
|
||||
start_node_id = target_node_id
|
||||
start_node_ids.append(start_node_id)
|
||||
elif node_type == NodeType.START.value or \
|
||||
node_iteration_id is not None and iteration_start_node_id == source_node.get('id'):
|
||||
start_node_id = source_node_id
|
||||
start_node_ids.append(start_node_id)
|
||||
else:
|
||||
sub_start_node_ids = self._get_answer_start_at_node_ids(graph, source_node_id)
|
||||
if sub_start_node_ids:
|
||||
start_node_ids.extend(sub_start_node_ids)
|
||||
|
||||
return start_node_ids
|
||||
|
||||
def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]:
|
||||
"""
|
||||
Get iteration nested relations.
|
||||
:param graph: graph
|
||||
:return:
|
||||
"""
|
||||
nodes = graph.get('nodes')
|
||||
|
||||
iteration_ids = [node.get('id') for node in nodes
|
||||
if node.get('data', {}).get('type') in [
|
||||
NodeType.ITERATION.value,
|
||||
NodeType.LOOP.value,
|
||||
]]
|
||||
|
||||
return {
|
||||
iteration_id: [
|
||||
node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id
|
||||
] for iteration_id in iteration_ids
|
||||
}
|
||||
|
||||
def _generate_stream_outputs_when_node_started(self) -> Generator:
|
||||
"""
|
||||
Generate stream outputs.
|
||||
:return:
|
||||
"""
|
||||
if self._task_state.current_stream_generate_state:
|
||||
route_chunks = self._task_state.current_stream_generate_state.generate_route[
|
||||
self._task_state.current_stream_generate_state.current_route_position:
|
||||
]
|
||||
|
||||
for route_chunk in route_chunks:
|
||||
if route_chunk.type == 'text':
|
||||
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
|
||||
|
||||
# handle output moderation chunk
|
||||
should_direct_answer = self._handle_output_moderation_chunk(route_chunk.text)
|
||||
if should_direct_answer:
|
||||
continue
|
||||
|
||||
self._task_state.answer += route_chunk.text
|
||||
yield self._message_to_stream_response(route_chunk.text, self._message.id)
|
||||
else:
|
||||
break
|
||||
|
||||
self._task_state.current_stream_generate_state.current_route_position += 1
|
||||
|
||||
# all route chunks are generated
|
||||
if self._task_state.current_stream_generate_state.current_route_position == len(
|
||||
self._task_state.current_stream_generate_state.generate_route
|
||||
):
|
||||
self._task_state.current_stream_generate_state = None
|
||||
|
||||
def _generate_stream_outputs_when_node_finished(self) -> Optional[Generator]:
|
||||
"""
|
||||
Generate stream outputs.
|
||||
:return:
|
||||
"""
|
||||
if not self._task_state.current_stream_generate_state:
|
||||
return
|
||||
|
||||
route_chunks = self._task_state.current_stream_generate_state.generate_route[
|
||||
self._task_state.current_stream_generate_state.current_route_position:]
|
||||
|
||||
for route_chunk in route_chunks:
|
||||
if route_chunk.type == 'text':
|
||||
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
|
||||
self._task_state.answer += route_chunk.text
|
||||
yield self._message_to_stream_response(route_chunk.text, self._message.id)
|
||||
else:
|
||||
value = None
|
||||
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
|
||||
value_selector = route_chunk.value_selector
|
||||
if not value_selector:
|
||||
self._task_state.current_stream_generate_state.current_route_position += 1
|
||||
continue
|
||||
|
||||
route_chunk_node_id = value_selector[0]
|
||||
|
||||
if route_chunk_node_id == 'sys':
|
||||
# system variable
|
||||
value = contexts.workflow_variable_pool.get().get(value_selector)
|
||||
if value:
|
||||
value = value.text
|
||||
elif route_chunk_node_id in self._iteration_nested_relations:
|
||||
# it's a iteration variable
|
||||
if not self._iteration_state or route_chunk_node_id not in self._iteration_state.current_iterations:
|
||||
continue
|
||||
iteration_state = self._iteration_state.current_iterations[route_chunk_node_id]
|
||||
iterator = iteration_state.inputs
|
||||
if not iterator:
|
||||
continue
|
||||
iterator_selector = iterator.get('iterator_selector', [])
|
||||
if value_selector[1] == 'index':
|
||||
value = iteration_state.current_index
|
||||
elif value_selector[1] == 'item':
|
||||
value = iterator_selector[iteration_state.current_index] if iteration_state.current_index < len(
|
||||
iterator_selector
|
||||
) else None
|
||||
else:
|
||||
# check chunk node id is before current node id or equal to current node id
|
||||
if route_chunk_node_id not in self._task_state.ran_node_execution_infos:
|
||||
break
|
||||
|
||||
latest_node_execution_info = self._task_state.latest_node_execution_info
|
||||
|
||||
# get route chunk node execution info
|
||||
route_chunk_node_execution_info = self._task_state.ran_node_execution_infos[route_chunk_node_id]
|
||||
if (route_chunk_node_execution_info.node_type == NodeType.LLM
|
||||
and latest_node_execution_info.node_type == NodeType.LLM):
|
||||
# only LLM support chunk stream output
|
||||
self._task_state.current_stream_generate_state.current_route_position += 1
|
||||
continue
|
||||
|
||||
# get route chunk node execution
|
||||
route_chunk_node_execution = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.id == route_chunk_node_execution_info.workflow_node_execution_id
|
||||
).first()
|
||||
|
||||
outputs = route_chunk_node_execution.outputs_dict
|
||||
|
||||
# get value from outputs
|
||||
value = None
|
||||
for key in value_selector[1:]:
|
||||
if not value:
|
||||
value = outputs.get(key) if outputs else None
|
||||
else:
|
||||
value = value.get(key)
|
||||
|
||||
if value is not None:
|
||||
text = ''
|
||||
if isinstance(value, str | int | float):
|
||||
text = str(value)
|
||||
elif isinstance(value, FileVar):
|
||||
# convert file to markdown
|
||||
text = value.to_markdown()
|
||||
elif isinstance(value, dict):
|
||||
# handle files
|
||||
file_vars = self._fetch_files_from_variable_value(value)
|
||||
if file_vars:
|
||||
file_var = file_vars[0]
|
||||
try:
|
||||
file_var_obj = FileVar(**file_var)
|
||||
|
||||
# convert file to markdown
|
||||
text = file_var_obj.to_markdown()
|
||||
except Exception as e:
|
||||
logger.error(f'Error creating file var: {e}')
|
||||
|
||||
if not text:
|
||||
# other types
|
||||
text = json.dumps(value, ensure_ascii=False)
|
||||
elif isinstance(value, list):
|
||||
# handle files
|
||||
file_vars = self._fetch_files_from_variable_value(value)
|
||||
for file_var in file_vars:
|
||||
try:
|
||||
file_var_obj = FileVar(**file_var)
|
||||
except Exception as e:
|
||||
logger.error(f'Error creating file var: {e}')
|
||||
continue
|
||||
|
||||
# convert file to markdown
|
||||
text = file_var_obj.to_markdown() + ' '
|
||||
|
||||
text = text.strip()
|
||||
|
||||
if not text and value:
|
||||
# other types
|
||||
text = json.dumps(value, ensure_ascii=False)
|
||||
|
||||
if text:
|
||||
self._task_state.answer += text
|
||||
yield self._message_to_stream_response(text, self._message.id)
|
||||
|
||||
self._task_state.current_stream_generate_state.current_route_position += 1
|
||||
|
||||
# all route chunks are generated
|
||||
if self._task_state.current_stream_generate_state.current_route_position == len(
|
||||
self._task_state.current_stream_generate_state.generate_route
|
||||
):
|
||||
self._task_state.current_stream_generate_state = None
|
||||
|
||||
def _is_stream_out_support(self, event: QueueTextChunkEvent) -> bool:
|
||||
"""
|
||||
Is stream out support
|
||||
:param event: queue text chunk event
|
||||
:return:
|
||||
"""
|
||||
if not event.metadata:
|
||||
return True
|
||||
|
||||
if 'node_id' not in event.metadata:
|
||||
return True
|
||||
|
||||
node_type = event.metadata.get('node_type')
|
||||
stream_output_value_selector = event.metadata.get('value_selector')
|
||||
if not stream_output_value_selector:
|
||||
return False
|
||||
|
||||
if not self._task_state.current_stream_generate_state:
|
||||
return False
|
||||
|
||||
route_chunk = self._task_state.current_stream_generate_state.generate_route[
|
||||
self._task_state.current_stream_generate_state.current_route_position]
|
||||
|
||||
if route_chunk.type != 'var':
|
||||
return False
|
||||
|
||||
if node_type != NodeType.LLM:
|
||||
# only LLM support chunk stream output
|
||||
return False
|
||||
|
||||
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
|
||||
value_selector = route_chunk.value_selector
|
||||
|
||||
# check chunk node id is before current node id or equal to current node id
|
||||
if value_selector != stream_output_value_selector:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _handle_output_moderation_chunk(self, text: str) -> bool:
|
||||
"""
|
||||
Handle output moderation chunk.
|
||||
|
@ -782,3 +569,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
self._output_moderation_handler.append_new_token(text)
|
||||
|
||||
return False
|
||||
|
||||
def _refetch_message(self) -> None:
|
||||
"""
|
||||
Refetch message.
|
||||
:return:
|
||||
"""
|
||||
message = db.session.query(Message).filter(Message.id == self._message.id).first()
|
||||
if message:
|
||||
self._message = message
|
||||
|
|
|
@ -1,203 +0,0 @@
|
|||
from typing import Any, Optional
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
class WorkflowEventTriggerCallback(WorkflowCallback):
|
||||
|
||||
def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
|
||||
self._queue_manager = queue_manager
|
||||
|
||||
def on_workflow_run_started(self) -> None:
|
||||
"""
|
||||
Workflow run started
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueWorkflowStartedEvent(),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_run_succeeded(self) -> None:
|
||||
"""
|
||||
Workflow run succeeded
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueWorkflowSucceededEvent(),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_run_failed(self, error: str) -> None:
|
||||
"""
|
||||
Workflow run failed
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueWorkflowFailedEvent(
|
||||
error=error
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_node_execute_started(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
node_run_index: int = 1,
|
||||
predecessor_node_id: Optional[str] = None) -> None:
|
||||
"""
|
||||
Workflow node execute started
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueNodeStartedEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_data=node_data,
|
||||
node_run_index=node_run_index,
|
||||
predecessor_node_id=predecessor_node_id
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_node_execute_succeeded(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
inputs: Optional[dict] = None,
|
||||
process_data: Optional[dict] = None,
|
||||
outputs: Optional[dict] = None,
|
||||
execution_metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Workflow node execute succeeded
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueNodeSucceededEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_data=node_data,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
execution_metadata=execution_metadata
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_node_execute_failed(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
error: str,
|
||||
inputs: Optional[dict] = None,
|
||||
outputs: Optional[dict] = None,
|
||||
process_data: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Workflow node execute failed
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueNodeFailedEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_data=node_data,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
process_data=process_data,
|
||||
error=error
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Publish text chunk
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueTextChunkEvent(
|
||||
text=text,
|
||||
metadata={
|
||||
"node_id": node_id,
|
||||
**metadata
|
||||
}
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_iteration_started(self,
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
node_run_index: int = 1,
|
||||
node_data: Optional[BaseNodeData] = None,
|
||||
inputs: dict = None,
|
||||
predecessor_node_id: Optional[str] = None,
|
||||
metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Publish iteration started
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueIterationStartEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_run_index=node_run_index,
|
||||
node_data=node_data,
|
||||
inputs=inputs,
|
||||
predecessor_node_id=predecessor_node_id,
|
||||
metadata=metadata
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_iteration_next(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
index: int,
|
||||
node_run_index: int,
|
||||
output: Optional[Any]) -> None:
|
||||
"""
|
||||
Publish iteration next
|
||||
"""
|
||||
self._queue_manager._publish(
|
||||
QueueIterationNextEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
index=index,
|
||||
node_run_index=node_run_index,
|
||||
output=output
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_iteration_completed(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_run_index: int,
|
||||
outputs: dict) -> None:
|
||||
"""
|
||||
Publish iteration completed
|
||||
"""
|
||||
self._queue_manager._publish(
|
||||
QueueIterationCompletedEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_run_index=node_run_index,
|
||||
outputs=outputs
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_event(self, event: AppQueueEvent) -> None:
|
||||
"""
|
||||
Publish event
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
event,
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
|
@ -3,7 +3,7 @@ import os
|
|||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Union
|
||||
from typing import Any, Literal, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
|
@ -28,6 +28,24 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
@overload
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: Literal[True] = True,
|
||||
) -> Generator[dict, None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: Literal[False] = False,
|
||||
) -> dict: ...
|
||||
|
||||
def generate(self, app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Any,
|
||||
|
|
|
@ -16,7 +16,7 @@ class AppGenerateResponseConverter(ABC):
|
|||
def convert(cls, response: Union[
|
||||
AppBlockingResponse,
|
||||
Generator[AppStreamResponse, Any, None]
|
||||
], invoke_from: InvokeFrom):
|
||||
], invoke_from: InvokeFrom) -> dict[str, Any] | Generator[str, Any, None]:
|
||||
if invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
|
||||
if isinstance(response, AppBlockingResponse):
|
||||
return cls.convert_blocking_full_response(response)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
|
@ -347,7 +347,7 @@ class AppRunner:
|
|||
self, app_id: str,
|
||||
tenant_id: str,
|
||||
app_generate_entity: AppGenerateEntity,
|
||||
inputs: dict,
|
||||
inputs: Mapping[str, Any],
|
||||
query: str,
|
||||
message_id: str,
|
||||
) -> tuple[bool, dict, str]:
|
||||
|
|
|
@ -3,7 +3,7 @@ import os
|
|||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Union
|
||||
from typing import Any, Literal, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
|
@ -28,13 +28,31 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
@overload
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Any,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: Literal[True] = True,
|
||||
) -> Generator[str, None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Any,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: Literal[False] = False,
|
||||
) -> dict: ...
|
||||
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Any,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
) -> Union[dict, Generator[dict, None, None]]:
|
||||
) -> Union[dict, Generator[str, None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ import os
|
|||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Union
|
||||
from typing import Any, Literal, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
|
@ -30,12 +30,30 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
@overload
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: Literal[True] = True,
|
||||
) -> Generator[str, None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: Literal[False] = False,
|
||||
) -> dict: ...
|
||||
|
||||
def generate(self, app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Any,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True) \
|
||||
-> Union[dict, Generator[dict, None, None]]:
|
||||
-> Union[dict, Generator[str, None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
@ -203,7 +221,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||
user: Union[Account, EndUser],
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True) \
|
||||
-> Union[dict, Generator[dict, None, None]]:
|
||||
-> Union[dict, Generator[str, None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ import os
|
|||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
from typing import Any, Literal, Optional, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
|
@ -32,14 +32,40 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class WorkflowAppGenerator(BaseAppGenerator):
|
||||
@overload
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: Literal[True] = True,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None
|
||||
) -> Generator[str, None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: Literal[False] = False,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None
|
||||
) -> dict: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Generate App response.
|
||||
|
@ -51,6 +77,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
:param invoke_from: invoke from source
|
||||
:param stream: is stream
|
||||
:param call_depth: call depth
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
"""
|
||||
inputs = args['inputs']
|
||||
|
||||
|
@ -98,16 +125,19 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
application_generate_entity=application_generate_entity,
|
||||
invoke_from=invoke_from,
|
||||
stream=stream,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id
|
||||
)
|
||||
|
||||
def _generate(
|
||||
self, app_model: App,
|
||||
self, *,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
) -> Union[dict, Generator[dict, None, None]]:
|
||||
workflow_thread_pool_id: Optional[str] = None
|
||||
) -> dict[str, Any] | Generator[str, None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
@ -117,6 +147,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
:param application_generate_entity: application generate entity
|
||||
:param invoke_from: invoke from source
|
||||
:param stream: is stream
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
"""
|
||||
# init queue manager
|
||||
queue_manager = WorkflowAppQueueManager(
|
||||
|
@ -128,10 +159,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
|
||||
# new thread
|
||||
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'flask_app': current_app._get_current_object(), # type: ignore
|
||||
'application_generate_entity': application_generate_entity,
|
||||
'queue_manager': queue_manager,
|
||||
'context': contextvars.copy_context()
|
||||
'context': contextvars.copy_context(),
|
||||
'workflow_thread_pool_id': workflow_thread_pool_id
|
||||
})
|
||||
|
||||
worker_thread.start()
|
||||
|
@ -155,7 +187,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
node_id: str,
|
||||
user: Account,
|
||||
args: dict,
|
||||
stream: bool = True):
|
||||
stream: bool = True) -> dict[str, Any] | Generator[str, Any, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
@ -172,10 +204,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
if args.get('inputs') is None:
|
||||
raise ValueError('inputs is required')
|
||||
|
||||
extras = {
|
||||
"auto_generate_conversation_name": False
|
||||
}
|
||||
|
||||
# convert to app config
|
||||
app_config = WorkflowAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
|
@ -191,7 +219,9 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
extras=extras,
|
||||
extras={
|
||||
"auto_generate_conversation_name": False
|
||||
},
|
||||
single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity(
|
||||
node_id=node_id,
|
||||
inputs=args['inputs']
|
||||
|
@ -211,12 +241,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
def _generate_worker(self, flask_app: Flask,
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
context: contextvars.Context) -> None:
|
||||
context: contextvars.Context,
|
||||
workflow_thread_pool_id: Optional[str] = None) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
:param flask_app: Flask app
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: queue manager
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
:return:
|
||||
"""
|
||||
for var, val in context.items():
|
||||
|
@ -224,22 +256,13 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
with flask_app.app_context():
|
||||
try:
|
||||
# workflow app
|
||||
runner = WorkflowAppRunner()
|
||||
if application_generate_entity.single_iteration_run:
|
||||
single_iteration_run = application_generate_entity.single_iteration_run
|
||||
runner.single_iteration_run(
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
workflow_id=application_generate_entity.app_config.workflow_id,
|
||||
queue_manager=queue_manager,
|
||||
inputs=single_iteration_run.inputs,
|
||||
node_id=single_iteration_run.node_id,
|
||||
user_id=application_generate_entity.user_id
|
||||
)
|
||||
else:
|
||||
runner.run(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager
|
||||
)
|
||||
runner = WorkflowAppRunner(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id
|
||||
)
|
||||
|
||||
runner.run()
|
||||
except GenerateTaskStoppedException:
|
||||
pass
|
||||
except InvokeAuthorizationError:
|
||||
|
@ -251,14 +274,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
|
||||
if os.environ.get("DEBUG") and os.environ.get("DEBUG", "false").lower() == 'true':
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except Exception as e:
|
||||
logger.exception("Unknown Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
finally:
|
||||
db.session.remove()
|
||||
db.session.close()
|
||||
|
||||
def _handle_response(self, application_generate_entity: WorkflowAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
|
|
|
@ -4,46 +4,61 @@ from typing import Optional, cast
|
|||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
|
||||
from core.app.apps.workflow.workflow_event_trigger_callback import WorkflowEventTriggerCallback
|
||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
InvokeFrom,
|
||||
WorkflowAppGenerateEntity,
|
||||
)
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.node_entities import UserFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, EndUser
|
||||
from models.workflow import Workflow
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowAppRunner:
|
||||
class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
"""
|
||||
Workflow Application Runner
|
||||
"""
|
||||
|
||||
def run(self, application_generate_entity: WorkflowAppGenerateEntity, queue_manager: AppQueueManager) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
workflow_thread_pool_id: Optional[str] = None
|
||||
) -> None:
|
||||
"""
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: application queue manager
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
"""
|
||||
self.application_generate_entity = application_generate_entity
|
||||
self.queue_manager = queue_manager
|
||||
self.workflow_thread_pool_id = workflow_thread_pool_id
|
||||
|
||||
def run(self) -> None:
|
||||
"""
|
||||
Run application
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: application queue manager
|
||||
:return:
|
||||
"""
|
||||
app_config = application_generate_entity.app_config
|
||||
app_config = self.application_generate_entity.app_config
|
||||
app_config = cast(WorkflowAppConfig, app_config)
|
||||
|
||||
user_id = None
|
||||
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
|
||||
if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
|
||||
if end_user:
|
||||
user_id = end_user.session_id
|
||||
else:
|
||||
user_id = application_generate_entity.user_id
|
||||
user_id = self.application_generate_entity.user_id
|
||||
|
||||
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
|
||||
if not app_record:
|
||||
|
@ -53,80 +68,64 @@ class WorkflowAppRunner:
|
|||
if not workflow:
|
||||
raise ValueError('Workflow not initialized')
|
||||
|
||||
inputs = application_generate_entity.inputs
|
||||
files = application_generate_entity.files
|
||||
|
||||
db.session.close()
|
||||
|
||||
workflow_callbacks: list[WorkflowCallback] = [
|
||||
WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)
|
||||
]
|
||||
|
||||
workflow_callbacks: list[WorkflowCallback] = []
|
||||
if bool(os.environ.get('DEBUG', 'False').lower() == 'true'):
|
||||
workflow_callbacks.append(WorkflowLoggingCallback())
|
||||
|
||||
# Create a variable pool.
|
||||
system_inputs = {
|
||||
SystemVariableKey.FILES: files,
|
||||
SystemVariableKey.USER_ID: user_id,
|
||||
}
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=inputs,
|
||||
environment_variables=workflow.environment_variables,
|
||||
conversation_variables=[],
|
||||
)
|
||||
# if only single iteration run is requested
|
||||
if self.application_generate_entity.single_iteration_run:
|
||||
# if only single iteration run is requested
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
workflow=workflow,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs
|
||||
)
|
||||
else:
|
||||
|
||||
inputs = self.application_generate_entity.inputs
|
||||
files = self.application_generate_entity.files
|
||||
|
||||
# Create a variable pool.
|
||||
system_inputs = {
|
||||
SystemVariableKey.FILES: files,
|
||||
SystemVariableKey.USER_ID: user_id,
|
||||
}
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=inputs,
|
||||
environment_variables=workflow.environment_variables,
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = self._init_graph(graph_config=workflow.graph_dict)
|
||||
|
||||
# RUN WORKFLOW
|
||||
workflow_engine_manager = WorkflowEngineManager()
|
||||
workflow_engine_manager.run_workflow(
|
||||
workflow=workflow,
|
||||
user_id=application_generate_entity.user_id,
|
||||
user_from=UserFrom.ACCOUNT
|
||||
if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
|
||||
else UserFrom.END_USER,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
callbacks=workflow_callbacks,
|
||||
call_depth=application_generate_entity.call_depth,
|
||||
workflow_entry = WorkflowEntry(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_type=WorkflowType.value_of(workflow.type),
|
||||
graph=graph,
|
||||
graph_config=workflow.graph_dict,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
user_from=(
|
||||
UserFrom.ACCOUNT
|
||||
if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
|
||||
else UserFrom.END_USER
|
||||
),
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
call_depth=self.application_generate_entity.call_depth,
|
||||
variable_pool=variable_pool,
|
||||
thread_pool_id=self.workflow_thread_pool_id
|
||||
)
|
||||
|
||||
def single_iteration_run(
|
||||
self, app_id: str, workflow_id: str, queue_manager: AppQueueManager, inputs: dict, node_id: str, user_id: str
|
||||
) -> None:
|
||||
"""
|
||||
Single iteration run
|
||||
"""
|
||||
app_record = db.session.query(App).filter(App.id == app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError('App not found')
|
||||
|
||||
if not app_record.workflow_id:
|
||||
raise ValueError('Workflow not initialized')
|
||||
|
||||
workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id)
|
||||
if not workflow:
|
||||
raise ValueError('Workflow not initialized')
|
||||
|
||||
workflow_callbacks = [WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)]
|
||||
|
||||
workflow_engine_manager = WorkflowEngineManager()
|
||||
workflow_engine_manager.single_step_run_iteration_workflow_node(
|
||||
workflow=workflow, node_id=node_id, user_id=user_id, user_inputs=inputs, callbacks=workflow_callbacks
|
||||
generator = workflow_entry.run(
|
||||
callbacks=workflow_callbacks
|
||||
)
|
||||
|
||||
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
|
||||
"""
|
||||
Get workflow
|
||||
"""
|
||||
# fetch workflow by workflow_id
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.filter(
|
||||
Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# return workflow
|
||||
return workflow
|
||||
for event in generator:
|
||||
self._handle_event(workflow_entry, event)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
|
@ -15,10 +16,12 @@ from core.app.entities.queue_entities import (
|
|||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueMessageReplaceEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
QueueParallelBranchRunStartedEvent,
|
||||
QueueParallelBranchRunSucceededEvent,
|
||||
QueuePingEvent,
|
||||
QueueStopEvent,
|
||||
QueueTextChunkEvent,
|
||||
|
@ -32,19 +35,16 @@ from core.app.entities.task_entities import (
|
|||
MessageAudioStreamResponse,
|
||||
StreamResponse,
|
||||
TextChunkStreamResponse,
|
||||
TextReplaceStreamResponse,
|
||||
WorkflowAppBlockingResponse,
|
||||
WorkflowAppStreamResponse,
|
||||
WorkflowFinishStreamResponse,
|
||||
WorkflowStreamGenerateNodes,
|
||||
WorkflowStartStreamResponse,
|
||||
WorkflowTaskState,
|
||||
)
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import EndUser
|
||||
|
@ -52,8 +52,8 @@ from models.workflow import (
|
|||
Workflow,
|
||||
WorkflowAppLog,
|
||||
WorkflowAppLogCreatedFrom,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowRun,
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -68,7 +68,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
_task_state: WorkflowTaskState
|
||||
_application_generate_entity: WorkflowAppGenerateEntity
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
_iteration_nested_relations: dict[str, list[str]]
|
||||
|
||||
def __init__(self, application_generate_entity: WorkflowAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
|
@ -96,11 +95,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
SystemVariableKey.USER_ID: user_id
|
||||
}
|
||||
|
||||
self._task_state = WorkflowTaskState(
|
||||
iteration_nested_node_ids=[]
|
||||
)
|
||||
self._stream_generate_nodes = self._get_stream_generate_nodes()
|
||||
self._iteration_nested_relations = self._get_iteration_nested_relations(self._workflow.graph_dict)
|
||||
self._task_state = WorkflowTaskState()
|
||||
|
||||
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||
"""
|
||||
|
@ -129,23 +124,20 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
if isinstance(stream_response, ErrorStreamResponse):
|
||||
raise stream_response.err
|
||||
elif isinstance(stream_response, WorkflowFinishStreamResponse):
|
||||
workflow_run = db.session.query(WorkflowRun).filter(
|
||||
WorkflowRun.id == self._task_state.workflow_run_id).first()
|
||||
|
||||
response = WorkflowAppBlockingResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
workflow_run_id=stream_response.data.id,
|
||||
data=WorkflowAppBlockingResponse.Data(
|
||||
id=workflow_run.id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
status=workflow_run.status,
|
||||
outputs=workflow_run.outputs_dict,
|
||||
error=workflow_run.error,
|
||||
elapsed_time=workflow_run.elapsed_time,
|
||||
total_tokens=workflow_run.total_tokens,
|
||||
total_steps=workflow_run.total_steps,
|
||||
created_at=int(workflow_run.created_at.timestamp()),
|
||||
finished_at=int(workflow_run.finished_at.timestamp())
|
||||
id=stream_response.data.id,
|
||||
workflow_id=stream_response.data.workflow_id,
|
||||
status=stream_response.data.status,
|
||||
outputs=stream_response.data.outputs,
|
||||
error=stream_response.data.error,
|
||||
elapsed_time=stream_response.data.elapsed_time,
|
||||
total_tokens=stream_response.data.total_tokens,
|
||||
total_steps=stream_response.data.total_steps,
|
||||
created_at=int(stream_response.data.created_at),
|
||||
finished_at=int(stream_response.data.finished_at)
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -161,9 +153,13 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
To stream response.
|
||||
:return:
|
||||
"""
|
||||
workflow_run_id = None
|
||||
for stream_response in generator:
|
||||
if isinstance(stream_response, WorkflowStartStreamResponse):
|
||||
workflow_run_id = stream_response.workflow_run_id
|
||||
|
||||
yield WorkflowAppStreamResponse(
|
||||
workflow_run_id=self._task_state.workflow_run_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
stream_response=stream_response
|
||||
)
|
||||
|
||||
|
@ -178,17 +174,18 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
|
||||
Generator[StreamResponse, None, None]:
|
||||
|
||||
publisher = None
|
||||
tts_publisher = None
|
||||
task_id = self._application_generate_entity.task_id
|
||||
tenant_id = self._application_generate_entity.app_config.tenant_id
|
||||
features_dict = self._workflow.features_dict
|
||||
|
||||
if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[
|
||||
'text_to_speech'].get('autoPlay') == 'enabled':
|
||||
publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
|
||||
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
|
||||
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
|
||||
|
||||
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
audio_response = self._listenAudioMsg(publisher, task_id=task_id)
|
||||
audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id)
|
||||
if audio_response:
|
||||
yield audio_response
|
||||
else:
|
||||
|
@ -198,9 +195,9 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
start_listener_time = time.time()
|
||||
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
|
||||
try:
|
||||
if not publisher:
|
||||
if not tts_publisher:
|
||||
break
|
||||
audio_trunk = publisher.checkAndGetAudio()
|
||||
audio_trunk = tts_publisher.checkAndGetAudio()
|
||||
if audio_trunk is None:
|
||||
# release cpu
|
||||
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
|
||||
|
@ -218,69 +215,159 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
|
||||
def _process_stream_response(
|
||||
self,
|
||||
publisher: AppGeneratorTTSPublisher,
|
||||
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""
|
||||
Process stream response.
|
||||
:return:
|
||||
"""
|
||||
for message in self._queue_manager.listen():
|
||||
if publisher:
|
||||
publisher.publish(message=message)
|
||||
event = message.event
|
||||
graph_runtime_state = None
|
||||
workflow_run = None
|
||||
|
||||
if isinstance(event, QueueErrorEvent):
|
||||
for queue_message in self._queue_manager.listen():
|
||||
event = queue_message.event
|
||||
|
||||
if isinstance(event, QueuePingEvent):
|
||||
yield self._ping_stream_response()
|
||||
elif isinstance(event, QueueErrorEvent):
|
||||
err = self._handle_error(event)
|
||||
yield self._error_to_stream_response(err)
|
||||
break
|
||||
elif isinstance(event, QueueWorkflowStartedEvent):
|
||||
workflow_run = self._handle_workflow_start()
|
||||
# override graph runtime state
|
||||
graph_runtime_state = event.graph_runtime_state
|
||||
|
||||
# init workflow run
|
||||
workflow_run = self._handle_workflow_run_start()
|
||||
yield self._workflow_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run
|
||||
)
|
||||
elif isinstance(event, QueueNodeStartedEvent):
|
||||
workflow_node_execution = self._handle_node_start(event)
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
# search stream_generate_routes if node id is answer start at node
|
||||
if not self._task_state.current_stream_generate_state and event.node_id in self._stream_generate_nodes:
|
||||
self._task_state.current_stream_generate_state = self._stream_generate_nodes[event.node_id]
|
||||
workflow_node_execution = self._handle_node_execution_start(
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
)
|
||||
|
||||
# generate stream outputs when node started
|
||||
yield from self._generate_stream_outputs_when_node_started()
|
||||
|
||||
yield self._workflow_node_start_to_stream_response(
|
||||
response = self._workflow_node_start_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution
|
||||
)
|
||||
elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent):
|
||||
workflow_node_execution = self._handle_node_finished(event)
|
||||
|
||||
yield self._workflow_node_finish_to_stream_response(
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueNodeSucceededEvent):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_success(event)
|
||||
|
||||
response = self._workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution
|
||||
)
|
||||
|
||||
if isinstance(event, QueueNodeFailedEvent):
|
||||
yield from self._handle_iteration_exception(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
error=f'Child node failed: {event.error}'
|
||||
)
|
||||
elif isinstance(event, QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent):
|
||||
if isinstance(event, QueueIterationNextEvent):
|
||||
# clear ran node execution infos of current iteration
|
||||
iteration_relations = self._iteration_nested_relations.get(event.node_id)
|
||||
if iteration_relations:
|
||||
for node_id in iteration_relations:
|
||||
self._task_state.ran_node_execution_infos.pop(node_id, None)
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueNodeFailedEvent):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
|
||||
|
||||
yield self._handle_iteration_to_stream_response(self._application_generate_entity.task_id, event)
|
||||
self._handle_iteration_operation(event)
|
||||
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
|
||||
workflow_run = self._handle_workflow_finished(
|
||||
event, trace_manager=trace_manager
|
||||
response = self._workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution
|
||||
)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueParallelBranchRunStartedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
yield self._workflow_parallel_branch_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
yield self._workflow_parallel_branch_finished_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationStartEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
yield self._workflow_iteration_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationNextEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
yield self._workflow_iteration_next_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationCompletedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
yield self._workflow_iteration_completed_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, QueueWorkflowSucceededEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception('Graph runtime state not initialized.')
|
||||
|
||||
workflow_run = self._handle_workflow_run_success(
|
||||
workflow_run=workflow_run,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
outputs=json.dumps(event.outputs) if isinstance(event, QueueWorkflowSucceededEvent) and event.outputs else None,
|
||||
conversation_id=None,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
# save workflow app log
|
||||
self._save_workflow_app_log(workflow_run)
|
||||
|
||||
yield self._workflow_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run
|
||||
)
|
||||
elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent):
|
||||
if not workflow_run:
|
||||
raise Exception('Workflow run not initialized.')
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception('Graph runtime state not initialized.')
|
||||
|
||||
workflow_run = self._handle_workflow_run_failed(
|
||||
workflow_run=workflow_run,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
status=WorkflowRunStatus.FAILED if isinstance(event, QueueWorkflowFailedEvent) else WorkflowRunStatus.STOPPED,
|
||||
error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
|
||||
conversation_id=None,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
# save workflow app log
|
||||
|
@ -295,22 +382,17 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
if delta_text is None:
|
||||
continue
|
||||
|
||||
if not self._is_stream_out_support(
|
||||
event=event
|
||||
):
|
||||
continue
|
||||
# only publish tts message at text chunk streaming
|
||||
if tts_publisher:
|
||||
tts_publisher.publish(message=queue_message)
|
||||
|
||||
self._task_state.answer += delta_text
|
||||
yield self._text_chunk_to_stream_response(delta_text)
|
||||
elif isinstance(event, QueueMessageReplaceEvent):
|
||||
yield self._text_replace_to_stream_response(event.text)
|
||||
elif isinstance(event, QueuePingEvent):
|
||||
yield self._ping_stream_response()
|
||||
else:
|
||||
continue
|
||||
|
||||
if publisher:
|
||||
publisher.publish(None)
|
||||
if tts_publisher:
|
||||
tts_publisher.publish(None)
|
||||
|
||||
|
||||
def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None:
|
||||
|
@ -329,15 +411,15 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
# not save log for debugging
|
||||
return
|
||||
|
||||
workflow_app_log = WorkflowAppLog(
|
||||
tenant_id=workflow_run.tenant_id,
|
||||
app_id=workflow_run.app_id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
created_from=created_from.value,
|
||||
created_by_role=('account' if isinstance(self._user, Account) else 'end_user'),
|
||||
created_by=self._user.id,
|
||||
)
|
||||
workflow_app_log = WorkflowAppLog()
|
||||
workflow_app_log.tenant_id = workflow_run.tenant_id
|
||||
workflow_app_log.app_id = workflow_run.app_id
|
||||
workflow_app_log.workflow_id = workflow_run.workflow_id
|
||||
workflow_app_log.workflow_run_id = workflow_run.id
|
||||
workflow_app_log.created_from = created_from.value
|
||||
workflow_app_log.created_by_role = 'account' if isinstance(self._user, Account) else 'end_user'
|
||||
workflow_app_log.created_by = self._user.id
|
||||
|
||||
db.session.add(workflow_app_log)
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
@ -354,180 +436,3 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
)
|
||||
|
||||
return response
|
||||
|
||||
def _text_replace_to_stream_response(self, text: str) -> TextReplaceStreamResponse:
|
||||
"""
|
||||
Text replace to stream response.
|
||||
:param text: text
|
||||
:return:
|
||||
"""
|
||||
return TextReplaceStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
text=TextReplaceStreamResponse.Data(text=text)
|
||||
)
|
||||
|
||||
def _get_stream_generate_nodes(self) -> dict[str, WorkflowStreamGenerateNodes]:
|
||||
"""
|
||||
Get stream generate nodes.
|
||||
:return:
|
||||
"""
|
||||
# find all answer nodes
|
||||
graph = self._workflow.graph_dict
|
||||
end_node_configs = [
|
||||
node for node in graph['nodes']
|
||||
if node.get('data', {}).get('type') == NodeType.END.value
|
||||
]
|
||||
|
||||
# parse stream output node value selectors of end nodes
|
||||
stream_generate_routes = {}
|
||||
for node_config in end_node_configs:
|
||||
# get generate route for stream output
|
||||
end_node_id = node_config['id']
|
||||
generate_nodes = EndNode.extract_generate_nodes(graph, node_config)
|
||||
start_node_ids = self._get_end_start_at_node_ids(graph, end_node_id)
|
||||
if not start_node_ids:
|
||||
continue
|
||||
|
||||
for start_node_id in start_node_ids:
|
||||
stream_generate_routes[start_node_id] = WorkflowStreamGenerateNodes(
|
||||
end_node_id=end_node_id,
|
||||
stream_node_ids=generate_nodes
|
||||
)
|
||||
|
||||
return stream_generate_routes
|
||||
|
||||
def _get_end_start_at_node_ids(self, graph: dict, target_node_id: str) \
|
||||
-> list[str]:
|
||||
"""
|
||||
Get end start at node id.
|
||||
:param graph: graph
|
||||
:param target_node_id: target node ID
|
||||
:return:
|
||||
"""
|
||||
nodes = graph.get('nodes')
|
||||
edges = graph.get('edges')
|
||||
|
||||
# fetch all ingoing edges from source node
|
||||
ingoing_edges = []
|
||||
for edge in edges:
|
||||
if edge.get('target') == target_node_id:
|
||||
ingoing_edges.append(edge)
|
||||
|
||||
if not ingoing_edges:
|
||||
return []
|
||||
|
||||
start_node_ids = []
|
||||
for ingoing_edge in ingoing_edges:
|
||||
source_node_id = ingoing_edge.get('source')
|
||||
source_node = next((node for node in nodes if node.get('id') == source_node_id), None)
|
||||
if not source_node:
|
||||
continue
|
||||
|
||||
node_type = source_node.get('data', {}).get('type')
|
||||
node_iteration_id = source_node.get('data', {}).get('iteration_id')
|
||||
iteration_start_node_id = None
|
||||
if node_iteration_id:
|
||||
iteration_node = next((node for node in nodes if node.get('id') == node_iteration_id), None)
|
||||
iteration_start_node_id = iteration_node.get('data', {}).get('start_node_id')
|
||||
|
||||
if node_type in [
|
||||
NodeType.IF_ELSE.value,
|
||||
NodeType.QUESTION_CLASSIFIER.value
|
||||
]:
|
||||
start_node_id = target_node_id
|
||||
start_node_ids.append(start_node_id)
|
||||
elif node_type == NodeType.START.value or \
|
||||
node_iteration_id is not None and iteration_start_node_id == source_node.get('id'):
|
||||
start_node_id = source_node_id
|
||||
start_node_ids.append(start_node_id)
|
||||
else:
|
||||
sub_start_node_ids = self._get_end_start_at_node_ids(graph, source_node_id)
|
||||
if sub_start_node_ids:
|
||||
start_node_ids.extend(sub_start_node_ids)
|
||||
|
||||
return start_node_ids
|
||||
|
||||
def _generate_stream_outputs_when_node_started(self) -> Generator:
|
||||
"""
|
||||
Generate stream outputs.
|
||||
:return:
|
||||
"""
|
||||
if self._task_state.current_stream_generate_state:
|
||||
stream_node_ids = self._task_state.current_stream_generate_state.stream_node_ids
|
||||
|
||||
for node_id, node_execution_info in self._task_state.ran_node_execution_infos.items():
|
||||
if node_id not in stream_node_ids:
|
||||
continue
|
||||
|
||||
node_execution_info = self._task_state.ran_node_execution_infos[node_id]
|
||||
|
||||
# get chunk node execution
|
||||
route_chunk_node_execution = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.id == node_execution_info.workflow_node_execution_id).first()
|
||||
|
||||
if not route_chunk_node_execution:
|
||||
continue
|
||||
|
||||
outputs = route_chunk_node_execution.outputs_dict
|
||||
|
||||
if not outputs:
|
||||
continue
|
||||
|
||||
# get value from outputs
|
||||
text = outputs.get('text')
|
||||
|
||||
if text:
|
||||
self._task_state.answer += text
|
||||
yield self._text_chunk_to_stream_response(text)
|
||||
|
||||
db.session.close()
|
||||
|
||||
def _is_stream_out_support(self, event: QueueTextChunkEvent) -> bool:
|
||||
"""
|
||||
Is stream out support
|
||||
:param event: queue text chunk event
|
||||
:return:
|
||||
"""
|
||||
if not event.metadata:
|
||||
return False
|
||||
|
||||
if 'node_id' not in event.metadata:
|
||||
return False
|
||||
|
||||
node_id = event.metadata.get('node_id')
|
||||
node_type = event.metadata.get('node_type')
|
||||
stream_output_value_selector = event.metadata.get('value_selector')
|
||||
if not stream_output_value_selector:
|
||||
return False
|
||||
|
||||
if not self._task_state.current_stream_generate_state:
|
||||
return False
|
||||
|
||||
if node_id not in self._task_state.current_stream_generate_state.stream_node_ids:
|
||||
return False
|
||||
|
||||
if node_type != NodeType.LLM:
|
||||
# only LLM support chunk stream output
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]:
|
||||
"""
|
||||
Get iteration nested relations.
|
||||
:param graph: graph
|
||||
:return:
|
||||
"""
|
||||
nodes = graph.get('nodes')
|
||||
|
||||
iteration_ids = [node.get('id') for node in nodes
|
||||
if node.get('data', {}).get('type') in [
|
||||
NodeType.ITERATION.value,
|
||||
NodeType.LOOP.value,
|
||||
]]
|
||||
|
||||
return {
|
||||
iteration_id: [
|
||||
node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id
|
||||
] for iteration_id in iteration_ids
|
||||
}
|
||||
|
|
|
@ -1,200 +0,0 @@
|
|||
from typing import Any, Optional
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
class WorkflowEventTriggerCallback(WorkflowCallback):
|
||||
|
||||
def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
|
||||
self._queue_manager = queue_manager
|
||||
|
||||
def on_workflow_run_started(self) -> None:
|
||||
"""
|
||||
Workflow run started
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueWorkflowStartedEvent(),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_run_succeeded(self) -> None:
|
||||
"""
|
||||
Workflow run succeeded
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueWorkflowSucceededEvent(),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_run_failed(self, error: str) -> None:
|
||||
"""
|
||||
Workflow run failed
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueWorkflowFailedEvent(
|
||||
error=error
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_node_execute_started(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
node_run_index: int = 1,
|
||||
predecessor_node_id: Optional[str] = None) -> None:
|
||||
"""
|
||||
Workflow node execute started
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueNodeStartedEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_data=node_data,
|
||||
node_run_index=node_run_index,
|
||||
predecessor_node_id=predecessor_node_id
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_node_execute_succeeded(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
inputs: Optional[dict] = None,
|
||||
process_data: Optional[dict] = None,
|
||||
outputs: Optional[dict] = None,
|
||||
execution_metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Workflow node execute succeeded
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueNodeSucceededEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_data=node_data,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
execution_metadata=execution_metadata
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_node_execute_failed(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
error: str,
|
||||
inputs: Optional[dict] = None,
|
||||
outputs: Optional[dict] = None,
|
||||
process_data: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Workflow node execute failed
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueNodeFailedEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_data=node_data,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
process_data=process_data,
|
||||
error=error
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Publish text chunk
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueTextChunkEvent(
|
||||
text=text,
|
||||
metadata={
|
||||
"node_id": node_id,
|
||||
**metadata
|
||||
}
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_iteration_started(self,
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
node_run_index: int = 1,
|
||||
node_data: Optional[BaseNodeData] = None,
|
||||
inputs: dict = None,
|
||||
predecessor_node_id: Optional[str] = None,
|
||||
metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Publish iteration started
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueIterationStartEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_run_index=node_run_index,
|
||||
node_data=node_data,
|
||||
inputs=inputs,
|
||||
predecessor_node_id=predecessor_node_id,
|
||||
metadata=metadata
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_iteration_next(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
index: int,
|
||||
node_run_index: int,
|
||||
output: Optional[Any]) -> None:
|
||||
"""
|
||||
Publish iteration next
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueIterationNextEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
index=index,
|
||||
node_run_index=node_run_index,
|
||||
output=output
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_iteration_completed(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_run_index: int,
|
||||
outputs: dict) -> None:
|
||||
"""
|
||||
Publish iteration completed
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueIterationCompletedEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_run_index=node_run_index,
|
||||
outputs=outputs
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_event(self, event: AppQueueEvent) -> None:
|
||||
"""
|
||||
Publish event
|
||||
"""
|
||||
pass
|
379
api/core/app/apps/workflow_app_runner.py
Normal file
379
api/core/app/apps/workflow_app_runner.py
Normal file
|
@ -0,0 +1,379 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
QueueParallelBranchRunStartedEvent,
|
||||
QueueParallelBranchRunSucceededEvent,
|
||||
QueueRetrieverResourcesEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphEngineEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
IterationRunFailedEvent,
|
||||
IterationRunNextEvent,
|
||||
IterationRunStartedEvent,
|
||||
IterationRunSucceededEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunRetrieverResourceEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
ParallelBranchRunFailedEvent,
|
||||
ParallelBranchRunStartedEvent,
|
||||
ParallelBranchRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.iteration.entities import IterationNodeData
|
||||
from core.workflow.nodes.node_mapping import node_classes
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from models.model import App
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
class WorkflowBasedAppRunner(AppRunner):
|
||||
def __init__(self, queue_manager: AppQueueManager):
|
||||
self.queue_manager = queue_manager
|
||||
|
||||
def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph:
|
||||
"""
|
||||
Init graph
|
||||
"""
|
||||
if 'nodes' not in graph_config or 'edges' not in graph_config:
|
||||
raise ValueError('nodes or edges not found in workflow graph')
|
||||
|
||||
if not isinstance(graph_config.get('nodes'), list):
|
||||
raise ValueError('nodes in workflow graph must be a list')
|
||||
|
||||
if not isinstance(graph_config.get('edges'), list):
|
||||
raise ValueError('edges in workflow graph must be a list')
|
||||
# init graph
|
||||
graph = Graph.init(
|
||||
graph_config=graph_config
|
||||
)
|
||||
|
||||
if not graph:
|
||||
raise ValueError('graph not found in workflow')
|
||||
|
||||
return graph
|
||||
|
||||
def _get_graph_and_variable_pool_of_single_iteration(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_inputs: dict,
|
||||
) -> tuple[Graph, VariablePool]:
|
||||
"""
|
||||
Get variable pool of single iteration
|
||||
"""
|
||||
# fetch workflow graph
|
||||
graph_config = workflow.graph_dict
|
||||
if not graph_config:
|
||||
raise ValueError('workflow graph not found')
|
||||
|
||||
graph_config = cast(dict[str, Any], graph_config)
|
||||
|
||||
if 'nodes' not in graph_config or 'edges' not in graph_config:
|
||||
raise ValueError('nodes or edges not found in workflow graph')
|
||||
|
||||
if not isinstance(graph_config.get('nodes'), list):
|
||||
raise ValueError('nodes in workflow graph must be a list')
|
||||
|
||||
if not isinstance(graph_config.get('edges'), list):
|
||||
raise ValueError('edges in workflow graph must be a list')
|
||||
|
||||
# filter nodes only in iteration
|
||||
node_configs = [
|
||||
node for node in graph_config.get('nodes', [])
|
||||
if node.get('id') == node_id or node.get('data', {}).get('iteration_id', '') == node_id
|
||||
]
|
||||
|
||||
graph_config['nodes'] = node_configs
|
||||
|
||||
node_ids = [node.get('id') for node in node_configs]
|
||||
|
||||
# filter edges only in iteration
|
||||
edge_configs = [
|
||||
edge for edge in graph_config.get('edges', [])
|
||||
if (edge.get('source') is None or edge.get('source') in node_ids)
|
||||
and (edge.get('target') is None or edge.get('target') in node_ids)
|
||||
]
|
||||
|
||||
graph_config['edges'] = edge_configs
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(
|
||||
graph_config=graph_config,
|
||||
root_node_id=node_id
|
||||
)
|
||||
|
||||
if not graph:
|
||||
raise ValueError('graph not found in workflow')
|
||||
|
||||
# fetch node config from node id
|
||||
iteration_node_config = None
|
||||
for node in node_configs:
|
||||
if node.get('id') == node_id:
|
||||
iteration_node_config = node
|
||||
break
|
||||
|
||||
if not iteration_node_config:
|
||||
raise ValueError('iteration node id not found in workflow graph')
|
||||
|
||||
# Get node class
|
||||
node_type = NodeType.value_of(iteration_node_config.get('data', {}).get('type'))
|
||||
node_cls = node_classes.get(node_type)
|
||||
node_cls = cast(type[BaseNode], node_cls)
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
environment_variables=workflow.environment_variables,
|
||||
)
|
||||
|
||||
try:
|
||||
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=workflow.graph_dict,
|
||||
config=iteration_node_config
|
||||
)
|
||||
except NotImplementedError:
|
||||
variable_mapping = {}
|
||||
|
||||
WorkflowEntry.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id=workflow.tenant_id,
|
||||
node_type=node_type,
|
||||
node_data=IterationNodeData(**iteration_node_config.get('data', {}))
|
||||
)
|
||||
|
||||
return graph, variable_pool
|
||||
|
||||
def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) -> None:
|
||||
"""
|
||||
Handle event
|
||||
:param workflow_entry: workflow entry
|
||||
:param event: event
|
||||
"""
|
||||
if isinstance(event, GraphRunStartedEvent):
|
||||
self._publish_event(
|
||||
QueueWorkflowStartedEvent(
|
||||
graph_runtime_state=workflow_entry.graph_engine.graph_runtime_state
|
||||
)
|
||||
)
|
||||
elif isinstance(event, GraphRunSucceededEvent):
|
||||
self._publish_event(
|
||||
QueueWorkflowSucceededEvent(outputs=event.outputs)
|
||||
)
|
||||
elif isinstance(event, GraphRunFailedEvent):
|
||||
self._publish_event(
|
||||
QueueWorkflowFailedEvent(error=event.error)
|
||||
)
|
||||
elif isinstance(event, NodeRunStartedEvent):
|
||||
self._publish_event(
|
||||
QueueNodeStartedEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
node_run_index=event.route_node_state.index,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
in_iteration_id=event.in_iteration_id
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunSucceededEvent):
|
||||
self._publish_event(
|
||||
QueueNodeSucceededEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
inputs=event.route_node_state.node_run_result.inputs
|
||||
if event.route_node_state.node_run_result else {},
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs
|
||||
if event.route_node_state.node_run_result else {},
|
||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
||||
if event.route_node_state.node_run_result else {},
|
||||
in_iteration_id=event.in_iteration_id
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunFailedEvent):
|
||||
self._publish_event(
|
||||
QueueNodeFailedEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
inputs=event.route_node_state.node_run_result.inputs
|
||||
if event.route_node_state.node_run_result else {},
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs
|
||||
if event.route_node_state.node_run_result else {},
|
||||
error=event.route_node_state.node_run_result.error
|
||||
if event.route_node_state.node_run_result
|
||||
and event.route_node_state.node_run_result.error
|
||||
else "Unknown error",
|
||||
in_iteration_id=event.in_iteration_id
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
self._publish_event(
|
||||
QueueTextChunkEvent(
|
||||
text=event.chunk_content,
|
||||
from_variable_selector=event.from_variable_selector,
|
||||
in_iteration_id=event.in_iteration_id
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunRetrieverResourceEvent):
|
||||
self._publish_event(
|
||||
QueueRetrieverResourcesEvent(
|
||||
retriever_resources=event.retriever_resources,
|
||||
in_iteration_id=event.in_iteration_id
|
||||
)
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunStartedEvent):
|
||||
self._publish_event(
|
||||
QueueParallelBranchRunStartedEvent(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
in_iteration_id=event.in_iteration_id
|
||||
)
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunSucceededEvent):
|
||||
self._publish_event(
|
||||
QueueParallelBranchRunSucceededEvent(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
in_iteration_id=event.in_iteration_id
|
||||
)
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunFailedEvent):
|
||||
self._publish_event(
|
||||
QueueParallelBranchRunFailedEvent(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
error=event.error
|
||||
)
|
||||
)
|
||||
elif isinstance(event, IterationRunStartedEvent):
|
||||
self._publish_event(
|
||||
QueueIterationStartEvent(
|
||||
node_execution_id=event.iteration_id,
|
||||
node_id=event.iteration_node_id,
|
||||
node_type=event.iteration_node_type,
|
||||
node_data=event.iteration_node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.start_at,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
inputs=event.inputs,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
metadata=event.metadata
|
||||
)
|
||||
)
|
||||
elif isinstance(event, IterationRunNextEvent):
|
||||
self._publish_event(
|
||||
QueueIterationNextEvent(
|
||||
node_execution_id=event.iteration_id,
|
||||
node_id=event.iteration_node_id,
|
||||
node_type=event.iteration_node_type,
|
||||
node_data=event.iteration_node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
index=event.index,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
output=event.pre_iteration_output,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, (IterationRunSucceededEvent | IterationRunFailedEvent)):
|
||||
self._publish_event(
|
||||
QueueIterationCompletedEvent(
|
||||
node_execution_id=event.iteration_id,
|
||||
node_id=event.iteration_node_id,
|
||||
node_type=event.iteration_node_type,
|
||||
node_data=event.iteration_node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.start_at,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
inputs=event.inputs,
|
||||
outputs=event.outputs,
|
||||
metadata=event.metadata,
|
||||
steps=event.steps,
|
||||
error=event.error if isinstance(event, IterationRunFailedEvent) else None
|
||||
)
|
||||
)
|
||||
|
||||
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
|
||||
"""
|
||||
Get workflow
|
||||
"""
|
||||
# fetch workflow by workflow_id
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.filter(
|
||||
Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# return workflow
|
||||
return workflow
|
||||
|
||||
def _publish_event(self, event: AppQueueEvent) -> None:
|
||||
self.queue_manager.publish(
|
||||
event,
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
|
@ -1,10 +1,24 @@
|
|||
from typing import Optional
|
||||
|
||||
from core.app.entities.queue_entities import AppQueueEvent
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphEngineEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
IterationRunFailedEvent,
|
||||
IterationRunNextEvent,
|
||||
IterationRunStartedEvent,
|
||||
IterationRunSucceededEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
ParallelBranchRunFailedEvent,
|
||||
ParallelBranchRunStartedEvent,
|
||||
ParallelBranchRunSucceededEvent,
|
||||
)
|
||||
|
||||
_TEXT_COLOR_MAPPING = {
|
||||
"blue": "36;1",
|
||||
|
@ -20,127 +34,203 @@ class WorkflowLoggingCallback(WorkflowCallback):
|
|||
def __init__(self) -> None:
|
||||
self.current_node_id = None
|
||||
|
||||
def on_workflow_run_started(self) -> None:
|
||||
"""
|
||||
Workflow run started
|
||||
"""
|
||||
self.print_text("\n[on_workflow_run_started]", color='pink')
|
||||
def on_event(
|
||||
self,
|
||||
event: GraphEngineEvent
|
||||
) -> None:
|
||||
if isinstance(event, GraphRunStartedEvent):
|
||||
self.print_text("\n[GraphRunStartedEvent]", color='pink')
|
||||
elif isinstance(event, GraphRunSucceededEvent):
|
||||
self.print_text("\n[GraphRunSucceededEvent]", color='green')
|
||||
elif isinstance(event, GraphRunFailedEvent):
|
||||
self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color='red')
|
||||
elif isinstance(event, NodeRunStartedEvent):
|
||||
self.on_workflow_node_execute_started(
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, NodeRunSucceededEvent):
|
||||
self.on_workflow_node_execute_succeeded(
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, NodeRunFailedEvent):
|
||||
self.on_workflow_node_execute_failed(
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
self.on_node_text_chunk(
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunStartedEvent):
|
||||
self.on_workflow_parallel_started(
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent):
|
||||
self.on_workflow_parallel_completed(
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, IterationRunStartedEvent):
|
||||
self.on_workflow_iteration_started(
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, IterationRunNextEvent):
|
||||
self.on_workflow_iteration_next(
|
||||
event=event
|
||||
)
|
||||
elif isinstance(event, IterationRunSucceededEvent | IterationRunFailedEvent):
|
||||
self.on_workflow_iteration_completed(
|
||||
event=event
|
||||
)
|
||||
else:
|
||||
self.print_text(f"\n[{event.__class__.__name__}]", color='blue')
|
||||
|
||||
def on_workflow_run_succeeded(self) -> None:
|
||||
"""
|
||||
Workflow run succeeded
|
||||
"""
|
||||
self.print_text("\n[on_workflow_run_succeeded]", color='green')
|
||||
|
||||
def on_workflow_run_failed(self, error: str) -> None:
|
||||
"""
|
||||
Workflow run failed
|
||||
"""
|
||||
self.print_text("\n[on_workflow_run_failed]", color='red')
|
||||
|
||||
def on_workflow_node_execute_started(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
node_run_index: int = 1,
|
||||
predecessor_node_id: Optional[str] = None) -> None:
|
||||
def on_workflow_node_execute_started(
|
||||
self,
|
||||
event: NodeRunStartedEvent
|
||||
) -> None:
|
||||
"""
|
||||
Workflow node execute started
|
||||
"""
|
||||
self.print_text("\n[on_workflow_node_execute_started]", color='yellow')
|
||||
self.print_text(f"Node ID: {node_id}", color='yellow')
|
||||
self.print_text(f"Type: {node_type.value}", color='yellow')
|
||||
self.print_text(f"Index: {node_run_index}", color='yellow')
|
||||
if predecessor_node_id:
|
||||
self.print_text(f"Predecessor Node ID: {predecessor_node_id}", color='yellow')
|
||||
self.print_text("\n[NodeRunStartedEvent]", color='yellow')
|
||||
self.print_text(f"Node ID: {event.node_id}", color='yellow')
|
||||
self.print_text(f"Node Title: {event.node_data.title}", color='yellow')
|
||||
self.print_text(f"Type: {event.node_type.value}", color='yellow')
|
||||
|
||||
def on_workflow_node_execute_succeeded(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
inputs: Optional[dict] = None,
|
||||
process_data: Optional[dict] = None,
|
||||
outputs: Optional[dict] = None,
|
||||
execution_metadata: Optional[dict] = None) -> None:
|
||||
def on_workflow_node_execute_succeeded(
|
||||
self,
|
||||
event: NodeRunSucceededEvent
|
||||
) -> None:
|
||||
"""
|
||||
Workflow node execute succeeded
|
||||
"""
|
||||
self.print_text("\n[on_workflow_node_execute_succeeded]", color='green')
|
||||
self.print_text(f"Node ID: {node_id}", color='green')
|
||||
self.print_text(f"Type: {node_type.value}", color='green')
|
||||
self.print_text(f"Inputs: {jsonable_encoder(inputs) if inputs else ''}", color='green')
|
||||
self.print_text(f"Process Data: {jsonable_encoder(process_data) if process_data else ''}", color='green')
|
||||
self.print_text(f"Outputs: {jsonable_encoder(outputs) if outputs else ''}", color='green')
|
||||
self.print_text(f"Metadata: {jsonable_encoder(execution_metadata) if execution_metadata else ''}",
|
||||
color='green')
|
||||
route_node_state = event.route_node_state
|
||||
|
||||
def on_workflow_node_execute_failed(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
error: str,
|
||||
inputs: Optional[dict] = None,
|
||||
outputs: Optional[dict] = None,
|
||||
process_data: Optional[dict] = None) -> None:
|
||||
self.print_text("\n[NodeRunSucceededEvent]", color='green')
|
||||
self.print_text(f"Node ID: {event.node_id}", color='green')
|
||||
self.print_text(f"Node Title: {event.node_data.title}", color='green')
|
||||
self.print_text(f"Type: {event.node_type.value}", color='green')
|
||||
|
||||
if route_node_state.node_run_result:
|
||||
node_run_result = route_node_state.node_run_result
|
||||
self.print_text(f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}",
|
||||
color='green')
|
||||
self.print_text(
|
||||
f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
|
||||
color='green')
|
||||
self.print_text(f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}",
|
||||
color='green')
|
||||
self.print_text(
|
||||
f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}",
|
||||
color='green')
|
||||
|
||||
def on_workflow_node_execute_failed(
|
||||
self,
|
||||
event: NodeRunFailedEvent
|
||||
) -> None:
|
||||
"""
|
||||
Workflow node execute failed
|
||||
"""
|
||||
self.print_text("\n[on_workflow_node_execute_failed]", color='red')
|
||||
self.print_text(f"Node ID: {node_id}", color='red')
|
||||
self.print_text(f"Type: {node_type.value}", color='red')
|
||||
self.print_text(f"Error: {error}", color='red')
|
||||
self.print_text(f"Inputs: {jsonable_encoder(inputs) if inputs else ''}", color='red')
|
||||
self.print_text(f"Process Data: {jsonable_encoder(process_data) if process_data else ''}", color='red')
|
||||
self.print_text(f"Outputs: {jsonable_encoder(outputs) if outputs else ''}", color='red')
|
||||
route_node_state = event.route_node_state
|
||||
|
||||
def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None:
|
||||
self.print_text("\n[NodeRunFailedEvent]", color='red')
|
||||
self.print_text(f"Node ID: {event.node_id}", color='red')
|
||||
self.print_text(f"Node Title: {event.node_data.title}", color='red')
|
||||
self.print_text(f"Type: {event.node_type.value}", color='red')
|
||||
|
||||
if route_node_state.node_run_result:
|
||||
node_run_result = route_node_state.node_run_result
|
||||
self.print_text(f"Error: {node_run_result.error}", color='red')
|
||||
self.print_text(f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}",
|
||||
color='red')
|
||||
self.print_text(
|
||||
f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
|
||||
color='red')
|
||||
self.print_text(f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}",
|
||||
color='red')
|
||||
|
||||
def on_node_text_chunk(
|
||||
self,
|
||||
event: NodeRunStreamChunkEvent
|
||||
) -> None:
|
||||
"""
|
||||
Publish text chunk
|
||||
"""
|
||||
if not self.current_node_id or self.current_node_id != node_id:
|
||||
self.current_node_id = node_id
|
||||
self.print_text('\n[on_node_text_chunk]')
|
||||
self.print_text(f"Node ID: {node_id}")
|
||||
self.print_text(f"Metadata: {jsonable_encoder(metadata) if metadata else ''}")
|
||||
route_node_state = event.route_node_state
|
||||
if not self.current_node_id or self.current_node_id != route_node_state.node_id:
|
||||
self.current_node_id = route_node_state.node_id
|
||||
self.print_text('\n[NodeRunStreamChunkEvent]')
|
||||
self.print_text(f"Node ID: {route_node_state.node_id}")
|
||||
|
||||
self.print_text(text, color="pink", end="")
|
||||
node_run_result = route_node_state.node_run_result
|
||||
if node_run_result:
|
||||
self.print_text(
|
||||
f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}")
|
||||
|
||||
def on_workflow_iteration_started(self,
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
node_run_index: int = 1,
|
||||
node_data: Optional[BaseNodeData] = None,
|
||||
inputs: dict = None,
|
||||
predecessor_node_id: Optional[str] = None,
|
||||
metadata: Optional[dict] = None) -> None:
|
||||
self.print_text(event.chunk_content, color="pink", end="")
|
||||
|
||||
def on_workflow_parallel_started(
|
||||
self,
|
||||
event: ParallelBranchRunStartedEvent
|
||||
) -> None:
|
||||
"""
|
||||
Publish parallel started
|
||||
"""
|
||||
self.print_text("\n[ParallelBranchRunStartedEvent]", color='blue')
|
||||
self.print_text(f"Parallel ID: {event.parallel_id}", color='blue')
|
||||
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color='blue')
|
||||
if event.in_iteration_id:
|
||||
self.print_text(f"Iteration ID: {event.in_iteration_id}", color='blue')
|
||||
|
||||
def on_workflow_parallel_completed(
|
||||
self,
|
||||
event: ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent
|
||||
) -> None:
|
||||
"""
|
||||
Publish parallel completed
|
||||
"""
|
||||
if isinstance(event, ParallelBranchRunSucceededEvent):
|
||||
color = 'blue'
|
||||
elif isinstance(event, ParallelBranchRunFailedEvent):
|
||||
color = 'red'
|
||||
|
||||
self.print_text("\n[ParallelBranchRunSucceededEvent]" if isinstance(event, ParallelBranchRunSucceededEvent) else "\n[ParallelBranchRunFailedEvent]", color=color)
|
||||
self.print_text(f"Parallel ID: {event.parallel_id}", color=color)
|
||||
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color=color)
|
||||
if event.in_iteration_id:
|
||||
self.print_text(f"Iteration ID: {event.in_iteration_id}", color=color)
|
||||
|
||||
if isinstance(event, ParallelBranchRunFailedEvent):
|
||||
self.print_text(f"Error: {event.error}", color=color)
|
||||
|
||||
def on_workflow_iteration_started(
|
||||
self,
|
||||
event: IterationRunStartedEvent
|
||||
) -> None:
|
||||
"""
|
||||
Publish iteration started
|
||||
"""
|
||||
self.print_text("\n[on_workflow_iteration_started]", color='blue')
|
||||
self.print_text(f"Node ID: {node_id}", color='blue')
|
||||
self.print_text("\n[IterationRunStartedEvent]", color='blue')
|
||||
self.print_text(f"Iteration Node ID: {event.iteration_id}", color='blue')
|
||||
|
||||
def on_workflow_iteration_next(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
index: int,
|
||||
node_run_index: int,
|
||||
output: Optional[dict]) -> None:
|
||||
def on_workflow_iteration_next(
|
||||
self,
|
||||
event: IterationRunNextEvent
|
||||
) -> None:
|
||||
"""
|
||||
Publish iteration next
|
||||
"""
|
||||
self.print_text("\n[on_workflow_iteration_next]", color='blue')
|
||||
self.print_text("\n[IterationRunNextEvent]", color='blue')
|
||||
self.print_text(f"Iteration Node ID: {event.iteration_id}", color='blue')
|
||||
self.print_text(f"Iteration Index: {event.index}", color='blue')
|
||||
|
||||
def on_workflow_iteration_completed(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_run_index: int,
|
||||
outputs: dict) -> None:
|
||||
def on_workflow_iteration_completed(
|
||||
self,
|
||||
event: IterationRunSucceededEvent | IterationRunFailedEvent
|
||||
) -> None:
|
||||
"""
|
||||
Publish iteration completed
|
||||
"""
|
||||
self.print_text("\n[on_workflow_iteration_completed]", color='blue')
|
||||
|
||||
def on_event(self, event: AppQueueEvent) -> None:
|
||||
"""
|
||||
Publish event
|
||||
"""
|
||||
self.print_text("\n[on_workflow_event]", color='blue')
|
||||
self.print_text(f"Event: {jsonable_encoder(event)}", color='blue')
|
||||
self.print_text("\n[IterationRunSucceededEvent]" if isinstance(event, IterationRunSucceededEvent) else "\n[IterationRunFailedEvent]", color='blue')
|
||||
self.print_text(f"Node ID: {event.iteration_id}", color='blue')
|
||||
|
||||
def print_text(
|
||||
self, text: str, color: Optional[str] = None, end: str = "\n"
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
|
@ -5,7 +6,8 @@ from pydantic import BaseModel, field_validator
|
|||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
|
||||
|
||||
class QueueEvent(str, Enum):
|
||||
|
@ -31,6 +33,9 @@ class QueueEvent(str, Enum):
|
|||
ANNOTATION_REPLY = "annotation_reply"
|
||||
AGENT_THOUGHT = "agent_thought"
|
||||
MESSAGE_FILE = "message_file"
|
||||
PARALLEL_BRANCH_RUN_STARTED = "parallel_branch_run_started"
|
||||
PARALLEL_BRANCH_RUN_SUCCEEDED = "parallel_branch_run_succeeded"
|
||||
PARALLEL_BRANCH_RUN_FAILED = "parallel_branch_run_failed"
|
||||
ERROR = "error"
|
||||
PING = "ping"
|
||||
STOP = "stop"
|
||||
|
@ -38,7 +43,7 @@ class QueueEvent(str, Enum):
|
|||
|
||||
class AppQueueEvent(BaseModel):
|
||||
"""
|
||||
QueueEvent entity
|
||||
QueueEvent abstract entity
|
||||
"""
|
||||
event: QueueEvent
|
||||
|
||||
|
@ -46,6 +51,7 @@ class AppQueueEvent(BaseModel):
|
|||
class QueueLLMChunkEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueLLMChunkEvent entity
|
||||
Only for basic mode apps
|
||||
"""
|
||||
event: QueueEvent = QueueEvent.LLM_CHUNK
|
||||
chunk: LLMResultChunk
|
||||
|
@ -55,14 +61,24 @@ class QueueIterationStartEvent(AppQueueEvent):
|
|||
QueueIterationStartEvent entity
|
||||
"""
|
||||
event: QueueEvent = QueueEvent.ITERATION_START
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
start_at: datetime
|
||||
|
||||
node_run_index: int
|
||||
inputs: dict = None
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
predecessor_node_id: Optional[str] = None
|
||||
metadata: Optional[dict] = None
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
|
||||
class QueueIterationNextEvent(AppQueueEvent):
|
||||
"""
|
||||
|
@ -71,8 +87,18 @@ class QueueIterationNextEvent(AppQueueEvent):
|
|||
event: QueueEvent = QueueEvent.ITERATION_NEXT
|
||||
|
||||
index: int
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
|
||||
node_run_index: int
|
||||
output: Optional[Any] = None # output for the current iteration
|
||||
|
@ -93,13 +119,30 @@ class QueueIterationCompletedEvent(AppQueueEvent):
|
|||
"""
|
||||
QueueIterationCompletedEvent entity
|
||||
"""
|
||||
event:QueueEvent = QueueEvent.ITERATION_COMPLETED
|
||||
event: QueueEvent = QueueEvent.ITERATION_COMPLETED
|
||||
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
start_at: datetime
|
||||
|
||||
node_run_index: int
|
||||
outputs: dict
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
steps: int = 0
|
||||
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class QueueTextChunkEvent(AppQueueEvent):
|
||||
"""
|
||||
|
@ -107,7 +150,10 @@ class QueueTextChunkEvent(AppQueueEvent):
|
|||
"""
|
||||
event: QueueEvent = QueueEvent.TEXT_CHUNK
|
||||
text: str
|
||||
metadata: Optional[dict] = None
|
||||
from_variable_selector: Optional[list[str]] = None
|
||||
"""from variable selector"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
|
||||
|
||||
class QueueAgentMessageEvent(AppQueueEvent):
|
||||
|
@ -132,6 +178,8 @@ class QueueRetrieverResourcesEvent(AppQueueEvent):
|
|||
"""
|
||||
event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES
|
||||
retriever_resources: list[dict]
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
|
||||
|
||||
class QueueAnnotationReplyEvent(AppQueueEvent):
|
||||
|
@ -162,6 +210,7 @@ class QueueWorkflowStartedEvent(AppQueueEvent):
|
|||
QueueWorkflowStartedEvent entity
|
||||
"""
|
||||
event: QueueEvent = QueueEvent.WORKFLOW_STARTED
|
||||
graph_runtime_state: GraphRuntimeState
|
||||
|
||||
|
||||
class QueueWorkflowSucceededEvent(AppQueueEvent):
|
||||
|
@ -169,6 +218,7 @@ class QueueWorkflowSucceededEvent(AppQueueEvent):
|
|||
QueueWorkflowSucceededEvent entity
|
||||
"""
|
||||
event: QueueEvent = QueueEvent.WORKFLOW_SUCCEEDED
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class QueueWorkflowFailedEvent(AppQueueEvent):
|
||||
|
@ -185,11 +235,23 @@ class QueueNodeStartedEvent(AppQueueEvent):
|
|||
"""
|
||||
event: QueueEvent = QueueEvent.NODE_STARTED
|
||||
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
node_run_index: int = 1
|
||||
predecessor_node_id: Optional[str] = None
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
start_at: datetime
|
||||
|
||||
|
||||
class QueueNodeSucceededEvent(AppQueueEvent):
|
||||
|
@ -198,14 +260,26 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
|||
"""
|
||||
event: QueueEvent = QueueEvent.NODE_SUCCEEDED
|
||||
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Optional[dict] = None
|
||||
process_data: Optional[dict] = None
|
||||
outputs: Optional[dict] = None
|
||||
execution_metadata: Optional[dict] = None
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
process_data: Optional[dict[str, Any]] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
|
||||
|
||||
error: Optional[str] = None
|
||||
|
||||
|
@ -216,13 +290,25 @@ class QueueNodeFailedEvent(AppQueueEvent):
|
|||
"""
|
||||
event: QueueEvent = QueueEvent.NODE_FAILED
|
||||
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Optional[dict] = None
|
||||
outputs: Optional[dict] = None
|
||||
process_data: Optional[dict] = None
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
process_data: Optional[dict[str, Any]] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
|
||||
error: str
|
||||
|
||||
|
@ -274,10 +360,23 @@ class QueueStopEvent(AppQueueEvent):
|
|||
event: QueueEvent = QueueEvent.STOP
|
||||
stopped_by: StopBy
|
||||
|
||||
def get_stop_reason(self) -> str:
|
||||
"""
|
||||
To stop reason
|
||||
"""
|
||||
reason_mapping = {
|
||||
QueueStopEvent.StopBy.USER_MANUAL: 'Stopped by user.',
|
||||
QueueStopEvent.StopBy.ANNOTATION_REPLY: 'Stopped by annotation reply.',
|
||||
QueueStopEvent.StopBy.OUTPUT_MODERATION: 'Stopped by output moderation.',
|
||||
QueueStopEvent.StopBy.INPUT_MODERATION: 'Stopped by input moderation.'
|
||||
}
|
||||
|
||||
return reason_mapping.get(self.stopped_by, 'Stopped by unknown reason.')
|
||||
|
||||
|
||||
class QueueMessage(BaseModel):
|
||||
"""
|
||||
QueueMessage entity
|
||||
QueueMessage abstract entity
|
||||
"""
|
||||
task_id: str
|
||||
app_mode: str
|
||||
|
@ -297,3 +396,52 @@ class WorkflowQueueMessage(QueueMessage):
|
|||
WorkflowQueueMessage entity
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class QueueParallelBranchRunStartedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueParallelBranchRunStartedEvent entity
|
||||
"""
|
||||
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_STARTED
|
||||
|
||||
parallel_id: str
|
||||
parallel_start_node_id: str
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
|
||||
|
||||
class QueueParallelBranchRunSucceededEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueParallelBranchRunSucceededEvent entity
|
||||
"""
|
||||
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_SUCCEEDED
|
||||
|
||||
parallel_id: str
|
||||
parallel_start_node_id: str
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
|
||||
|
||||
class QueueParallelBranchRunFailedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueParallelBranchRunFailedEvent entity
|
||||
"""
|
||||
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_FAILED
|
||||
|
||||
parallel_id: str
|
||||
parallel_start_node_id: str
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
error: str
|
||||
|
|
|
@ -3,40 +3,11 @@ from typing import Any, Optional
|
|||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.nodes.answer.entities import GenerateRouteChunk
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class WorkflowStreamGenerateNodes(BaseModel):
|
||||
"""
|
||||
WorkflowStreamGenerateNodes entity
|
||||
"""
|
||||
end_node_id: str
|
||||
stream_node_ids: list[str]
|
||||
|
||||
|
||||
class ChatflowStreamGenerateRoute(BaseModel):
|
||||
"""
|
||||
ChatflowStreamGenerateRoute entity
|
||||
"""
|
||||
answer_node_id: str
|
||||
generate_route: list[GenerateRouteChunk]
|
||||
current_route_position: int = 0
|
||||
|
||||
|
||||
class NodeExecutionInfo(BaseModel):
|
||||
"""
|
||||
NodeExecutionInfo entity
|
||||
"""
|
||||
workflow_node_execution_id: str
|
||||
node_type: NodeType
|
||||
start_at: float
|
||||
|
||||
|
||||
class TaskState(BaseModel):
|
||||
"""
|
||||
TaskState entity
|
||||
|
@ -57,27 +28,6 @@ class WorkflowTaskState(TaskState):
|
|||
"""
|
||||
answer: str = ""
|
||||
|
||||
workflow_run_id: Optional[str] = None
|
||||
start_at: Optional[float] = None
|
||||
total_tokens: int = 0
|
||||
total_steps: int = 0
|
||||
|
||||
ran_node_execution_infos: dict[str, NodeExecutionInfo] = {}
|
||||
latest_node_execution_info: Optional[NodeExecutionInfo] = None
|
||||
|
||||
current_stream_generate_state: Optional[WorkflowStreamGenerateNodes] = None
|
||||
|
||||
iteration_nested_node_ids: list[str] = None
|
||||
|
||||
|
||||
class AdvancedChatTaskState(WorkflowTaskState):
|
||||
"""
|
||||
AdvancedChatTaskState entity
|
||||
"""
|
||||
usage: LLMUsage
|
||||
|
||||
current_stream_generate_state: Optional[ChatflowStreamGenerateRoute] = None
|
||||
|
||||
|
||||
class StreamEvent(Enum):
|
||||
"""
|
||||
|
@ -97,6 +47,8 @@ class StreamEvent(Enum):
|
|||
WORKFLOW_FINISHED = "workflow_finished"
|
||||
NODE_STARTED = "node_started"
|
||||
NODE_FINISHED = "node_finished"
|
||||
PARALLEL_BRANCH_STARTED = "parallel_branch_started"
|
||||
PARALLEL_BRANCH_FINISHED = "parallel_branch_finished"
|
||||
ITERATION_STARTED = "iteration_started"
|
||||
ITERATION_NEXT = "iteration_next"
|
||||
ITERATION_COMPLETED = "iteration_completed"
|
||||
|
@ -267,6 +219,11 @@ class NodeStartStreamResponse(StreamResponse):
|
|||
inputs: Optional[dict] = None
|
||||
created_at: int
|
||||
extras: dict = {}
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
parent_parallel_id: Optional[str] = None
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
iteration_id: Optional[str] = None
|
||||
|
||||
event: StreamEvent = StreamEvent.NODE_STARTED
|
||||
workflow_run_id: str
|
||||
|
@ -286,7 +243,12 @@ class NodeStartStreamResponse(StreamResponse):
|
|||
"predecessor_node_id": self.data.predecessor_node_id,
|
||||
"inputs": None,
|
||||
"created_at": self.data.created_at,
|
||||
"extras": {}
|
||||
"extras": {},
|
||||
"parallel_id": self.data.parallel_id,
|
||||
"parallel_start_node_id": self.data.parallel_start_node_id,
|
||||
"parent_parallel_id": self.data.parent_parallel_id,
|
||||
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
|
||||
"iteration_id": self.data.iteration_id,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -316,6 +278,11 @@ class NodeFinishStreamResponse(StreamResponse):
|
|||
created_at: int
|
||||
finished_at: int
|
||||
files: Optional[list[dict]] = []
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
parent_parallel_id: Optional[str] = None
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
iteration_id: Optional[str] = None
|
||||
|
||||
event: StreamEvent = StreamEvent.NODE_FINISHED
|
||||
workflow_run_id: str
|
||||
|
@ -342,9 +309,58 @@ class NodeFinishStreamResponse(StreamResponse):
|
|||
"execution_metadata": None,
|
||||
"created_at": self.data.created_at,
|
||||
"finished_at": self.data.finished_at,
|
||||
"files": []
|
||||
"files": [],
|
||||
"parallel_id": self.data.parallel_id,
|
||||
"parallel_start_node_id": self.data.parallel_start_node_id,
|
||||
"parent_parallel_id": self.data.parent_parallel_id,
|
||||
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
|
||||
"iteration_id": self.data.iteration_id,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ParallelBranchStartStreamResponse(StreamResponse):
|
||||
"""
|
||||
ParallelBranchStartStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
parallel_id: str
|
||||
parallel_branch_id: str
|
||||
parent_parallel_id: Optional[str] = None
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
iteration_id: Optional[str] = None
|
||||
created_at: int
|
||||
|
||||
event: StreamEvent = StreamEvent.PARALLEL_BRANCH_STARTED
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class ParallelBranchFinishedStreamResponse(StreamResponse):
|
||||
"""
|
||||
ParallelBranchFinishedStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
parallel_id: str
|
||||
parallel_branch_id: str
|
||||
parent_parallel_id: Optional[str] = None
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
iteration_id: Optional[str] = None
|
||||
status: str
|
||||
error: Optional[str] = None
|
||||
created_at: int
|
||||
|
||||
event: StreamEvent = StreamEvent.PARALLEL_BRANCH_FINISHED
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class IterationNodeStartStreamResponse(StreamResponse):
|
||||
|
@ -364,6 +380,8 @@ class IterationNodeStartStreamResponse(StreamResponse):
|
|||
extras: dict = {}
|
||||
metadata: dict = {}
|
||||
inputs: dict = {}
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
|
||||
event: StreamEvent = StreamEvent.ITERATION_STARTED
|
||||
workflow_run_id: str
|
||||
|
@ -387,6 +405,8 @@ class IterationNodeNextStreamResponse(StreamResponse):
|
|||
created_at: int
|
||||
pre_iteration_output: Optional[Any] = None
|
||||
extras: dict = {}
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
|
||||
event: StreamEvent = StreamEvent.ITERATION_NEXT
|
||||
workflow_run_id: str
|
||||
|
@ -408,8 +428,8 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
|
|||
title: str
|
||||
outputs: Optional[dict] = None
|
||||
created_at: int
|
||||
extras: dict = None
|
||||
inputs: dict = None
|
||||
extras: Optional[dict] = None
|
||||
inputs: Optional[dict] = None
|
||||
status: WorkflowNodeExecutionStatus
|
||||
error: Optional[str] = None
|
||||
elapsed_time: float
|
||||
|
@ -417,6 +437,8 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
|
|||
execution_metadata: Optional[dict] = None
|
||||
finished_at: int
|
||||
steps: int
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
|
||||
event: StreamEvent = StreamEvent.ITERATION_COMPLETED
|
||||
workflow_run_id: str
|
||||
|
@ -488,7 +510,7 @@ class WorkflowAppStreamResponse(AppStreamResponse):
|
|||
"""
|
||||
WorkflowAppStreamResponse entity
|
||||
"""
|
||||
workflow_run_id: str
|
||||
workflow_run_id: Optional[str] = None
|
||||
|
||||
|
||||
class AppBlockingResponse(BaseModel):
|
||||
|
@ -562,25 +584,3 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
|
|||
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class WorkflowIterationState(BaseModel):
|
||||
"""
|
||||
WorkflowIterationState entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
parent_iteration_id: Optional[str] = None
|
||||
iteration_id: str
|
||||
current_index: int
|
||||
iteration_steps_boundary: list[int] = None
|
||||
node_execution_id: str
|
||||
started_at: float
|
||||
inputs: dict = None
|
||||
total_tokens: int = 0
|
||||
node_data: BaseNodeData
|
||||
|
||||
current_iterations: dict[str, Data] = None
|
||||
|
|
|
@ -68,16 +68,18 @@ class BasedGenerateTaskPipeline:
|
|||
err = Exception(e.description if getattr(e, 'description', None) is not None else str(e))
|
||||
|
||||
if message:
|
||||
message = db.session.query(Message).filter(Message.id == message.id).first()
|
||||
err_desc = self._error_to_desc(err)
|
||||
message.status = 'error'
|
||||
message.error = err_desc
|
||||
refetch_message = db.session.query(Message).filter(Message.id == message.id).first()
|
||||
|
||||
db.session.commit()
|
||||
if refetch_message:
|
||||
err_desc = self._error_to_desc(err)
|
||||
refetch_message.status = 'error'
|
||||
refetch_message.error = err_desc
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return err
|
||||
|
||||
def _error_to_desc(cls, e: Exception) -> str:
|
||||
def _error_to_desc(self, e: Exception) -> str:
|
||||
"""
|
||||
Error to desc.
|
||||
:param e: exception
|
||||
|
|
|
@ -8,7 +8,6 @@ from core.app.entities.app_invoke_entities import (
|
|||
AgentChatAppGenerateEntity,
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
InvokeFrom,
|
||||
)
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAnnotationReplyEvent,
|
||||
|
@ -16,11 +15,11 @@ from core.app.entities.queue_entities import (
|
|||
QueueRetrieverResourcesEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
AdvancedChatTaskState,
|
||||
EasyUITaskState,
|
||||
MessageFileStreamResponse,
|
||||
MessageReplaceStreamResponse,
|
||||
MessageStreamResponse,
|
||||
WorkflowTaskState,
|
||||
)
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
|
@ -36,7 +35,7 @@ class MessageCycleManage:
|
|||
AgentChatAppGenerateEntity,
|
||||
AdvancedChatAppGenerateEntity
|
||||
]
|
||||
_task_state: Union[EasyUITaskState, AdvancedChatTaskState]
|
||||
_task_state: Union[EasyUITaskState, WorkflowTaskState]
|
||||
|
||||
def _generate_conversation_name(self, conversation: Conversation, query: str) -> Optional[Thread]:
|
||||
"""
|
||||
|
@ -45,6 +44,9 @@ class MessageCycleManage:
|
|||
:param query: query
|
||||
:return: thread
|
||||
"""
|
||||
if isinstance(self._application_generate_entity, CompletionAppGenerateEntity):
|
||||
return None
|
||||
|
||||
is_first_message = self._application_generate_entity.conversation_id is None
|
||||
extras = self._application_generate_entity.extras
|
||||
auto_generate_conversation_name = extras.get('auto_generate_conversation_name', True)
|
||||
|
@ -52,7 +54,7 @@ class MessageCycleManage:
|
|||
if auto_generate_conversation_name and is_first_message:
|
||||
# start generate thread
|
||||
thread = Thread(target=self._generate_conversation_name_worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'flask_app': current_app._get_current_object(), # type: ignore
|
||||
'conversation_id': conversation.id,
|
||||
'query': query
|
||||
})
|
||||
|
@ -75,6 +77,9 @@ class MessageCycleManage:
|
|||
.first()
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
return
|
||||
|
||||
if conversation.mode != AppMode.COMPLETION.value:
|
||||
app_model = conversation.app
|
||||
if not app_model:
|
||||
|
@ -121,34 +126,13 @@ class MessageCycleManage:
|
|||
if self._application_generate_entity.app_config.additional_features.show_retrieve_source:
|
||||
self._task_state.metadata['retriever_resources'] = event.retriever_resources
|
||||
|
||||
def _get_response_metadata(self) -> dict:
|
||||
"""
|
||||
Get response metadata by invoke from.
|
||||
:return:
|
||||
"""
|
||||
metadata = {}
|
||||
|
||||
# show_retrieve_source
|
||||
if 'retriever_resources' in self._task_state.metadata:
|
||||
metadata['retriever_resources'] = self._task_state.metadata['retriever_resources']
|
||||
|
||||
# show annotation reply
|
||||
if 'annotation_reply' in self._task_state.metadata:
|
||||
metadata['annotation_reply'] = self._task_state.metadata['annotation_reply']
|
||||
|
||||
# show usage
|
||||
if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
|
||||
metadata['usage'] = self._task_state.metadata['usage']
|
||||
|
||||
return metadata
|
||||
|
||||
def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]:
|
||||
"""
|
||||
Message file to stream response.
|
||||
:param event: event
|
||||
:return:
|
||||
"""
|
||||
message_file: MessageFile = (
|
||||
message_file = (
|
||||
db.session.query(MessageFile)
|
||||
.filter(MessageFile.id == event.message_file_id)
|
||||
.first()
|
||||
|
|
|
@ -1,33 +1,41 @@
|
|||
import json
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Union, cast
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueStopEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
QueueParallelBranchRunStartedEvent,
|
||||
QueueParallelBranchRunSucceededEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
NodeExecutionInfo,
|
||||
IterationNodeCompletedStreamResponse,
|
||||
IterationNodeNextStreamResponse,
|
||||
IterationNodeStartStreamResponse,
|
||||
NodeFinishStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
ParallelBranchFinishedStreamResponse,
|
||||
ParallelBranchStartStreamResponse,
|
||||
WorkflowFinishStreamResponse,
|
||||
WorkflowStartStreamResponse,
|
||||
WorkflowTaskState,
|
||||
)
|
||||
from core.app.task_pipeline.workflow_iteration_cycle_manage import WorkflowIterationCycleManage
|
||||
from core.file.file_obj import FileVar
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import EndUser
|
||||
|
@ -41,54 +49,56 @@ from models.workflow import (
|
|||
WorkflowRunStatus,
|
||||
WorkflowRunTriggeredFrom,
|
||||
)
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
|
||||
class WorkflowCycleManage(WorkflowIterationCycleManage):
|
||||
def _init_workflow_run(self, workflow: Workflow,
|
||||
triggered_from: WorkflowRunTriggeredFrom,
|
||||
user: Union[Account, EndUser],
|
||||
user_inputs: dict,
|
||||
system_inputs: Optional[dict] = None) -> WorkflowRun:
|
||||
"""
|
||||
Init workflow run
|
||||
:param workflow: Workflow instance
|
||||
:param triggered_from: triggered from
|
||||
:param user: account or end user
|
||||
:param user_inputs: user variables inputs
|
||||
:param system_inputs: system inputs, like: query, files
|
||||
:return:
|
||||
"""
|
||||
max_sequence = db.session.query(db.func.max(WorkflowRun.sequence_number)) \
|
||||
.filter(WorkflowRun.tenant_id == workflow.tenant_id) \
|
||||
.filter(WorkflowRun.app_id == workflow.app_id) \
|
||||
.scalar() or 0
|
||||
class WorkflowCycleManage:
|
||||
_application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity]
|
||||
_workflow: Workflow
|
||||
_user: Union[Account, EndUser]
|
||||
_task_state: WorkflowTaskState
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
|
||||
def _handle_workflow_run_start(self) -> WorkflowRun:
|
||||
max_sequence = (
|
||||
db.session.query(db.func.max(WorkflowRun.sequence_number))
|
||||
.filter(WorkflowRun.tenant_id == self._workflow.tenant_id)
|
||||
.filter(WorkflowRun.app_id == self._workflow.app_id)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
new_sequence_number = max_sequence + 1
|
||||
|
||||
inputs = {**user_inputs}
|
||||
for key, value in (system_inputs or {}).items():
|
||||
inputs = {**self._application_generate_entity.inputs}
|
||||
for key, value in (self._workflow_system_variables or {}).items():
|
||||
if key.value == 'conversation':
|
||||
continue
|
||||
|
||||
inputs[f'sys.{key.value}'] = value
|
||||
inputs = WorkflowEngineManager.handle_special_values(inputs)
|
||||
|
||||
inputs = WorkflowEntry.handle_special_values(inputs)
|
||||
|
||||
triggered_from= (
|
||||
WorkflowRunTriggeredFrom.DEBUGGING
|
||||
if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER
|
||||
else WorkflowRunTriggeredFrom.APP_RUN
|
||||
)
|
||||
|
||||
# init workflow run
|
||||
workflow_run = WorkflowRun(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
sequence_number=new_sequence_number,
|
||||
workflow_id=workflow.id,
|
||||
type=workflow.type,
|
||||
triggered_from=triggered_from.value,
|
||||
version=workflow.version,
|
||||
graph=workflow.graph,
|
||||
inputs=json.dumps(inputs),
|
||||
status=WorkflowRunStatus.RUNNING.value,
|
||||
created_by_role=(CreatedByRole.ACCOUNT.value
|
||||
if isinstance(user, Account) else CreatedByRole.END_USER.value),
|
||||
created_by=user.id
|
||||
workflow_run = WorkflowRun()
|
||||
workflow_run.tenant_id = self._workflow.tenant_id
|
||||
workflow_run.app_id = self._workflow.app_id
|
||||
workflow_run.sequence_number = new_sequence_number
|
||||
workflow_run.workflow_id = self._workflow.id
|
||||
workflow_run.type = self._workflow.type
|
||||
workflow_run.triggered_from = triggered_from.value
|
||||
workflow_run.version = self._workflow.version
|
||||
workflow_run.graph = self._workflow.graph
|
||||
workflow_run.inputs = json.dumps(inputs)
|
||||
workflow_run.status = WorkflowRunStatus.RUNNING.value
|
||||
workflow_run.created_by_role = (
|
||||
CreatedByRole.ACCOUNT.value if isinstance(self._user, Account) else CreatedByRole.END_USER.value
|
||||
)
|
||||
workflow_run.created_by = self._user.id
|
||||
|
||||
db.session.add(workflow_run)
|
||||
db.session.commit()
|
||||
|
@ -97,33 +107,37 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
|
|||
|
||||
return workflow_run
|
||||
|
||||
def _workflow_run_success(
|
||||
self, workflow_run: WorkflowRun,
|
||||
def _handle_workflow_run_success(
|
||||
self,
|
||||
workflow_run: WorkflowRun,
|
||||
start_at: float,
|
||||
total_tokens: int,
|
||||
total_steps: int,
|
||||
outputs: Optional[str] = None,
|
||||
conversation_id: Optional[str] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
) -> WorkflowRun:
|
||||
"""
|
||||
Workflow run success
|
||||
:param workflow_run: workflow run
|
||||
:param start_at: start time
|
||||
:param total_tokens: total tokens
|
||||
:param total_steps: total steps
|
||||
:param outputs: outputs
|
||||
:param conversation_id: conversation id
|
||||
:return:
|
||||
"""
|
||||
workflow_run = self._refetch_workflow_run(workflow_run.id)
|
||||
|
||||
workflow_run.status = WorkflowRunStatus.SUCCEEDED.value
|
||||
workflow_run.outputs = outputs
|
||||
workflow_run.elapsed_time = WorkflowService.get_elapsed_time(workflow_run_id=workflow_run.id)
|
||||
workflow_run.elapsed_time = time.perf_counter() - start_at
|
||||
workflow_run.total_tokens = total_tokens
|
||||
workflow_run.total_steps = total_steps
|
||||
workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
|
||||
db.session.commit()
|
||||
db.session.refresh(workflow_run)
|
||||
db.session.close()
|
||||
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
|
@ -135,34 +149,58 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
|
|||
)
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _workflow_run_failed(
|
||||
self, workflow_run: WorkflowRun,
|
||||
def _handle_workflow_run_failed(
|
||||
self,
|
||||
workflow_run: WorkflowRun,
|
||||
start_at: float,
|
||||
total_tokens: int,
|
||||
total_steps: int,
|
||||
status: WorkflowRunStatus,
|
||||
error: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
) -> WorkflowRun:
|
||||
"""
|
||||
Workflow run failed
|
||||
:param workflow_run: workflow run
|
||||
:param start_at: start time
|
||||
:param total_tokens: total tokens
|
||||
:param total_steps: total steps
|
||||
:param status: status
|
||||
:param error: error message
|
||||
:return:
|
||||
"""
|
||||
workflow_run = self._refetch_workflow_run(workflow_run.id)
|
||||
|
||||
workflow_run.status = status.value
|
||||
workflow_run.error = error
|
||||
workflow_run.elapsed_time = WorkflowService.get_elapsed_time(workflow_run_id=workflow_run.id)
|
||||
workflow_run.elapsed_time = time.perf_counter() - start_at
|
||||
workflow_run.total_tokens = total_tokens
|
||||
workflow_run.total_steps = total_steps
|
||||
workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
running_workflow_node_executions = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
|
||||
WorkflowNodeExecution.app_id == workflow_run.app_id,
|
||||
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
|
||||
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
WorkflowNodeExecution.workflow_run_id == workflow_run.id,
|
||||
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value
|
||||
).all()
|
||||
|
||||
for workflow_node_execution in running_workflow_node_executions:
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
||||
workflow_node_execution.error = error
|
||||
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - workflow_node_execution.created_at).total_seconds()
|
||||
db.session.commit()
|
||||
|
||||
db.session.refresh(workflow_run)
|
||||
db.session.close()
|
||||
|
||||
|
@ -178,39 +216,24 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
|
|||
|
||||
return workflow_run
|
||||
|
||||
def _init_node_execution_from_workflow_run(self, workflow_run: WorkflowRun,
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
node_title: str,
|
||||
node_run_index: int = 1,
|
||||
predecessor_node_id: Optional[str] = None) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Init workflow node execution from workflow run
|
||||
:param workflow_run: workflow run
|
||||
:param node_id: node id
|
||||
:param node_type: node type
|
||||
:param node_title: node title
|
||||
:param node_run_index: run index
|
||||
:param predecessor_node_id: predecessor node id if exists
|
||||
:return:
|
||||
"""
|
||||
def _handle_node_execution_start(self, workflow_run: WorkflowRun, event: QueueNodeStartedEvent) -> WorkflowNodeExecution:
|
||||
# init workflow node execution
|
||||
workflow_node_execution = WorkflowNodeExecution(
|
||||
tenant_id=workflow_run.tenant_id,
|
||||
app_id=workflow_run.app_id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
workflow_run_id=workflow_run.id,
|
||||
predecessor_node_id=predecessor_node_id,
|
||||
index=node_run_index,
|
||||
node_id=node_id,
|
||||
node_type=node_type.value,
|
||||
title=node_title,
|
||||
status=WorkflowNodeExecutionStatus.RUNNING.value,
|
||||
created_by_role=workflow_run.created_by_role,
|
||||
created_by=workflow_run.created_by,
|
||||
created_at=datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
)
|
||||
workflow_node_execution = WorkflowNodeExecution()
|
||||
workflow_node_execution.tenant_id = workflow_run.tenant_id
|
||||
workflow_node_execution.app_id = workflow_run.app_id
|
||||
workflow_node_execution.workflow_id = workflow_run.workflow_id
|
||||
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
|
||||
workflow_node_execution.workflow_run_id = workflow_run.id
|
||||
workflow_node_execution.predecessor_node_id = event.predecessor_node_id
|
||||
workflow_node_execution.index = event.node_run_index
|
||||
workflow_node_execution.node_execution_id = event.node_execution_id
|
||||
workflow_node_execution.node_id = event.node_id
|
||||
workflow_node_execution.node_type = event.node_type.value
|
||||
workflow_node_execution.title = event.node_data.title
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value
|
||||
workflow_node_execution.created_by_role = workflow_run.created_by_role
|
||||
workflow_node_execution.created_by = workflow_run.created_by
|
||||
workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
|
||||
db.session.add(workflow_node_execution)
|
||||
db.session.commit()
|
||||
|
@ -219,33 +242,26 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
|
|||
|
||||
return workflow_node_execution
|
||||
|
||||
def _workflow_node_execution_success(self, workflow_node_execution: WorkflowNodeExecution,
|
||||
start_at: float,
|
||||
inputs: Optional[dict] = None,
|
||||
process_data: Optional[dict] = None,
|
||||
outputs: Optional[dict] = None,
|
||||
execution_metadata: Optional[dict] = None) -> WorkflowNodeExecution:
|
||||
def _handle_workflow_node_execution_success(self, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Workflow node execution success
|
||||
:param workflow_node_execution: workflow node execution
|
||||
:param start_at: start time
|
||||
:param inputs: inputs
|
||||
:param process_data: process data
|
||||
:param outputs: outputs
|
||||
:param execution_metadata: execution metadata
|
||||
:param event: queue node succeeded event
|
||||
:return:
|
||||
"""
|
||||
inputs = WorkflowEngineManager.handle_special_values(inputs)
|
||||
outputs = WorkflowEngineManager.handle_special_values(outputs)
|
||||
workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id)
|
||||
|
||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
|
||||
workflow_node_execution.elapsed_time = time.perf_counter() - start_at
|
||||
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
||||
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
|
||||
workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None
|
||||
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
||||
workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(execution_metadata)) \
|
||||
if execution_metadata else None
|
||||
workflow_node_execution.execution_metadata = (
|
||||
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
|
||||
)
|
||||
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - event.start_at).total_seconds()
|
||||
|
||||
db.session.commit()
|
||||
db.session.refresh(workflow_node_execution)
|
||||
|
@ -253,33 +269,24 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
|
|||
|
||||
return workflow_node_execution
|
||||
|
||||
def _workflow_node_execution_failed(self, workflow_node_execution: WorkflowNodeExecution,
|
||||
start_at: float,
|
||||
error: str,
|
||||
inputs: Optional[dict] = None,
|
||||
process_data: Optional[dict] = None,
|
||||
outputs: Optional[dict] = None,
|
||||
execution_metadata: Optional[dict] = None
|
||||
) -> WorkflowNodeExecution:
|
||||
def _handle_workflow_node_execution_failed(self, event: QueueNodeFailedEvent) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Workflow node execution failed
|
||||
:param workflow_node_execution: workflow node execution
|
||||
:param start_at: start time
|
||||
:param error: error message
|
||||
:param event: queue node failed event
|
||||
:return:
|
||||
"""
|
||||
inputs = WorkflowEngineManager.handle_special_values(inputs)
|
||||
outputs = WorkflowEngineManager.handle_special_values(outputs)
|
||||
workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id)
|
||||
|
||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
||||
workflow_node_execution.error = error
|
||||
workflow_node_execution.elapsed_time = time.perf_counter() - start_at
|
||||
workflow_node_execution.error = event.error
|
||||
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
||||
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
|
||||
workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None
|
||||
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
||||
workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(execution_metadata)) \
|
||||
if execution_metadata else None
|
||||
workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - event.start_at).total_seconds()
|
||||
|
||||
db.session.commit()
|
||||
db.session.refresh(workflow_node_execution)
|
||||
|
@ -287,8 +294,13 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
|
|||
|
||||
return workflow_node_execution
|
||||
|
||||
def _workflow_start_to_stream_response(self, task_id: str,
|
||||
workflow_run: WorkflowRun) -> WorkflowStartStreamResponse:
|
||||
#################################################
|
||||
# to stream responses #
|
||||
#################################################
|
||||
|
||||
def _workflow_start_to_stream_response(
|
||||
self, task_id: str, workflow_run: WorkflowRun
|
||||
) -> WorkflowStartStreamResponse:
|
||||
"""
|
||||
Workflow start to stream response.
|
||||
:param task_id: task id
|
||||
|
@ -302,13 +314,14 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
|
|||
id=workflow_run.id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
sequence_number=workflow_run.sequence_number,
|
||||
inputs=workflow_run.inputs_dict,
|
||||
created_at=int(workflow_run.created_at.timestamp())
|
||||
)
|
||||
inputs=workflow_run.inputs_dict or {},
|
||||
created_at=int(workflow_run.created_at.timestamp()),
|
||||
),
|
||||
)
|
||||
|
||||
def _workflow_finish_to_stream_response(self, task_id: str,
|
||||
workflow_run: WorkflowRun) -> WorkflowFinishStreamResponse:
|
||||
def _workflow_finish_to_stream_response(
|
||||
self, task_id: str, workflow_run: WorkflowRun
|
||||
) -> WorkflowFinishStreamResponse:
|
||||
"""
|
||||
Workflow finish to stream response.
|
||||
:param task_id: task id
|
||||
|
@ -320,16 +333,16 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
|
|||
created_by_account = workflow_run.created_by_account
|
||||
if created_by_account:
|
||||
created_by = {
|
||||
"id": created_by_account.id,
|
||||
"name": created_by_account.name,
|
||||
"email": created_by_account.email,
|
||||
'id': created_by_account.id,
|
||||
'name': created_by_account.name,
|
||||
'email': created_by_account.email,
|
||||
}
|
||||
else:
|
||||
created_by_end_user = workflow_run.created_by_end_user
|
||||
if created_by_end_user:
|
||||
created_by = {
|
||||
"id": created_by_end_user.id,
|
||||
"user": created_by_end_user.session_id,
|
||||
'id': created_by_end_user.id,
|
||||
'user': created_by_end_user.session_id,
|
||||
}
|
||||
|
||||
return WorkflowFinishStreamResponse(
|
||||
|
@ -348,14 +361,13 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
|
|||
created_by=created_by,
|
||||
created_at=int(workflow_run.created_at.timestamp()),
|
||||
finished_at=int(workflow_run.finished_at.timestamp()),
|
||||
files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict)
|
||||
)
|
||||
files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict or {}),
|
||||
),
|
||||
)
|
||||
|
||||
def _workflow_node_start_to_stream_response(self, event: QueueNodeStartedEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution) \
|
||||
-> NodeStartStreamResponse:
|
||||
def _workflow_node_start_to_stream_response(
|
||||
self, event: QueueNodeStartedEvent, task_id: str, workflow_node_execution: WorkflowNodeExecution
|
||||
) -> Optional[NodeStartStreamResponse]:
|
||||
"""
|
||||
Workflow node start to stream response.
|
||||
:param event: queue node started event
|
||||
|
@ -363,6 +375,9 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
|
|||
:param workflow_node_execution: workflow node execution
|
||||
:return:
|
||||
"""
|
||||
if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]:
|
||||
return None
|
||||
|
||||
response = NodeStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_run_id,
|
||||
|
@ -374,8 +389,13 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
|
|||
index=workflow_node_execution.index,
|
||||
predecessor_node_id=workflow_node_execution.predecessor_node_id,
|
||||
inputs=workflow_node_execution.inputs_dict,
|
||||
created_at=int(workflow_node_execution.created_at.timestamp())
|
||||
)
|
||||
created_at=int(workflow_node_execution.created_at.timestamp()),
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
),
|
||||
)
|
||||
|
||||
# extras logic
|
||||
|
@ -384,19 +404,27 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
|
|||
response.data.extras['icon'] = ToolManager.get_tool_icon(
|
||||
tenant_id=self._application_generate_entity.app_config.tenant_id,
|
||||
provider_type=node_data.provider_type,
|
||||
provider_id=node_data.provider_id
|
||||
provider_id=node_data.provider_id,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _workflow_node_finish_to_stream_response(self, task_id: str, workflow_node_execution: WorkflowNodeExecution) \
|
||||
-> NodeFinishStreamResponse:
|
||||
def _workflow_node_finish_to_stream_response(
|
||||
self,
|
||||
event: QueueNodeSucceededEvent | QueueNodeFailedEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution
|
||||
) -> Optional[NodeFinishStreamResponse]:
|
||||
"""
|
||||
Workflow node finish to stream response.
|
||||
:param event: queue node succeeded or failed event
|
||||
:param task_id: task id
|
||||
:param workflow_node_execution: workflow node execution
|
||||
:return:
|
||||
"""
|
||||
if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]:
|
||||
return None
|
||||
|
||||
return NodeFinishStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_run_id,
|
||||
|
@ -416,181 +444,155 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
|
|||
execution_metadata=workflow_node_execution.execution_metadata_dict,
|
||||
created_at=int(workflow_node_execution.created_at.timestamp()),
|
||||
finished_at=int(workflow_node_execution.finished_at.timestamp()),
|
||||
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict)
|
||||
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}),
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
),
|
||||
)
|
||||
|
||||
def _workflow_parallel_branch_start_to_stream_response(
|
||||
self,
|
||||
task_id: str,
|
||||
workflow_run: WorkflowRun,
|
||||
event: QueueParallelBranchRunStartedEvent
|
||||
) -> ParallelBranchStartStreamResponse:
|
||||
"""
|
||||
Workflow parallel branch start to stream response
|
||||
:param task_id: task id
|
||||
:param workflow_run: workflow run
|
||||
:param event: parallel branch run started event
|
||||
:return:
|
||||
"""
|
||||
return ParallelBranchStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
data=ParallelBranchStartStreamResponse.Data(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_branch_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
created_at=int(time.time()),
|
||||
)
|
||||
)
|
||||
|
||||
def _workflow_parallel_branch_finished_to_stream_response(
|
||||
self,
|
||||
task_id: str,
|
||||
workflow_run: WorkflowRun,
|
||||
event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent
|
||||
) -> ParallelBranchFinishedStreamResponse:
|
||||
"""
|
||||
Workflow parallel branch finished to stream response
|
||||
:param task_id: task id
|
||||
:param workflow_run: workflow run
|
||||
:param event: parallel branch run succeeded or failed event
|
||||
:return:
|
||||
"""
|
||||
return ParallelBranchFinishedStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
data=ParallelBranchFinishedStreamResponse.Data(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_branch_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
status='succeeded' if isinstance(event, QueueParallelBranchRunSucceededEvent) else 'failed',
|
||||
error=event.error if isinstance(event, QueueParallelBranchRunFailedEvent) else None,
|
||||
created_at=int(time.time()),
|
||||
)
|
||||
)
|
||||
|
||||
def _handle_workflow_start(self) -> WorkflowRun:
|
||||
self._task_state.start_at = time.perf_counter()
|
||||
|
||||
workflow_run = self._init_workflow_run(
|
||||
workflow=self._workflow,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING
|
||||
if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER
|
||||
else WorkflowRunTriggeredFrom.APP_RUN,
|
||||
user=self._user,
|
||||
user_inputs=self._application_generate_entity.inputs,
|
||||
system_inputs=self._workflow_system_variables
|
||||
def _workflow_iteration_start_to_stream_response(
|
||||
self,
|
||||
task_id: str,
|
||||
workflow_run: WorkflowRun,
|
||||
event: QueueIterationStartEvent
|
||||
) -> IterationNodeStartStreamResponse:
|
||||
"""
|
||||
Workflow iteration start to stream response
|
||||
:param task_id: task id
|
||||
:param workflow_run: workflow run
|
||||
:param event: iteration start event
|
||||
:return:
|
||||
"""
|
||||
return IterationNodeStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
data=IterationNodeStartStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=event.inputs or {},
|
||||
metadata=event.metadata or {},
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
)
|
||||
)
|
||||
|
||||
self._task_state.workflow_run_id = workflow_run.id
|
||||
|
||||
db.session.close()
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _handle_node_start(self, event: QueueNodeStartedEvent) -> WorkflowNodeExecution:
|
||||
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first()
|
||||
workflow_node_execution = self._init_node_execution_from_workflow_run(
|
||||
workflow_run=workflow_run,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_data.title,
|
||||
node_run_index=event.node_run_index,
|
||||
predecessor_node_id=event.predecessor_node_id
|
||||
def _workflow_iteration_next_to_stream_response(self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent) -> IterationNodeNextStreamResponse:
|
||||
"""
|
||||
Workflow iteration next to stream response
|
||||
:param task_id: task id
|
||||
:param workflow_run: workflow run
|
||||
:param event: iteration next event
|
||||
:return:
|
||||
"""
|
||||
return IterationNodeNextStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
data=IterationNodeNextStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
index=event.index,
|
||||
pre_iteration_output=event.output,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
)
|
||||
)
|
||||
|
||||
latest_node_execution_info = NodeExecutionInfo(
|
||||
workflow_node_execution_id=workflow_node_execution.id,
|
||||
node_type=event.node_type,
|
||||
start_at=time.perf_counter()
|
||||
)
|
||||
|
||||
self._task_state.ran_node_execution_infos[event.node_id] = latest_node_execution_info
|
||||
self._task_state.latest_node_execution_info = latest_node_execution_info
|
||||
|
||||
self._task_state.total_steps += 1
|
||||
|
||||
db.session.close()
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> WorkflowNodeExecution:
|
||||
current_node_execution = self._task_state.ran_node_execution_infos[event.node_id]
|
||||
workflow_node_execution = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.id == current_node_execution.workflow_node_execution_id).first()
|
||||
|
||||
execution_metadata = event.execution_metadata if isinstance(event, QueueNodeSucceededEvent) else None
|
||||
|
||||
if self._iteration_state and self._iteration_state.current_iterations:
|
||||
if not execution_metadata:
|
||||
execution_metadata = {}
|
||||
current_iteration_data = None
|
||||
for iteration_node_id in self._iteration_state.current_iterations:
|
||||
data = self._iteration_state.current_iterations[iteration_node_id]
|
||||
if data.parent_iteration_id == None:
|
||||
current_iteration_data = data
|
||||
break
|
||||
|
||||
if current_iteration_data:
|
||||
execution_metadata[NodeRunMetadataKey.ITERATION_ID] = current_iteration_data.iteration_id
|
||||
execution_metadata[NodeRunMetadataKey.ITERATION_INDEX] = current_iteration_data.current_index
|
||||
|
||||
if isinstance(event, QueueNodeSucceededEvent):
|
||||
workflow_node_execution = self._workflow_node_execution_success(
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
start_at=current_node_execution.start_at,
|
||||
inputs=event.inputs,
|
||||
process_data=event.process_data,
|
||||
def _workflow_iteration_completed_to_stream_response(self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent) -> IterationNodeCompletedStreamResponse:
|
||||
"""
|
||||
Workflow iteration completed to stream response
|
||||
:param task_id: task id
|
||||
:param workflow_run: workflow run
|
||||
:param event: iteration completed event
|
||||
:return:
|
||||
"""
|
||||
return IterationNodeCompletedStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
data=IterationNodeCompletedStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
outputs=event.outputs,
|
||||
execution_metadata=execution_metadata
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=event.inputs or {},
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
error=None,
|
||||
elapsed_time=(datetime.now(timezone.utc).replace(tzinfo=None) - event.start_at).total_seconds(),
|
||||
total_tokens=event.metadata.get('total_tokens', 0) if event.metadata else 0,
|
||||
execution_metadata=event.metadata,
|
||||
finished_at=int(time.time()),
|
||||
steps=event.steps,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
)
|
||||
|
||||
if execution_metadata and execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
|
||||
self._task_state.total_tokens += (
|
||||
int(execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)))
|
||||
|
||||
if self._iteration_state:
|
||||
for iteration_node_id in self._iteration_state.current_iterations:
|
||||
data = self._iteration_state.current_iterations[iteration_node_id]
|
||||
if execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
|
||||
data.total_tokens += int(execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS))
|
||||
|
||||
if workflow_node_execution.node_type == NodeType.LLM.value:
|
||||
outputs = workflow_node_execution.outputs_dict
|
||||
usage_dict = outputs.get('usage', {})
|
||||
self._task_state.metadata['usage'] = usage_dict
|
||||
else:
|
||||
workflow_node_execution = self._workflow_node_execution_failed(
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
start_at=current_node_execution.start_at,
|
||||
error=event.error,
|
||||
inputs=event.inputs,
|
||||
process_data=event.process_data,
|
||||
outputs=event.outputs,
|
||||
execution_metadata=execution_metadata
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_workflow_finished(
|
||||
self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent,
|
||||
conversation_id: Optional[str] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> Optional[WorkflowRun]:
|
||||
workflow_run = db.session.query(WorkflowRun).filter(
|
||||
WorkflowRun.id == self._task_state.workflow_run_id).first()
|
||||
if not workflow_run:
|
||||
return None
|
||||
|
||||
if conversation_id is None:
|
||||
conversation_id = self._application_generate_entity.inputs.get('sys.conversation_id')
|
||||
if isinstance(event, QueueStopEvent):
|
||||
workflow_run = self._workflow_run_failed(
|
||||
workflow_run=workflow_run,
|
||||
total_tokens=self._task_state.total_tokens,
|
||||
total_steps=self._task_state.total_steps,
|
||||
status=WorkflowRunStatus.STOPPED,
|
||||
error='Workflow stopped.',
|
||||
conversation_id=conversation_id,
|
||||
trace_manager=trace_manager
|
||||
)
|
||||
|
||||
latest_node_execution_info = self._task_state.latest_node_execution_info
|
||||
if latest_node_execution_info:
|
||||
workflow_node_execution = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.id == latest_node_execution_info.workflow_node_execution_id).first()
|
||||
if (workflow_node_execution
|
||||
and workflow_node_execution.status == WorkflowNodeExecutionStatus.RUNNING.value):
|
||||
self._workflow_node_execution_failed(
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
start_at=latest_node_execution_info.start_at,
|
||||
error='Workflow stopped.'
|
||||
)
|
||||
elif isinstance(event, QueueWorkflowFailedEvent):
|
||||
workflow_run = self._workflow_run_failed(
|
||||
workflow_run=workflow_run,
|
||||
total_tokens=self._task_state.total_tokens,
|
||||
total_steps=self._task_state.total_steps,
|
||||
status=WorkflowRunStatus.FAILED,
|
||||
error=event.error,
|
||||
conversation_id=conversation_id,
|
||||
trace_manager=trace_manager
|
||||
)
|
||||
else:
|
||||
if self._task_state.latest_node_execution_info:
|
||||
workflow_node_execution = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.id == self._task_state.latest_node_execution_info.workflow_node_execution_id).first()
|
||||
outputs = workflow_node_execution.outputs
|
||||
else:
|
||||
outputs = None
|
||||
|
||||
workflow_run = self._workflow_run_success(
|
||||
workflow_run=workflow_run,
|
||||
total_tokens=self._task_state.total_tokens,
|
||||
total_steps=self._task_state.total_steps,
|
||||
outputs=outputs,
|
||||
conversation_id=conversation_id,
|
||||
trace_manager=trace_manager
|
||||
)
|
||||
|
||||
self._task_state.workflow_run_id = workflow_run.id
|
||||
|
||||
db.session.close()
|
||||
|
||||
return workflow_run
|
||||
)
|
||||
|
||||
def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> list[dict]:
|
||||
"""
|
||||
|
@ -647,3 +649,40 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
|
|||
return value.to_dict()
|
||||
|
||||
return None
|
||||
|
||||
def _refetch_workflow_run(self, workflow_run_id: str) -> WorkflowRun:
|
||||
"""
|
||||
Refetch workflow run
|
||||
:param workflow_run_id: workflow run id
|
||||
:return:
|
||||
"""
|
||||
workflow_run = db.session.query(WorkflowRun).filter(
|
||||
WorkflowRun.id == workflow_run_id).first()
|
||||
|
||||
if not workflow_run:
|
||||
raise Exception(f'Workflow run not found: {workflow_run_id}')
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _refetch_workflow_node_execution(self, node_execution_id: str) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Refetch workflow node execution
|
||||
:param node_execution_id: workflow node execution id
|
||||
:return:
|
||||
"""
|
||||
workflow_node_execution = (
|
||||
db.session.query(WorkflowNodeExecution)
|
||||
.filter(
|
||||
WorkflowNodeExecution.tenant_id == self._application_generate_entity.app_config.tenant_id,
|
||||
WorkflowNodeExecution.app_id == self._application_generate_entity.app_config.app_id,
|
||||
WorkflowNodeExecution.workflow_id == self._workflow.id,
|
||||
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
WorkflowNodeExecution.node_execution_id == node_execution_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not workflow_node_execution:
|
||||
raise Exception(f'Workflow node execution not found: {node_execution_id}')
|
||||
|
||||
return workflow_node_execution
|
|
@ -1,16 +0,0 @@
|
|||
from typing import Any, Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
|
||||
from core.app.entities.task_entities import AdvancedChatTaskState, WorkflowTaskState
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from models.account import Account
|
||||
from models.model import EndUser
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
class WorkflowCycleStateManager:
|
||||
_application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity]
|
||||
_workflow: Workflow
|
||||
_user: Union[Account, EndUser]
|
||||
_task_state: Union[AdvancedChatTaskState, WorkflowTaskState]
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
|
@ -1,290 +0,0 @@
|
|||
import json
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Union
|
||||
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
IterationNodeCompletedStreamResponse,
|
||||
IterationNodeNextStreamResponse,
|
||||
IterationNodeStartStreamResponse,
|
||||
NodeExecutionInfo,
|
||||
WorkflowIterationState,
|
||||
)
|
||||
from core.app.task_pipeline.workflow_cycle_state_manager import WorkflowCycleStateManager
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||
from extensions.ext_database import db
|
||||
from models.workflow import (
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionStatus,
|
||||
WorkflowNodeExecutionTriggeredFrom,
|
||||
WorkflowRun,
|
||||
)
|
||||
|
||||
|
||||
class WorkflowIterationCycleManage(WorkflowCycleStateManager):
|
||||
_iteration_state: WorkflowIterationState = None
|
||||
|
||||
def _init_iteration_state(self) -> WorkflowIterationState:
|
||||
if not self._iteration_state:
|
||||
self._iteration_state = WorkflowIterationState(
|
||||
current_iterations={}
|
||||
)
|
||||
|
||||
def _handle_iteration_to_stream_response(self, task_id: str, event: QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent) \
|
||||
-> Union[IterationNodeStartStreamResponse, IterationNodeNextStreamResponse, IterationNodeCompletedStreamResponse]:
|
||||
"""
|
||||
Handle iteration to stream response
|
||||
:param task_id: task id
|
||||
:param event: iteration event
|
||||
:return:
|
||||
"""
|
||||
if isinstance(event, QueueIterationStartEvent):
|
||||
return IterationNodeStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=self._task_state.workflow_run_id,
|
||||
data=IterationNodeStartStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=event.inputs,
|
||||
metadata=event.metadata
|
||||
)
|
||||
)
|
||||
elif isinstance(event, QueueIterationNextEvent):
|
||||
current_iteration = self._iteration_state.current_iterations[event.node_id]
|
||||
|
||||
return IterationNodeNextStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=self._task_state.workflow_run_id,
|
||||
data=IterationNodeNextStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=current_iteration.node_data.title,
|
||||
index=event.index,
|
||||
pre_iteration_output=event.output,
|
||||
created_at=int(time.time()),
|
||||
extras={}
|
||||
)
|
||||
)
|
||||
elif isinstance(event, QueueIterationCompletedEvent):
|
||||
current_iteration = self._iteration_state.current_iterations[event.node_id]
|
||||
|
||||
return IterationNodeCompletedStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=self._task_state.workflow_run_id,
|
||||
data=IterationNodeCompletedStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=current_iteration.node_data.title,
|
||||
outputs=event.outputs,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=current_iteration.inputs,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
error=None,
|
||||
elapsed_time=time.perf_counter() - current_iteration.started_at,
|
||||
total_tokens=current_iteration.total_tokens,
|
||||
execution_metadata={
|
||||
'total_tokens': current_iteration.total_tokens,
|
||||
},
|
||||
finished_at=int(time.time()),
|
||||
steps=current_iteration.current_index
|
||||
)
|
||||
)
|
||||
|
||||
def _init_iteration_execution_from_workflow_run(self,
|
||||
workflow_run: WorkflowRun,
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
node_title: str,
|
||||
node_run_index: int = 1,
|
||||
inputs: Optional[dict] = None,
|
||||
predecessor_node_id: Optional[str] = None
|
||||
) -> WorkflowNodeExecution:
|
||||
workflow_node_execution = WorkflowNodeExecution(
|
||||
tenant_id=workflow_run.tenant_id,
|
||||
app_id=workflow_run.app_id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
workflow_run_id=workflow_run.id,
|
||||
predecessor_node_id=predecessor_node_id,
|
||||
index=node_run_index,
|
||||
node_id=node_id,
|
||||
node_type=node_type.value,
|
||||
inputs=json.dumps(inputs) if inputs else None,
|
||||
title=node_title,
|
||||
status=WorkflowNodeExecutionStatus.RUNNING.value,
|
||||
created_by_role=workflow_run.created_by_role,
|
||||
created_by=workflow_run.created_by,
|
||||
execution_metadata=json.dumps({
|
||||
'started_run_index': node_run_index + 1,
|
||||
'current_index': 0,
|
||||
'steps_boundary': [],
|
||||
}),
|
||||
created_at=datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
)
|
||||
|
||||
db.session.add(workflow_node_execution)
|
||||
db.session.commit()
|
||||
db.session.refresh(workflow_node_execution)
|
||||
db.session.close()
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_iteration_operation(self, event: QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent) -> WorkflowNodeExecution:
|
||||
if isinstance(event, QueueIterationStartEvent):
|
||||
return self._handle_iteration_started(event)
|
||||
elif isinstance(event, QueueIterationNextEvent):
|
||||
return self._handle_iteration_next(event)
|
||||
elif isinstance(event, QueueIterationCompletedEvent):
|
||||
return self._handle_iteration_completed(event)
|
||||
|
||||
def _handle_iteration_started(self, event: QueueIterationStartEvent) -> WorkflowNodeExecution:
|
||||
self._init_iteration_state()
|
||||
|
||||
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first()
|
||||
workflow_node_execution = self._init_iteration_execution_from_workflow_run(
|
||||
workflow_run=workflow_run,
|
||||
node_id=event.node_id,
|
||||
node_type=NodeType.ITERATION,
|
||||
node_title=event.node_data.title,
|
||||
node_run_index=event.node_run_index,
|
||||
inputs=event.inputs,
|
||||
predecessor_node_id=event.predecessor_node_id
|
||||
)
|
||||
|
||||
latest_node_execution_info = NodeExecutionInfo(
|
||||
workflow_node_execution_id=workflow_node_execution.id,
|
||||
node_type=NodeType.ITERATION,
|
||||
start_at=time.perf_counter()
|
||||
)
|
||||
|
||||
self._task_state.ran_node_execution_infos[event.node_id] = latest_node_execution_info
|
||||
self._task_state.latest_node_execution_info = latest_node_execution_info
|
||||
|
||||
self._iteration_state.current_iterations[event.node_id] = WorkflowIterationState.Data(
|
||||
parent_iteration_id=None,
|
||||
iteration_id=event.node_id,
|
||||
current_index=0,
|
||||
iteration_steps_boundary=[],
|
||||
node_execution_id=workflow_node_execution.id,
|
||||
started_at=time.perf_counter(),
|
||||
inputs=event.inputs,
|
||||
total_tokens=0,
|
||||
node_data=event.node_data
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_iteration_next(self, event: QueueIterationNextEvent) -> WorkflowNodeExecution:
|
||||
if event.node_id not in self._iteration_state.current_iterations:
|
||||
return
|
||||
current_iteration = self._iteration_state.current_iterations[event.node_id]
|
||||
current_iteration.current_index = event.index
|
||||
current_iteration.iteration_steps_boundary.append(event.node_run_index)
|
||||
workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.id == current_iteration.node_execution_id
|
||||
).first()
|
||||
|
||||
original_node_execution_metadata = workflow_node_execution.execution_metadata_dict
|
||||
if original_node_execution_metadata:
|
||||
original_node_execution_metadata['current_index'] = event.index
|
||||
original_node_execution_metadata['steps_boundary'] = current_iteration.iteration_steps_boundary
|
||||
original_node_execution_metadata['total_tokens'] = current_iteration.total_tokens
|
||||
workflow_node_execution.execution_metadata = json.dumps(original_node_execution_metadata)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
db.session.close()
|
||||
|
||||
def _handle_iteration_completed(self, event: QueueIterationCompletedEvent):
|
||||
if event.node_id not in self._iteration_state.current_iterations:
|
||||
return
|
||||
|
||||
current_iteration = self._iteration_state.current_iterations[event.node_id]
|
||||
workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.id == current_iteration.node_execution_id
|
||||
).first()
|
||||
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
|
||||
workflow_node_execution.outputs = json.dumps(WorkflowEngineManager.handle_special_values(event.outputs)) if event.outputs else None
|
||||
workflow_node_execution.elapsed_time = time.perf_counter() - current_iteration.started_at
|
||||
|
||||
original_node_execution_metadata = workflow_node_execution.execution_metadata_dict
|
||||
if original_node_execution_metadata:
|
||||
original_node_execution_metadata['steps_boundary'] = current_iteration.iteration_steps_boundary
|
||||
original_node_execution_metadata['total_tokens'] = current_iteration.total_tokens
|
||||
workflow_node_execution.execution_metadata = json.dumps(original_node_execution_metadata)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
# remove current iteration
|
||||
self._iteration_state.current_iterations.pop(event.node_id, None)
|
||||
|
||||
# set latest node execution info
|
||||
latest_node_execution_info = NodeExecutionInfo(
|
||||
workflow_node_execution_id=workflow_node_execution.id,
|
||||
node_type=NodeType.ITERATION,
|
||||
start_at=time.perf_counter()
|
||||
)
|
||||
|
||||
self._task_state.latest_node_execution_info = latest_node_execution_info
|
||||
|
||||
db.session.close()
|
||||
|
||||
def _handle_iteration_exception(self, task_id: str, error: str) -> Generator[IterationNodeCompletedStreamResponse, None, None]:
|
||||
"""
|
||||
Handle iteration exception
|
||||
"""
|
||||
if not self._iteration_state or not self._iteration_state.current_iterations:
|
||||
return
|
||||
|
||||
for node_id, current_iteration in self._iteration_state.current_iterations.items():
|
||||
workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.id == current_iteration.node_execution_id
|
||||
).first()
|
||||
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
||||
workflow_node_execution.error = error
|
||||
workflow_node_execution.elapsed_time = time.perf_counter() - current_iteration.started_at
|
||||
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
||||
yield IterationNodeCompletedStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=self._task_state.workflow_run_id,
|
||||
data=IterationNodeCompletedStreamResponse.Data(
|
||||
id=node_id,
|
||||
node_id=node_id,
|
||||
node_type=NodeType.ITERATION.value,
|
||||
title=current_iteration.node_data.title,
|
||||
outputs={},
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=current_iteration.inputs,
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=error,
|
||||
elapsed_time=time.perf_counter() - current_iteration.started_at,
|
||||
total_tokens=current_iteration.total_tokens,
|
||||
execution_metadata={
|
||||
'total_tokens': current_iteration.total_tokens,
|
||||
},
|
||||
finished_at=int(time.time()),
|
||||
steps=current_iteration.current_index
|
||||
)
|
||||
)
|
|
@ -65,7 +65,7 @@ class Extensible:
|
|||
if os.path.exists(builtin_file_path):
|
||||
with open(builtin_file_path, encoding='utf-8') as f:
|
||||
position = int(f.read().strip())
|
||||
position_map[extension_name] = position
|
||||
position_map[extension_name] = position
|
||||
|
||||
if (extension_name + '.py') not in file_names:
|
||||
logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.")
|
||||
|
|
|
@ -15,12 +15,6 @@ from core.helper.code_executor.template_transformer import TemplateTransformer
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Code Executor
|
||||
CODE_EXECUTION_ENDPOINT = dify_config.CODE_EXECUTION_ENDPOINT
|
||||
CODE_EXECUTION_API_KEY = dify_config.CODE_EXECUTION_API_KEY
|
||||
|
||||
CODE_EXECUTION_TIMEOUT = Timeout(connect=10, write=10, read=60, pool=None)
|
||||
|
||||
class CodeExecutionException(Exception):
|
||||
pass
|
||||
|
||||
|
@ -71,10 +65,10 @@ class CodeExecutor:
|
|||
:param code: code
|
||||
:return:
|
||||
"""
|
||||
url = URL(CODE_EXECUTION_ENDPOINT) / 'v1' / 'sandbox' / 'run'
|
||||
url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT)) / 'v1' / 'sandbox' / 'run'
|
||||
|
||||
headers = {
|
||||
'X-Api-Key': CODE_EXECUTION_API_KEY
|
||||
'X-Api-Key': dify_config.CODE_EXECUTION_API_KEY
|
||||
}
|
||||
|
||||
data = {
|
||||
|
@ -85,7 +79,12 @@ class CodeExecutor:
|
|||
}
|
||||
|
||||
try:
|
||||
response = post(str(url), json=data, headers=headers, timeout=CODE_EXECUTION_TIMEOUT)
|
||||
response = post(str(url), json=data, headers=headers,
|
||||
timeout=Timeout(
|
||||
connect=dify_config.CODE_EXECUTION_CONNECT_TIMEOUT,
|
||||
read=dify_config.CODE_EXECUTION_READ_TIMEOUT,
|
||||
write=dify_config.CODE_EXECUTION_WRITE_TIMEOUT,
|
||||
pool=None))
|
||||
if response.status_code == 503:
|
||||
raise CodeExecutionException('Code execution service is unavailable')
|
||||
elif response.status_code != 200:
|
||||
|
@ -96,7 +95,7 @@ class CodeExecutor:
|
|||
raise CodeExecutionException('Failed to execute code, which is likely a network issue,'
|
||||
' please check if the sandbox service is running.'
|
||||
f' ( Error: {str(e)} )')
|
||||
|
||||
|
||||
try:
|
||||
response = response.json()
|
||||
except:
|
||||
|
@ -104,12 +103,12 @@ class CodeExecutor:
|
|||
|
||||
if (code := response.get('code')) != 0:
|
||||
raise CodeExecutionException(f"Got error code: {code}. Got error msg: {response.get('message')}")
|
||||
|
||||
|
||||
response = CodeExecutionResponse(**response)
|
||||
|
||||
|
||||
if response.data.error:
|
||||
raise CodeExecutionException(response.data.error)
|
||||
|
||||
|
||||
return response.data.stdout or ''
|
||||
|
||||
@classmethod
|
||||
|
@ -133,4 +132,3 @@ class CodeExecutor:
|
|||
raise e
|
||||
|
||||
return template_transformer.transform_response(response)
|
||||
|
|
@ -79,7 +79,7 @@ def is_filtered(
|
|||
name_func: Callable[[Any], str],
|
||||
) -> bool:
|
||||
"""
|
||||
Chcek if the object should be filtered out.
|
||||
Check if the object should be filtered out.
|
||||
Overall logic: exclude > include > pin
|
||||
:param include_set: the set of names to be included
|
||||
:param exclude_set: the set of names to be excluded
|
||||
|
|
|
@ -16,9 +16,7 @@ from configs import dify_config
|
|||
from core.errors.error import ProviderTokenNotInitError
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType, PriceType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
|
@ -255,11 +253,8 @@ class IndexingRunner:
|
|||
tenant_id=tenant_id,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
)
|
||||
tokens = 0
|
||||
preview_texts = []
|
||||
total_segments = 0
|
||||
total_price = 0
|
||||
currency = 'USD'
|
||||
index_type = doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
all_text_docs = []
|
||||
|
@ -286,54 +281,22 @@ class IndexingRunner:
|
|||
for document in documents:
|
||||
if len(preview_texts) < 5:
|
||||
preview_texts.append(document.page_content)
|
||||
if indexing_technique == 'high_quality' or embedding_model_instance:
|
||||
tokens += embedding_model_instance.get_text_embedding_num_tokens(
|
||||
texts=[self.filter_string(document.page_content)]
|
||||
)
|
||||
|
||||
if doc_form and doc_form == 'qa_model':
|
||||
model_instance = self.model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
model_type_instance = model_instance.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
if len(preview_texts) > 0:
|
||||
# qa model document
|
||||
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],
|
||||
doc_language)
|
||||
document_qa_list = self.format_split_text(response)
|
||||
price_info = model_type_instance.get_price(
|
||||
model=model_instance.model,
|
||||
credentials=model_instance.credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=total_segments * 2000,
|
||||
)
|
||||
|
||||
return {
|
||||
"total_segments": total_segments * 20,
|
||||
"tokens": total_segments * 2000,
|
||||
"total_price": '{:f}'.format(price_info.total_amount),
|
||||
"currency": price_info.currency,
|
||||
"qa_preview": document_qa_list,
|
||||
"preview": preview_texts
|
||||
}
|
||||
if embedding_model_instance:
|
||||
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_instance.model_type_instance)
|
||||
embedding_price_info = embedding_model_type_instance.get_price(
|
||||
model=embedding_model_instance.model,
|
||||
credentials=embedding_model_instance.credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=tokens
|
||||
)
|
||||
total_price = '{:f}'.format(embedding_price_info.total_amount)
|
||||
currency = embedding_price_info.currency
|
||||
return {
|
||||
"total_segments": total_segments,
|
||||
"tokens": tokens,
|
||||
"total_price": total_price,
|
||||
"currency": currency,
|
||||
"preview": preview_texts
|
||||
}
|
||||
|
||||
|
@ -531,7 +494,7 @@ class IndexingRunner:
|
|||
hash = helper.generate_text_hash(document_node.page_content)
|
||||
document_node.metadata['doc_id'] = doc_id
|
||||
document_node.metadata['doc_hash'] = hash
|
||||
# delete Spliter character
|
||||
# delete Splitter character
|
||||
page_content = document_node.page_content
|
||||
if page_content.startswith(".") or page_content.startswith("。"):
|
||||
page_content = page_content[1:]
|
||||
|
|
|
@ -87,7 +87,7 @@ Here is a task description for which I would like you to create a high-quality p
|
|||
{{TASK_DESCRIPTION}}
|
||||
</task_description>
|
||||
Based on task description, please create a well-structured prompt template that another AI could use to consistently complete the task. The prompt template should include:
|
||||
- Do not inlcude <input> or <output> section and variables in the prompt, assume user will add them at their own will.
|
||||
- Do not include <input> or <output> section and variables in the prompt, assume user will add them at their own will.
|
||||
- Clear instructions for the AI that will be using this prompt, demarcated with <instructions> tags. The instructions should provide step-by-step directions on how to complete the task using the input variables. Also Specifies in the instructions that the output should not contain any xml tag.
|
||||
- Relevant examples if needed to clarify the task further, demarcated with <example> tags. Do not include variables in the prompt. Give three pairs of input and output examples.
|
||||
- Include other relevant sections demarcated with appropriate XML tags like <examples>, <instructions>.
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import logging
|
||||
import os
|
||||
from collections.abc import Callable, Generator
|
||||
from collections.abc import Callable, Generator, Sequence
|
||||
from typing import IO, Optional, Union, cast
|
||||
|
||||
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
||||
|
@ -41,7 +41,7 @@ class ModelInstance:
|
|||
configuration=provider_model_bundle.configuration,
|
||||
model_type=provider_model_bundle.model_type_instance.model_type,
|
||||
model=model,
|
||||
credentials=self.credentials
|
||||
credentials=self.credentials,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
@ -54,10 +54,7 @@ class ModelInstance:
|
|||
"""
|
||||
configuration = provider_model_bundle.configuration
|
||||
model_type = provider_model_bundle.model_type_instance.model_type
|
||||
credentials = configuration.get_current_credentials(
|
||||
model_type=model_type,
|
||||
model=model
|
||||
)
|
||||
credentials = configuration.get_current_credentials(model_type=model_type, model=model)
|
||||
|
||||
if credentials is None:
|
||||
raise ProviderTokenNotInitError(f"Model {model} credentials is not initialized.")
|
||||
|
@ -65,10 +62,9 @@ class ModelInstance:
|
|||
return credentials
|
||||
|
||||
@staticmethod
|
||||
def _get_load_balancing_manager(configuration: ProviderConfiguration,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
credentials: dict) -> Optional["LBModelManager"]:
|
||||
def _get_load_balancing_manager(
|
||||
configuration: ProviderConfiguration, model_type: ModelType, model: str, credentials: dict
|
||||
) -> Optional["LBModelManager"]:
|
||||
"""
|
||||
Get load balancing model credentials
|
||||
:param configuration: provider configuration
|
||||
|
@ -81,8 +77,7 @@ class ModelInstance:
|
|||
current_model_setting = None
|
||||
# check if model is disabled by admin
|
||||
for model_setting in configuration.model_settings:
|
||||
if (model_setting.model_type == model_type
|
||||
and model_setting.model == model):
|
||||
if model_setting.model_type == model_type and model_setting.model == model:
|
||||
current_model_setting = model_setting
|
||||
break
|
||||
|
||||
|
@ -95,17 +90,23 @@ class ModelInstance:
|
|||
model_type=model_type,
|
||||
model=model,
|
||||
load_balancing_configs=current_model_setting.load_balancing_configs,
|
||||
managed_credentials=credentials if configuration.custom_configuration.provider else None
|
||||
managed_credentials=credentials if configuration.custom_configuration.provider else None,
|
||||
)
|
||||
|
||||
return lb_model_manager
|
||||
|
||||
return None
|
||||
|
||||
def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
def invoke_llm(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: Optional[dict] = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
|
@ -132,11 +133,12 @@ class ModelInstance:
|
|||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
def get_llm_num_tokens(self, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
def get_llm_num_tokens(
|
||||
self, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None
|
||||
) -> int:
|
||||
"""
|
||||
Get number of tokens for llm
|
||||
|
||||
|
@ -153,11 +155,10 @@ class ModelInstance:
|
|||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
tools=tools
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
def invoke_text_embedding(self, texts: list[str], user: Optional[str] = None) \
|
||||
-> TextEmbeddingResult:
|
||||
def invoke_text_embedding(self, texts: list[str], user: Optional[str] = None) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
|
@ -174,7 +175,7 @@ class ModelInstance:
|
|||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
texts=texts,
|
||||
user=user
|
||||
user=user,
|
||||
)
|
||||
|
||||
def get_text_embedding_num_tokens(self, texts: list[str]) -> int:
|
||||
|
@ -192,13 +193,17 @@ class ModelInstance:
|
|||
function=self.model_type_instance.get_num_tokens,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
texts=texts
|
||||
texts=texts,
|
||||
)
|
||||
|
||||
def invoke_rerank(self, query: str, docs: list[str], score_threshold: Optional[float] = None,
|
||||
top_n: Optional[int] = None,
|
||||
user: Optional[str] = None) \
|
||||
-> RerankResult:
|
||||
def invoke_rerank(
|
||||
self,
|
||||
query: str,
|
||||
docs: list[str],
|
||||
score_threshold: Optional[float] = None,
|
||||
top_n: Optional[int] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> RerankResult:
|
||||
"""
|
||||
Invoke rerank model
|
||||
|
||||
|
@ -221,11 +226,10 @@ class ModelInstance:
|
|||
docs=docs,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_n,
|
||||
user=user
|
||||
user=user,
|
||||
)
|
||||
|
||||
def invoke_moderation(self, text: str, user: Optional[str] = None) \
|
||||
-> bool:
|
||||
def invoke_moderation(self, text: str, user: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Invoke moderation model
|
||||
|
||||
|
@ -242,11 +246,10 @@ class ModelInstance:
|
|||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
text=text,
|
||||
user=user
|
||||
user=user,
|
||||
)
|
||||
|
||||
def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) \
|
||||
-> str:
|
||||
def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) -> str:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
|
@ -263,11 +266,10 @@ class ModelInstance:
|
|||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
file=file,
|
||||
user=user
|
||||
user=user,
|
||||
)
|
||||
|
||||
def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) \
|
||||
-> str:
|
||||
def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) -> str:
|
||||
"""
|
||||
Invoke large language tts model
|
||||
|
||||
|
@ -288,7 +290,7 @@ class ModelInstance:
|
|||
content_text=content_text,
|
||||
user=user,
|
||||
tenant_id=tenant_id,
|
||||
voice=voice
|
||||
voice=voice,
|
||||
)
|
||||
|
||||
def _round_robin_invoke(self, function: Callable, *args, **kwargs):
|
||||
|
@ -312,8 +314,8 @@ class ModelInstance:
|
|||
raise last_exception
|
||||
|
||||
try:
|
||||
if 'credentials' in kwargs:
|
||||
del kwargs['credentials']
|
||||
if "credentials" in kwargs:
|
||||
del kwargs["credentials"]
|
||||
return function(*args, **kwargs, credentials=lb_config.credentials)
|
||||
except InvokeRateLimitError as e:
|
||||
# expire in 60 seconds
|
||||
|
@ -340,9 +342,7 @@ class ModelInstance:
|
|||
|
||||
self.model_type_instance = cast(TTSModel, self.model_type_instance)
|
||||
return self.model_type_instance.get_tts_model_voices(
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
language=language
|
||||
model=self.model, credentials=self.credentials, language=language
|
||||
)
|
||||
|
||||
|
||||
|
@ -363,9 +363,7 @@ class ModelManager:
|
|||
return self.get_default_model_instance(tenant_id, model_type)
|
||||
|
||||
provider_model_bundle = self._provider_manager.get_provider_model_bundle(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model_type=model_type
|
||||
tenant_id=tenant_id, provider=provider, model_type=model_type
|
||||
)
|
||||
|
||||
return ModelInstance(provider_model_bundle, model)
|
||||
|
@ -386,10 +384,7 @@ class ModelManager:
|
|||
:param model_type: model type
|
||||
:return:
|
||||
"""
|
||||
default_model_entity = self._provider_manager.get_default_model(
|
||||
tenant_id=tenant_id,
|
||||
model_type=model_type
|
||||
)
|
||||
default_model_entity = self._provider_manager.get_default_model(tenant_id=tenant_id, model_type=model_type)
|
||||
|
||||
if not default_model_entity:
|
||||
raise ProviderTokenNotInitError(f"Default model not found for {model_type}")
|
||||
|
@ -398,17 +393,20 @@ class ModelManager:
|
|||
tenant_id=tenant_id,
|
||||
provider=default_model_entity.provider.provider,
|
||||
model_type=model_type,
|
||||
model=default_model_entity.model
|
||||
model=default_model_entity.model,
|
||||
)
|
||||
|
||||
|
||||
class LBModelManager:
|
||||
def __init__(self, tenant_id: str,
|
||||
provider: str,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
load_balancing_configs: list[ModelLoadBalancingConfiguration],
|
||||
managed_credentials: Optional[dict] = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
load_balancing_configs: list[ModelLoadBalancingConfiguration],
|
||||
managed_credentials: Optional[dict] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Load balancing model manager
|
||||
:param tenant_id: tenant_id
|
||||
|
@ -439,10 +437,7 @@ class LBModelManager:
|
|||
:return:
|
||||
"""
|
||||
cache_key = "model_lb_index:{}:{}:{}:{}".format(
|
||||
self._tenant_id,
|
||||
self._provider,
|
||||
self._model_type.value,
|
||||
self._model
|
||||
self._tenant_id, self._provider, self._model_type.value, self._model
|
||||
)
|
||||
|
||||
cooldown_load_balancing_configs = []
|
||||
|
@ -473,10 +468,12 @@ class LBModelManager:
|
|||
|
||||
continue
|
||||
|
||||
if bool(os.environ.get("DEBUG", 'False').lower() == 'true'):
|
||||
logger.info(f"Model LB\nid: {config.id}\nname:{config.name}\n"
|
||||
f"tenant_id: {self._tenant_id}\nprovider: {self._provider}\n"
|
||||
f"model_type: {self._model_type.value}\nmodel: {self._model}")
|
||||
if bool(os.environ.get("DEBUG", "False").lower() == "true"):
|
||||
logger.info(
|
||||
f"Model LB\nid: {config.id}\nname:{config.name}\n"
|
||||
f"tenant_id: {self._tenant_id}\nprovider: {self._provider}\n"
|
||||
f"model_type: {self._model_type.value}\nmodel: {self._model}"
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
|
@ -490,14 +487,10 @@ class LBModelManager:
|
|||
:return:
|
||||
"""
|
||||
cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format(
|
||||
self._tenant_id,
|
||||
self._provider,
|
||||
self._model_type.value,
|
||||
self._model,
|
||||
config.id
|
||||
self._tenant_id, self._provider, self._model_type.value, self._model, config.id
|
||||
)
|
||||
|
||||
redis_client.setex(cooldown_cache_key, expire, 'true')
|
||||
redis_client.setex(cooldown_cache_key, expire, "true")
|
||||
|
||||
def in_cooldown(self, config: ModelLoadBalancingConfiguration) -> bool:
|
||||
"""
|
||||
|
@ -506,11 +499,7 @@ class LBModelManager:
|
|||
:return:
|
||||
"""
|
||||
cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format(
|
||||
self._tenant_id,
|
||||
self._provider,
|
||||
self._model_type.value,
|
||||
self._model,
|
||||
config.id
|
||||
self._tenant_id, self._provider, self._model_type.value, self._model, config.id
|
||||
)
|
||||
|
||||
res = redis_client.exists(cooldown_cache_key)
|
||||
|
@ -518,11 +507,9 @@ class LBModelManager:
|
|||
return res
|
||||
|
||||
@staticmethod
|
||||
def get_config_in_cooldown_and_ttl(tenant_id: str,
|
||||
provider: str,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
config_id: str) -> tuple[bool, int]:
|
||||
def get_config_in_cooldown_and_ttl(
|
||||
tenant_id: str, provider: str, model_type: ModelType, model: str, config_id: str
|
||||
) -> tuple[bool, int]:
|
||||
"""
|
||||
Get model load balancing config is in cooldown and ttl
|
||||
:param tenant_id: workspace id
|
||||
|
@ -533,11 +520,7 @@ class LBModelManager:
|
|||
:return:
|
||||
"""
|
||||
cooldown_cache_key = "model_lb_index:cooldown:{}:{}:{}:{}:{}".format(
|
||||
tenant_id,
|
||||
provider,
|
||||
model_type.value,
|
||||
model,
|
||||
config_id
|
||||
tenant_id, provider, model_type.value, model, config_id
|
||||
)
|
||||
|
||||
ttl = redis_client.ttl(cooldown_cache_key)
|
||||
|
|
|
@ -52,7 +52,7 @@
|
|||
- `mode` (string) voice model.(available for model type `tts`)
|
||||
- `name` (string) voice model display name.(available for model type `tts`)
|
||||
- `language` (string) the voice model supports languages.(available for model type `tts`)
|
||||
- `word_limit` (int) Single conversion word limit, paragraphwise by default(available for model type `tts`)
|
||||
- `word_limit` (int) Single conversion word limit, paragraph-wise by default(available for model type `tts`)
|
||||
- `audio_type` (string) Support audio file extension format, e.g.:mp3,wav(available for model type `tts`)
|
||||
- `max_workers` (int) Number of concurrent workers supporting text and audio conversion(available for model type`tts`)
|
||||
- `max_characters_per_chunk` (int) Maximum characters per chunk (available for model type `moderation`)
|
||||
|
@ -150,7 +150,7 @@
|
|||
|
||||
- `input` (float) Input price, i.e., Prompt price
|
||||
- `output` (float) Output price, i.e., returned content price
|
||||
- `unit` (float) Pricing unit, e.g., if the price is meausred in 1M tokens, the corresponding token amount for the unit price is `0.000001`.
|
||||
- `unit` (float) Pricing unit, e.g., if the price is measured in 1M tokens, the corresponding token amount for the unit price is `0.000001`.
|
||||
- `currency` (string) Currency unit
|
||||
|
||||
### ProviderCredentialSchema
|
||||
|
|
|
@ -33,6 +33,22 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
|
|||
'max': 1.0,
|
||||
'precision': 2,
|
||||
},
|
||||
DefaultParameterName.TOP_K: {
|
||||
'label': {
|
||||
'en_US': 'Top K',
|
||||
'zh_Hans': 'Top K',
|
||||
},
|
||||
'type': 'int',
|
||||
'help': {
|
||||
'en_US': 'Limits the number of tokens to consider for each step by keeping only the k most likely tokens.',
|
||||
'zh_Hans': '通过只保留每一步中最可能的 k 个标记来限制要考虑的标记数量。',
|
||||
},
|
||||
'required': False,
|
||||
'default': 50,
|
||||
'min': 1,
|
||||
'max': 100,
|
||||
'precision': 0,
|
||||
},
|
||||
DefaultParameterName.PRESENCE_PENALTY: {
|
||||
'label': {
|
||||
'en_US': 'Presence Penalty',
|
||||
|
|
|
@ -63,6 +63,39 @@ class LLMUsage(ModelUsage):
|
|||
latency=0.0
|
||||
)
|
||||
|
||||
def plus(self, other: 'LLMUsage') -> 'LLMUsage':
|
||||
"""
|
||||
Add two LLMUsage instances together.
|
||||
|
||||
:param other: Another LLMUsage instance to add
|
||||
:return: A new LLMUsage instance with summed values
|
||||
"""
|
||||
if self.total_tokens == 0:
|
||||
return other
|
||||
else:
|
||||
return LLMUsage(
|
||||
prompt_tokens=self.prompt_tokens + other.prompt_tokens,
|
||||
prompt_unit_price=other.prompt_unit_price,
|
||||
prompt_price_unit=other.prompt_price_unit,
|
||||
prompt_price=self.prompt_price + other.prompt_price,
|
||||
completion_tokens=self.completion_tokens + other.completion_tokens,
|
||||
completion_unit_price=other.completion_unit_price,
|
||||
completion_price_unit=other.completion_price_unit,
|
||||
completion_price=self.completion_price + other.completion_price,
|
||||
total_tokens=self.total_tokens + other.total_tokens,
|
||||
total_price=self.total_price + other.total_price,
|
||||
currency=other.currency,
|
||||
latency=self.latency + other.latency
|
||||
)
|
||||
|
||||
def __add__(self, other: 'LLMUsage') -> 'LLMUsage':
|
||||
"""
|
||||
Overload the + operator to add two LLMUsage instances.
|
||||
|
||||
:param other: Another LLMUsage instance to add
|
||||
:return: A new LLMUsage instance with summed values
|
||||
"""
|
||||
return self.plus(other)
|
||||
|
||||
class LLMResult(BaseModel):
|
||||
"""
|
||||
|
|
|
@ -85,12 +85,13 @@ class ModelFeature(Enum):
|
|||
STREAM_TOOL_CALL = "stream-tool-call"
|
||||
|
||||
|
||||
class DefaultParameterName(Enum):
|
||||
class DefaultParameterName(str, Enum):
|
||||
"""
|
||||
Enum class for parameter template variable.
|
||||
"""
|
||||
TEMPERATURE = "temperature"
|
||||
TOP_P = "top_p"
|
||||
TOP_K = "top_k"
|
||||
PRESENCE_PENALTY = "presence_penalty"
|
||||
FREQUENCY_PENALTY = "frequency_penalty"
|
||||
MAX_TOKENS = "max_tokens"
|
||||
|
|
|
@ -13,7 +13,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
class TTSModel(AIModel):
|
||||
"""
|
||||
Model class for ttstext model.
|
||||
Model class for TTS model.
|
||||
"""
|
||||
model_type: ModelType = ModelType.TTS
|
||||
|
||||
|
|
|
@ -19,9 +19,9 @@ class AnthropicProvider(ModelProvider):
|
|||
try:
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
# Use `claude-instant-1` model for validate,
|
||||
# Use `claude-3-opus-20240229` model for validate,
|
||||
model_instance.validate_credentials(
|
||||
model='claude-instant-1.2',
|
||||
model='claude-3-opus-20240229',
|
||||
credentials=credentials
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
|
|
|
@ -33,3 +33,4 @@ pricing:
|
|||
output: '5.51'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
deprecated: true
|
||||
|
|
|
@ -637,7 +637,19 @@ LLM_BASE_MODELS = [
|
|||
en_US='specifying the format that the model must output'
|
||||
),
|
||||
required=False,
|
||||
options=['text', 'json_object']
|
||||
options=['text', 'json_object', 'json_schema']
|
||||
),
|
||||
ParameterRule(
|
||||
name='json_schema',
|
||||
label=I18nObject(
|
||||
en_US='JSON Schema'
|
||||
),
|
||||
type='text',
|
||||
help=I18nObject(
|
||||
zh_Hans='设置返回的json schema,llm将按照它返回',
|
||||
en_US='Set a response json schema will ensure LLM to adhere it.'
|
||||
),
|
||||
required=False
|
||||
),
|
||||
],
|
||||
pricing=PriceConfig(
|
||||
|
@ -800,6 +812,94 @@ LLM_BASE_MODELS = [
|
|||
)
|
||||
)
|
||||
),
|
||||
AzureBaseModel(
|
||||
base_model_name='gpt-4o-2024-08-06',
|
||||
entity=AIModelEntity(
|
||||
model='fake-deployment-name',
|
||||
label=I18nObject(
|
||||
en_US='fake-deployment-name-label',
|
||||
),
|
||||
model_type=ModelType.LLM,
|
||||
features=[
|
||||
ModelFeature.AGENT_THOUGHT,
|
||||
ModelFeature.VISION,
|
||||
ModelFeature.MULTI_TOOL_CALL,
|
||||
ModelFeature.STREAM_TOOL_CALL,
|
||||
],
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.MODE: LLMMode.CHAT.value,
|
||||
ModelPropertyKey.CONTEXT_SIZE: 128000,
|
||||
},
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name='temperature',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
|
||||
),
|
||||
ParameterRule(
|
||||
name='top_p',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
|
||||
),
|
||||
ParameterRule(
|
||||
name='presence_penalty',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
|
||||
),
|
||||
ParameterRule(
|
||||
name='frequency_penalty',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
|
||||
),
|
||||
_get_max_tokens(default=512, min_val=1, max_val=4096),
|
||||
ParameterRule(
|
||||
name='seed',
|
||||
label=I18nObject(
|
||||
zh_Hans='种子',
|
||||
en_US='Seed'
|
||||
),
|
||||
type='int',
|
||||
help=I18nObject(
|
||||
zh_Hans='如果指定,模型将尽最大努力进行确定性采样,使得重复的具有相同种子和参数的请求应该返回相同的结果。不能保证确定性,您应该参考 system_fingerprint 响应参数来监视变化。',
|
||||
en_US='If specified, model will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor changes in the backend.'
|
||||
),
|
||||
required=False,
|
||||
precision=2,
|
||||
min=0,
|
||||
max=1,
|
||||
),
|
||||
ParameterRule(
|
||||
name='response_format',
|
||||
label=I18nObject(
|
||||
zh_Hans='回复格式',
|
||||
en_US='response_format'
|
||||
),
|
||||
type='string',
|
||||
help=I18nObject(
|
||||
zh_Hans='指定模型必须输出的格式',
|
||||
en_US='specifying the format that the model must output'
|
||||
),
|
||||
required=False,
|
||||
options=['text', 'json_object', 'json_schema']
|
||||
),
|
||||
ParameterRule(
|
||||
name='json_schema',
|
||||
label=I18nObject(
|
||||
en_US='JSON Schema'
|
||||
),
|
||||
type='text',
|
||||
help=I18nObject(
|
||||
zh_Hans='设置返回的json schema,llm将按照它返回',
|
||||
en_US='Set a response json schema will ensure LLM to adhere it.'
|
||||
),
|
||||
required=False
|
||||
),
|
||||
],
|
||||
pricing=PriceConfig(
|
||||
input=5.00,
|
||||
output=15.00,
|
||||
unit=0.000001,
|
||||
currency='USD',
|
||||
)
|
||||
)
|
||||
),
|
||||
AzureBaseModel(
|
||||
base_model_name='gpt-4-turbo',
|
||||
entity=AIModelEntity(
|
||||
|
|
|
@ -138,6 +138,12 @@ model_credential_schema:
|
|||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4o-2024-08-06
|
||||
value: gpt-4o-2024-08-06
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4-turbo
|
||||
value: gpt-4-turbo
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import copy
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import Optional, Union, cast
|
||||
|
@ -276,12 +277,18 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
|
|||
|
||||
response_format = model_parameters.get("response_format")
|
||||
if response_format:
|
||||
if response_format == "json_object":
|
||||
response_format = {"type": "json_object"}
|
||||
if response_format == "json_schema":
|
||||
json_schema = model_parameters.get("json_schema")
|
||||
if not json_schema:
|
||||
raise ValueError("Must define JSON Schema when the response format is json_schema")
|
||||
try:
|
||||
schema = json.loads(json_schema)
|
||||
except:
|
||||
raise ValueError(f"not correct json_schema format: {json_schema}")
|
||||
model_parameters.pop("json_schema")
|
||||
model_parameters["response_format"] = {"type": "json_schema", "json_schema": schema}
|
||||
else:
|
||||
response_format = {"type": "text"}
|
||||
|
||||
model_parameters["response_format"] = response_format
|
||||
model_parameters["response_format"] = {"type": response_format}
|
||||
|
||||
extra_model_kwargs = {}
|
||||
|
||||
|
|
|
@ -27,11 +27,3 @@ provider_credential_schema:
|
|||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
- variable: secret_key
|
||||
label:
|
||||
en_US: Secret Key
|
||||
type: secret-input
|
||||
required: false
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 Secret Key
|
||||
en_US: Enter your Secret Key
|
||||
|
|
|
@ -43,3 +43,4 @@ parameter_rules:
|
|||
zh_Hans: 允许模型自行进行外部搜索,以增强生成结果。
|
||||
en_US: Allow the model to perform external search to enhance the generation results.
|
||||
required: false
|
||||
deprecated: true
|
||||
|
|
|
@ -43,3 +43,4 @@ parameter_rules:
|
|||
zh_Hans: 允许模型自行进行外部搜索,以增强生成结果。
|
||||
en_US: Allow the model to perform external search to enhance the generation results.
|
||||
required: false
|
||||
deprecated: true
|
||||
|
|
|
@ -4,36 +4,32 @@ label:
|
|||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- multi-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 0.3
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
default: 0.85
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
min: 0
|
||||
max: 20
|
||||
default: 5
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8000
|
||||
min: 1
|
||||
max: 192000
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
default: 1
|
||||
min: 1
|
||||
max: 2
|
||||
default: 2048
|
||||
- name: with_search_enhance
|
||||
label:
|
||||
zh_Hans: 搜索增强
|
||||
|
|
|
@ -4,36 +4,44 @@ label:
|
|||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- multi-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 0.3
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
default: 0.85
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
min: 0
|
||||
max: 20
|
||||
default: 5
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8000
|
||||
min: 1
|
||||
max: 128000
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
default: 1
|
||||
min: 1
|
||||
max: 2
|
||||
default: 2048
|
||||
- name: res_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
- name: with_search_enhance
|
||||
label:
|
||||
zh_Hans: 搜索增强
|
||||
|
|
|
@ -4,36 +4,44 @@ label:
|
|||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- multi-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 0.3
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
default: 0.85
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
min: 0
|
||||
max: 20
|
||||
default: 5
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8000
|
||||
min: 1
|
||||
max: 32000
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
default: 1
|
||||
min: 1
|
||||
max: 2
|
||||
default: 2048
|
||||
- name: res_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
- name: with_search_enhance
|
||||
label:
|
||||
zh_Hans: 搜索增强
|
||||
|
|
|
@ -4,36 +4,44 @@ label:
|
|||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- multi-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 0.3
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
default: 0.85
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
min: 0
|
||||
max: 20
|
||||
default: 5
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8000
|
||||
min: 1
|
||||
max: 32000
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
default: 1
|
||||
min: 1
|
||||
max: 2
|
||||
default: 2048
|
||||
- name: res_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
- name: with_search_enhance
|
||||
label:
|
||||
zh_Hans: 搜索增强
|
||||
|
|
|
@ -1,11 +1,10 @@
|
|||
from collections.abc import Generator
|
||||
from enum import Enum
|
||||
from hashlib import md5
|
||||
from json import dumps, loads
|
||||
from typing import Any, Union
|
||||
import json
|
||||
from collections.abc import Iterator
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from requests import post
|
||||
|
||||
from core.model_runtime.entities.message_entities import PromptMessageTool
|
||||
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (
|
||||
BadRequestError,
|
||||
InsufficientAccountBalance,
|
||||
|
@ -16,203 +15,133 @@ from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors impor
|
|||
)
|
||||
|
||||
|
||||
class BaichuanMessage:
|
||||
class Role(Enum):
|
||||
USER = 'user'
|
||||
ASSISTANT = 'assistant'
|
||||
# Baichuan does not have system message
|
||||
_SYSTEM = 'system'
|
||||
|
||||
role: str = Role.USER.value
|
||||
content: str
|
||||
usage: dict[str, int] = None
|
||||
stop_reason: str = ''
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
'role': self.role,
|
||||
'content': self.content,
|
||||
}
|
||||
|
||||
def __init__(self, content: str, role: str = 'user') -> None:
|
||||
self.content = content
|
||||
self.role = role
|
||||
|
||||
class BaichuanModel:
|
||||
api_key: str
|
||||
secret_key: str
|
||||
|
||||
def __init__(self, api_key: str, secret_key: str = '') -> None:
|
||||
def __init__(self, api_key: str) -> None:
|
||||
self.api_key = api_key
|
||||
self.secret_key = secret_key
|
||||
|
||||
def _model_mapping(self, model: str) -> str:
|
||||
@property
|
||||
def _model_mapping(self) -> dict:
|
||||
return {
|
||||
'baichuan2-turbo': 'Baichuan2-Turbo',
|
||||
'baichuan2-turbo-192k': 'Baichuan2-Turbo-192k',
|
||||
'baichuan2-53b': 'Baichuan2-53B',
|
||||
'baichuan3-turbo': 'Baichuan3-Turbo',
|
||||
'baichuan3-turbo-128k': 'Baichuan3-Turbo-128k',
|
||||
'baichuan4': 'Baichuan4',
|
||||
}[model]
|
||||
"baichuan2-turbo": "Baichuan2-Turbo",
|
||||
"baichuan3-turbo": "Baichuan3-Turbo",
|
||||
"baichuan3-turbo-128k": "Baichuan3-Turbo-128k",
|
||||
"baichuan4": "Baichuan4",
|
||||
}
|
||||
|
||||
def _handle_chat_generate_response(self, response) -> BaichuanMessage:
|
||||
resp = response.json()
|
||||
choices = resp.get('choices', [])
|
||||
message = BaichuanMessage(content='', role='assistant')
|
||||
for choice in choices:
|
||||
message.content += choice['message']['content']
|
||||
message.role = choice['message']['role']
|
||||
if choice['finish_reason']:
|
||||
message.stop_reason = choice['finish_reason']
|
||||
@property
|
||||
def request_headers(self) -> dict[str, Any]:
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer " + self.api_key,
|
||||
}
|
||||
|
||||
if 'usage' in resp:
|
||||
message.usage = {
|
||||
'prompt_tokens': resp['usage']['prompt_tokens'],
|
||||
'completion_tokens': resp['usage']['completion_tokens'],
|
||||
'total_tokens': resp['usage']['total_tokens'],
|
||||
}
|
||||
def _build_parameters(
|
||||
self,
|
||||
model: str,
|
||||
stream: bool,
|
||||
messages: list[dict],
|
||||
parameters: dict[str, Any],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
) -> dict[str, Any]:
|
||||
if model in self._model_mapping.keys():
|
||||
# the LargeLanguageModel._code_block_mode_wrapper() method will remove the response_format of parameters.
|
||||
# we need to rename it to res_format to get its value
|
||||
if parameters.get("res_format") == "json_object":
|
||||
parameters["response_format"] = {"type": "json_object"}
|
||||
|
||||
return message
|
||||
|
||||
def _handle_chat_stream_generate_response(self, response) -> Generator:
|
||||
for line in response.iter_lines():
|
||||
if not line:
|
||||
continue
|
||||
line = line.decode('utf-8')
|
||||
# remove the first `data: ` prefix
|
||||
if line.startswith('data:'):
|
||||
line = line[5:].strip()
|
||||
try:
|
||||
data = loads(line)
|
||||
except Exception as e:
|
||||
if line.strip() == '[DONE]':
|
||||
return
|
||||
choices = data.get('choices', [])
|
||||
# save stop reason temporarily
|
||||
stop_reason = ''
|
||||
for choice in choices:
|
||||
if choice.get('finish_reason'):
|
||||
stop_reason = choice['finish_reason']
|
||||
if tools or parameters.get("with_search_enhance") is True:
|
||||
parameters["tools"] = []
|
||||
|
||||
if len(choice['delta']['content']) == 0:
|
||||
continue
|
||||
yield BaichuanMessage(**choice['delta'])
|
||||
|
||||
# if there is usage, the response is the last one, yield it and return
|
||||
if 'usage' in data:
|
||||
message = BaichuanMessage(content='', role='assistant')
|
||||
message.usage = {
|
||||
'prompt_tokens': data['usage']['prompt_tokens'],
|
||||
'completion_tokens': data['usage']['completion_tokens'],
|
||||
'total_tokens': data['usage']['total_tokens'],
|
||||
}
|
||||
message.stop_reason = stop_reason
|
||||
yield message
|
||||
|
||||
def _build_parameters(self, model: str, stream: bool, messages: list[BaichuanMessage],
|
||||
parameters: dict[str, Any]) \
|
||||
-> dict[str, Any]:
|
||||
if (model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b'
|
||||
or model == 'baichuan3-turbo' or model == 'baichuan3-turbo-128k' or model == 'baichuan4'):
|
||||
prompt_messages = []
|
||||
for message in messages:
|
||||
if message.role == BaichuanMessage.Role.USER.value or message.role == BaichuanMessage.Role._SYSTEM.value:
|
||||
# check if the latest message is a user message
|
||||
if len(prompt_messages) > 0 and prompt_messages[-1]['role'] == BaichuanMessage.Role.USER.value:
|
||||
prompt_messages[-1]['content'] += message.content
|
||||
else:
|
||||
prompt_messages.append({
|
||||
'content': message.content,
|
||||
'role': BaichuanMessage.Role.USER.value,
|
||||
})
|
||||
elif message.role == BaichuanMessage.Role.ASSISTANT.value:
|
||||
prompt_messages.append({
|
||||
'content': message.content,
|
||||
'role': message.role,
|
||||
})
|
||||
# [baichuan] frequency_penalty must be between 1 and 2
|
||||
if 'frequency_penalty' in parameters:
|
||||
if parameters['frequency_penalty'] < 1 or parameters['frequency_penalty'] > 2:
|
||||
parameters['frequency_penalty'] = 1
|
||||
# with_search_enhance is deprecated, use web_search instead
|
||||
if parameters.get("with_search_enhance") is True:
|
||||
parameters["tools"].append(
|
||||
{
|
||||
"type": "web_search",
|
||||
"web_search": {"enable": True},
|
||||
}
|
||||
)
|
||||
if tools:
|
||||
for tool in tools:
|
||||
parameters["tools"].append(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.parameters,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# turbo api accepts flat parameters
|
||||
return {
|
||||
'model': self._model_mapping(model),
|
||||
'stream': stream,
|
||||
'messages': prompt_messages,
|
||||
"model": self._model_mapping.get(model),
|
||||
"stream": stream,
|
||||
"messages": messages,
|
||||
**parameters,
|
||||
}
|
||||
else:
|
||||
raise BadRequestError(f"Unknown model: {model}")
|
||||
|
||||
def _build_headers(self, model: str, data: dict[str, Any]) -> dict[str, Any]:
|
||||
if (model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b'
|
||||
or model == 'baichuan3-turbo' or model == 'baichuan3-turbo-128k' or model == 'baichuan4'):
|
||||
# there is no secret key for turbo api
|
||||
return {
|
||||
'Content-Type': 'application/json',
|
||||
'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) ',
|
||||
'Authorization': 'Bearer ' + self.api_key,
|
||||
}
|
||||
else:
|
||||
raise BadRequestError(f"Unknown model: {model}")
|
||||
|
||||
def _calculate_md5(self, input_string):
|
||||
return md5(input_string.encode('utf-8')).hexdigest()
|
||||
|
||||
def generate(self, model: str, stream: bool, messages: list[BaichuanMessage],
|
||||
parameters: dict[str, Any], timeout: int) \
|
||||
-> Union[Generator, BaichuanMessage]:
|
||||
|
||||
if (model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b'
|
||||
or model == 'baichuan3-turbo' or model == 'baichuan3-turbo-128k' or model == 'baichuan4'):
|
||||
api_base = 'https://api.baichuan-ai.com/v1/chat/completions'
|
||||
def generate(
|
||||
self,
|
||||
model: str,
|
||||
stream: bool,
|
||||
messages: list[dict],
|
||||
parameters: dict[str, Any],
|
||||
timeout: int,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
) -> Union[Iterator, dict]:
|
||||
|
||||
if model in self._model_mapping.keys():
|
||||
api_base = "https://api.baichuan-ai.com/v1/chat/completions"
|
||||
else:
|
||||
raise BadRequestError(f"Unknown model: {model}")
|
||||
|
||||
try:
|
||||
data = self._build_parameters(model, stream, messages, parameters)
|
||||
headers = self._build_headers(model, data)
|
||||
except KeyError:
|
||||
raise InternalServerError(f"Failed to build parameters for model: {model}")
|
||||
|
||||
data = self._build_parameters(model, stream, messages, parameters, tools)
|
||||
|
||||
try:
|
||||
response = post(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
data=dumps(data),
|
||||
headers=self.request_headers,
|
||||
data=json.dumps(data),
|
||||
timeout=timeout,
|
||||
stream=stream
|
||||
stream=stream,
|
||||
)
|
||||
except Exception as e:
|
||||
raise InternalServerError(f"Failed to invoke model: {e}")
|
||||
|
||||
|
||||
if response.status_code != 200:
|
||||
try:
|
||||
resp = response.json()
|
||||
# try to parse error message
|
||||
err = resp['error']['code']
|
||||
msg = resp['error']['message']
|
||||
err = resp["error"]["type"]
|
||||
msg = resp["error"]["message"]
|
||||
except Exception as e:
|
||||
raise InternalServerError(f"Failed to convert response to json: {e} with text: {response.text}")
|
||||
raise InternalServerError(
|
||||
f"Failed to convert response to json: {e} with text: {response.text}"
|
||||
)
|
||||
|
||||
if err == 'invalid_api_key':
|
||||
if err == "invalid_api_key":
|
||||
raise InvalidAPIKeyError(msg)
|
||||
elif err == 'insufficient_quota':
|
||||
elif err == "insufficient_quota":
|
||||
raise InsufficientAccountBalance(msg)
|
||||
elif err == 'invalid_authentication':
|
||||
elif err == "invalid_authentication":
|
||||
raise InvalidAuthenticationError(msg)
|
||||
elif 'rate' in err:
|
||||
elif err == "invalid_request_error":
|
||||
raise BadRequestError(msg)
|
||||
elif "rate" in err:
|
||||
raise RateLimitReachedError(msg)
|
||||
elif 'internal' in err:
|
||||
elif "internal" in err:
|
||||
raise InternalServerError(msg)
|
||||
elif err == 'api_key_empty':
|
||||
elif err == "api_key_empty":
|
||||
raise InvalidAPIKeyError(msg)
|
||||
else:
|
||||
raise InternalServerError(f"Unknown error: {err} with message: {msg}")
|
||||
|
||||
|
||||
if stream:
|
||||
return self._handle_chat_stream_generate_response(response)
|
||||
return response.iter_lines()
|
||||
else:
|
||||
return self._handle_chat_generate_response(response)
|
||||
return response.json()
|
||||
|
|
|
@ -1,7 +1,12 @@
|
|||
from collections.abc import Generator
|
||||
import json
|
||||
from collections.abc import Generator, Iterator
|
||||
from typing import cast
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
|
@ -21,7 +26,7 @@ from core.model_runtime.errors.invoke import (
|
|||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import BaichuanTokenizer
|
||||
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo import BaichuanMessage, BaichuanModel
|
||||
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo import BaichuanModel
|
||||
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (
|
||||
BadRequestError,
|
||||
InsufficientAccountBalance,
|
||||
|
@ -32,20 +37,41 @@ from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors impor
|
|||
)
|
||||
|
||||
|
||||
class BaichuanLarguageModel(LargeLanguageModel):
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
|
||||
stream: bool = True, user: str | None = None) \
|
||||
-> LLMResult | Generator:
|
||||
return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user)
|
||||
class BaichuanLanguageModel(LargeLanguageModel):
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None = None) -> int:
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
) -> LLMResult | Generator:
|
||||
return self._generate(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
def get_num_tokens(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
) -> int:
|
||||
return self._num_tokens_from_messages(prompt_messages)
|
||||
|
||||
def _num_tokens_from_messages(self, messages: list[PromptMessage], ) -> int:
|
||||
def _num_tokens_from_messages(
|
||||
self,
|
||||
messages: list[PromptMessage],
|
||||
) -> int:
|
||||
"""Calculate num tokens for baichuan model"""
|
||||
|
||||
def tokens(text: str):
|
||||
|
@ -59,10 +85,10 @@ class BaichuanLarguageModel(LargeLanguageModel):
|
|||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
if isinstance(value, list):
|
||||
text = ''
|
||||
text = ""
|
||||
for item in value:
|
||||
if isinstance(item, dict) and item['type'] == 'text':
|
||||
text += item['text']
|
||||
if isinstance(item, dict) and item["type"] == "text":
|
||||
text += item["text"]
|
||||
|
||||
value = text
|
||||
|
||||
|
@ -84,19 +110,18 @@ class BaichuanLarguageModel(LargeLanguageModel):
|
|||
elif isinstance(message, AssistantPromptMessage):
|
||||
message = cast(AssistantPromptMessage, message)
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if message.tool_calls:
|
||||
message_dict["tool_calls"] = [tool_call.dict() for tool_call in
|
||||
message.tool_calls]
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
# copy from core/model_runtime/model_providers/anthropic/llm/llm.py
|
||||
message = cast(ToolPromptMessage, message)
|
||||
message_dict = {
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": message.tool_call_id,
|
||||
"content": message.content
|
||||
}]
|
||||
"role": "tool",
|
||||
"content": message.content,
|
||||
"tool_call_id": message.tool_call_id
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unknown message type {type(message)}")
|
||||
|
@ -105,102 +130,159 @@ class BaichuanLarguageModel(LargeLanguageModel):
|
|||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
# ping
|
||||
instance = BaichuanModel(
|
||||
api_key=credentials['api_key'],
|
||||
secret_key=credentials.get('secret_key', '')
|
||||
)
|
||||
instance = BaichuanModel(api_key=credentials["api_key"])
|
||||
|
||||
try:
|
||||
instance.generate(model=model, stream=False, messages=[
|
||||
BaichuanMessage(content='ping', role='user')
|
||||
], parameters={
|
||||
'max_tokens': 1,
|
||||
}, timeout=60)
|
||||
instance.generate(
|
||||
model=model,
|
||||
stream=False,
|
||||
messages=[{"content": "ping", "role": "user"}],
|
||||
parameters={
|
||||
"max_tokens": 1,
|
||||
},
|
||||
timeout=60,
|
||||
)
|
||||
except Exception as e:
|
||||
raise CredentialsValidateFailedError(f"Invalid API key: {e}")
|
||||
|
||||
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict, tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None, stream: bool = True, user: str | None = None) \
|
||||
-> LLMResult | Generator:
|
||||
if tools is not None and len(tools) > 0:
|
||||
raise InvokeBadRequestError("Baichuan model doesn't support tools")
|
||||
def _generate(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stream: bool = True,
|
||||
) -> LLMResult | Generator:
|
||||
|
||||
instance = BaichuanModel(
|
||||
api_key=credentials['api_key'],
|
||||
secret_key=credentials.get('secret_key', '')
|
||||
)
|
||||
|
||||
# convert prompt messages to baichuan messages
|
||||
messages = [
|
||||
BaichuanMessage(
|
||||
content=message.content if isinstance(message.content, str) else ''.join([
|
||||
content.data for content in message.content
|
||||
]),
|
||||
role=message.role.value
|
||||
) for message in prompt_messages
|
||||
]
|
||||
instance = BaichuanModel(api_key=credentials["api_key"])
|
||||
messages = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
|
||||
|
||||
# invoke model
|
||||
response = instance.generate(model=model, stream=stream, messages=messages, parameters=model_parameters,
|
||||
timeout=60)
|
||||
response = instance.generate(
|
||||
model=model,
|
||||
stream=stream,
|
||||
messages=messages,
|
||||
parameters=model_parameters,
|
||||
timeout=60,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
if stream:
|
||||
return self._handle_chat_generate_stream_response(model, prompt_messages, credentials, response)
|
||||
return self._handle_chat_generate_stream_response(
|
||||
model, prompt_messages, credentials, response
|
||||
)
|
||||
|
||||
return self._handle_chat_generate_response(model, prompt_messages, credentials, response)
|
||||
return self._handle_chat_generate_response(
|
||||
model, prompt_messages, credentials, response
|
||||
)
|
||||
|
||||
def _handle_chat_generate_response(
|
||||
self,
|
||||
model: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
credentials: dict,
|
||||
response: dict,
|
||||
) -> LLMResult:
|
||||
choices = response.get("choices", [])
|
||||
assistant_message = AssistantPromptMessage(content='', tool_calls=[])
|
||||
if choices and choices[0]["finish_reason"] == "tool_calls":
|
||||
for choice in choices:
|
||||
for tool_call in choice["message"]["tool_calls"]:
|
||||
tool = AssistantPromptMessage.ToolCall(
|
||||
id=tool_call.get("id", ""),
|
||||
type=tool_call.get("type", ""),
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=tool_call.get("function", {}).get("name", ""),
|
||||
arguments=tool_call.get("function", {}).get("arguments", "")
|
||||
),
|
||||
)
|
||||
assistant_message.tool_calls.append(tool)
|
||||
else:
|
||||
for choice in choices:
|
||||
assistant_message.content += choice["message"]["content"]
|
||||
assistant_message.role = choice["message"]["role"]
|
||||
|
||||
usage = response.get("usage")
|
||||
if usage:
|
||||
# transform usage
|
||||
prompt_tokens = usage["prompt_tokens"]
|
||||
completion_tokens = usage["completion_tokens"]
|
||||
else:
|
||||
# calculate num tokens
|
||||
prompt_tokens = self._num_tokens_from_messages(prompt_messages)
|
||||
completion_tokens = self._num_tokens_from_messages([assistant_message])
|
||||
|
||||
usage = self._calc_response_usage(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
|
||||
def _handle_chat_generate_response(self, model: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
credentials: dict,
|
||||
response: BaichuanMessage) -> LLMResult:
|
||||
# convert baichuan message to llm result
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials,
|
||||
prompt_tokens=response.usage['prompt_tokens'],
|
||||
completion_tokens=response.usage['completion_tokens'])
|
||||
return LLMResult(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(
|
||||
content=response.content,
|
||||
tool_calls=[]
|
||||
),
|
||||
message=assistant_message,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def _handle_chat_generate_stream_response(self, model: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
credentials: dict,
|
||||
response: Generator[BaichuanMessage, None, None]) -> Generator:
|
||||
for message in response:
|
||||
if message.usage:
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials,
|
||||
prompt_tokens=message.usage['prompt_tokens'],
|
||||
completion_tokens=message.usage['completion_tokens'])
|
||||
def _handle_chat_generate_stream_response(
|
||||
self,
|
||||
model: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
credentials: dict,
|
||||
response: Iterator,
|
||||
) -> Generator:
|
||||
for line in response:
|
||||
if not line:
|
||||
continue
|
||||
line = line.decode("utf-8")
|
||||
# remove the first `data: ` prefix
|
||||
if line.startswith("data:"):
|
||||
line = line[5:].strip()
|
||||
try:
|
||||
data = json.loads(line)
|
||||
except Exception as e:
|
||||
if line.strip() == "[DONE]":
|
||||
return
|
||||
choices = data.get("choices", [])
|
||||
|
||||
stop_reason = ""
|
||||
for choice in choices:
|
||||
if choice.get("finish_reason"):
|
||||
stop_reason = choice["finish_reason"]
|
||||
|
||||
if len(choice["delta"]["content"]) == 0:
|
||||
continue
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=message.content,
|
||||
tool_calls=[]
|
||||
content=choice["delta"]["content"], tool_calls=[]
|
||||
),
|
||||
usage=usage,
|
||||
finish_reason=message.stop_reason if message.stop_reason else None,
|
||||
finish_reason=stop_reason,
|
||||
),
|
||||
)
|
||||
else:
|
||||
|
||||
# if there is usage, the response is the last one, yield it and return
|
||||
if "usage" in data:
|
||||
usage = self._calc_response_usage(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_tokens=data["usage"]["prompt_tokens"],
|
||||
completion_tokens=data["usage"]["completion_tokens"],
|
||||
)
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=message.content,
|
||||
tool_calls=[]
|
||||
),
|
||||
finish_reason=message.stop_reason if message.stop_reason else None,
|
||||
message=AssistantPromptMessage(content="", tool_calls=[]),
|
||||
usage=usage,
|
||||
finish_reason=stop_reason,
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -215,21 +297,13 @@ class BaichuanLarguageModel(LargeLanguageModel):
|
|||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
InternalServerError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
RateLimitReachedError
|
||||
],
|
||||
InvokeConnectionError: [],
|
||||
InvokeServerUnavailableError: [InternalServerError],
|
||||
InvokeRateLimitError: [RateLimitReachedError],
|
||||
InvokeAuthorizationError: [
|
||||
InvalidAuthenticationError,
|
||||
InsufficientAccountBalance,
|
||||
InvalidAPIKeyError,
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
BadRequestError,
|
||||
KeyError
|
||||
]
|
||||
InvokeBadRequestError: [BadRequestError, KeyError],
|
||||
}
|
||||
|
|
|
@ -60,7 +60,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
|
|||
token_usage = 0
|
||||
|
||||
for chunk in chunks:
|
||||
# embeding chunk
|
||||
# embedding chunk
|
||||
chunk_embeddings, chunk_usage = self.embedding(
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
|
|
|
@ -793,11 +793,11 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
|||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the ermd = genai.GenerativeModel(model)ror type thrown to the caller
|
||||
The value is the md = genai.GenerativeModel(model)error type thrown by the model,
|
||||
The key is the ermd = genai.GenerativeModel(model) error type thrown to the caller
|
||||
The value is the md = genai.GenerativeModel(model) error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke emd = genai.GenerativeModel(model)rror mapping
|
||||
:return: Invoke emd = genai.GenerativeModel(model) error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [],
|
||||
|
|
|
@ -130,11 +130,11 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
|
|||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the ermd = genai.GenerativeModel(model)ror type thrown to the caller
|
||||
The value is the md = genai.GenerativeModel(model)error type thrown by the model,
|
||||
The key is the ermd = genai.GenerativeModel(model) error type thrown to the caller
|
||||
The value is the md = genai.GenerativeModel(model) error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke emd = genai.GenerativeModel(model)rror mapping
|
||||
:return: Invoke emd = genai.GenerativeModel(model) error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [],
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
|
|
@ -0,0 +1 @@
|
|||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="61.1 180.15 377.8 139.718"><path d="M431.911 245.181c3.842 0 6.989 1.952 6.989 4.337v14.776c0 2.385-3.147 4.337-6.989 4.337-3.846 0-6.99-1.952-6.99-4.337v-14.776c0-2.385 3.144-4.337 6.99-4.337ZM404.135 250.955c3.846 0 6.989 1.952 6.989 4.337v32.528c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.952-6.989-4.337v-32.528c0-2.385 3.147-4.337 6.989-4.337ZM376.363 257.688c3.842 0 6.989 1.952 6.989 4.337v36.562c0 2.385-3.147 4.337-6.989 4.337-3.846 0-6.993-1.952-6.993-4.337v-36.562c0-2.386 3.147-4.337 6.993-4.337ZM348.587 263.26c3.846 0 6.989 1.952 6.989 4.337v36.159c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.952-6.989-4.337v-36.159c0-2.385 3.147-4.337 6.989-4.337ZM320.811 268.177c3.846 0 6.989 1.952 6.989 4.337v31.318c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.952-6.989-4.337v-31.318c0-2.385 3.147-4.337 6.989-4.337ZM293.179 288.148c3.846 0 6.989 1.952 6.989 4.337v9.935c0 2.384-3.147 4.336-6.989 4.336s-6.99-1.951-6.99-4.336v-9.935c0-2.386 3.144-4.337 6.99-4.337Z" style="fill:#b1b3b4;fill-rule:evenodd"></path><path d="M431.911 205.441c3.842 0 6.989 1.952 6.989 4.337v24.459c0 2.385-3.147 4.337-6.989 4.337-3.846 0-6.99-1.952-6.99-4.337v-24.459c0-2.385 3.144-4.337 6.99-4.337ZM404.135 189.026c3.846 0 6.989 1.952 6.989 4.337v43.622c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.951-6.989-4.337v-43.622c0-2.385 3.147-4.337 6.989-4.337ZM376.363 182.848c3.842 0 6.989 1.953 6.989 4.337v56.937c0 2.384-3.147 4.337-6.989 4.337-3.846 0-6.993-1.952-6.993-4.337v-56.937c0-2.385 3.147-4.337 6.993-4.337ZM348.587 180.15c3.846 0 6.989 1.952 6.989 4.337v66.619c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.952-6.989-4.337v-66.619c0-2.385 3.147-4.337 6.989-4.337ZM320.811 181.84c3.846 0 6.989 1.952 6.989 4.337v67.627c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.951-6.989-4.337v-67.627c0-2.386 3.147-4.337 6.989-4.337ZM293.179 186.076c3.846 0 6.989 1.952 6.989 4.337v84.37c0 2.385-3.147 4.337-6.989 4.337s-6.99-1.951-6.99-4.337v-84.37c0-2.386 3.144-4.337 6.99-4.337ZM264.829 193.262c3.846 0 6.989 1.953 6.989 4.337v95.667c0 2.385-3.143 4.337-6.989 4.337-3.843 0-6.99-1.951-6.99-4.337v-95.667c0-2.385 3.147-4.337 6.99-4.337ZM237.057 205.441c3.842 0 6.989 1.953 6.989 4.337v92.036c0 2.385-3.147 4.337-6.989 4.337-3.846 0-6.99-1.951-6.99-4.337v-92.036c0-2.385 3.144-4.337 6.99-4.337ZM209.281 221.302c3.846 0 6.989 1.952 6.989 4.337v80.134c0 2.385-3.147 4.337-6.989 4.337s-6.99-1.952-6.99-4.337v-80.134c0-2.386 3.144-4.337 6.99-4.337ZM181.505 232.271c3.846 0 6.993 1.952 6.993 4.336v78.924c0 2.385-3.147 4.337-6.993 4.337-3.842 0-6.989-1.951-6.989-4.337v-78.924c0-2.385 3.147-4.336 6.989-4.336ZM153.873 241.348c3.846 0 6.989 1.953 6.989 4.337v42.009c0 2.384-3.147 4.337-6.989 4.337-3.843 0-6.99-1.952-6.99-4.337v-42.009c0-2.385 3.147-4.337 6.99-4.337ZM125.266 200.398c3.842 0 6.989 1.953 6.989 4.337v58.55c0 2.384-3.147 4.337-6.989 4.337-3.843 0-6.99-1.951-6.99-4.337v-58.55c0-2.385 3.144-4.337 6.99-4.337ZM96.7 204.231c3.842 0 6.989 1.953 6.989 4.337v18.004c0 2.384-3.147 4.337-6.989 4.337s-6.989-1.952-6.989-4.337v-18.004c0-2.385 3.143-4.337 6.989-4.337ZM68.089 201.81c3.846 0 6.99 1.953 6.99 4.337v8.12c0 2.384-3.147 4.336-6.99 4.336-3.842 0-6.989-1.951-6.989-4.336v-8.12c0-2.385 3.143-4.337 6.989-4.337ZM153.873 194.94c3.846 0 6.989 1.953 6.989 4.337v6.102c0 2.384-3.147 4.337-6.989 4.337-3.843 0-6.99-1.952-6.99-4.337v-6.102c0-2.385 3.147-4.337 6.99-4.337Z" style="fill:#000;fill-rule:evenodd"></path></svg>
|
After Width: | Height: | Size: 3.4 KiB |
|
@ -0,0 +1 @@
|
|||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="61.1 180.15 377.8 139.718"><path d="M431.911 245.181c3.842 0 6.989 1.952 6.989 4.337v14.776c0 2.385-3.147 4.337-6.989 4.337-3.846 0-6.99-1.952-6.99-4.337v-14.776c0-2.385 3.144-4.337 6.99-4.337ZM404.135 250.955c3.846 0 6.989 1.952 6.989 4.337v32.528c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.952-6.989-4.337v-32.528c0-2.385 3.147-4.337 6.989-4.337ZM376.363 257.688c3.842 0 6.989 1.952 6.989 4.337v36.562c0 2.385-3.147 4.337-6.989 4.337-3.846 0-6.993-1.952-6.993-4.337v-36.562c0-2.386 3.147-4.337 6.993-4.337ZM348.587 263.26c3.846 0 6.989 1.952 6.989 4.337v36.159c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.952-6.989-4.337v-36.159c0-2.385 3.147-4.337 6.989-4.337ZM320.811 268.177c3.846 0 6.989 1.952 6.989 4.337v31.318c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.952-6.989-4.337v-31.318c0-2.385 3.147-4.337 6.989-4.337ZM293.179 288.148c3.846 0 6.989 1.952 6.989 4.337v9.935c0 2.384-3.147 4.336-6.989 4.336s-6.99-1.951-6.99-4.336v-9.935c0-2.386 3.144-4.337 6.99-4.337Z" style="fill:#b1b3b4;fill-rule:evenodd"></path><path d="M431.911 205.441c3.842 0 6.989 1.952 6.989 4.337v24.459c0 2.385-3.147 4.337-6.989 4.337-3.846 0-6.99-1.952-6.99-4.337v-24.459c0-2.385 3.144-4.337 6.99-4.337ZM404.135 189.026c3.846 0 6.989 1.952 6.989 4.337v43.622c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.951-6.989-4.337v-43.622c0-2.385 3.147-4.337 6.989-4.337ZM376.363 182.848c3.842 0 6.989 1.953 6.989 4.337v56.937c0 2.384-3.147 4.337-6.989 4.337-3.846 0-6.993-1.952-6.993-4.337v-56.937c0-2.385 3.147-4.337 6.993-4.337ZM348.587 180.15c3.846 0 6.989 1.952 6.989 4.337v66.619c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.952-6.989-4.337v-66.619c0-2.385 3.147-4.337 6.989-4.337ZM320.811 181.84c3.846 0 6.989 1.952 6.989 4.337v67.627c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.951-6.989-4.337v-67.627c0-2.386 3.147-4.337 6.989-4.337ZM293.179 186.076c3.846 0 6.989 1.952 6.989 4.337v84.37c0 2.385-3.147 4.337-6.989 4.337s-6.99-1.951-6.99-4.337v-84.37c0-2.386 3.144-4.337 6.99-4.337ZM264.829 193.262c3.846 0 6.989 1.953 6.989 4.337v95.667c0 2.385-3.143 4.337-6.989 4.337-3.843 0-6.99-1.951-6.99-4.337v-95.667c0-2.385 3.147-4.337 6.99-4.337ZM237.057 205.441c3.842 0 6.989 1.953 6.989 4.337v92.036c0 2.385-3.147 4.337-6.989 4.337-3.846 0-6.99-1.951-6.99-4.337v-92.036c0-2.385 3.144-4.337 6.99-4.337ZM209.281 221.302c3.846 0 6.989 1.952 6.989 4.337v80.134c0 2.385-3.147 4.337-6.989 4.337s-6.99-1.952-6.99-4.337v-80.134c0-2.386 3.144-4.337 6.99-4.337ZM181.505 232.271c3.846 0 6.993 1.952 6.993 4.336v78.924c0 2.385-3.147 4.337-6.993 4.337-3.842 0-6.989-1.951-6.989-4.337v-78.924c0-2.385 3.147-4.336 6.989-4.336ZM153.873 241.348c3.846 0 6.989 1.953 6.989 4.337v42.009c0 2.384-3.147 4.337-6.989 4.337-3.843 0-6.99-1.952-6.99-4.337v-42.009c0-2.385 3.147-4.337 6.99-4.337ZM125.266 200.398c3.842 0 6.989 1.953 6.989 4.337v58.55c0 2.384-3.147 4.337-6.989 4.337-3.843 0-6.99-1.951-6.99-4.337v-58.55c0-2.385 3.144-4.337 6.99-4.337ZM96.7 204.231c3.842 0 6.989 1.953 6.989 4.337v18.004c0 2.384-3.147 4.337-6.989 4.337s-6.989-1.952-6.989-4.337v-18.004c0-2.385 3.143-4.337 6.989-4.337ZM68.089 201.81c3.846 0 6.99 1.953 6.99 4.337v8.12c0 2.384-3.147 4.336-6.99 4.336-3.842 0-6.989-1.951-6.989-4.336v-8.12c0-2.385 3.143-4.337 6.989-4.337ZM153.873 194.94c3.846 0 6.989 1.953 6.989 4.337v6.102c0 2.384-3.147 4.337-6.989 4.337-3.843 0-6.99-1.952-6.99-4.337v-6.102c0-2.385 3.147-4.337 6.99-4.337Z" style="fill:#000;fill-rule:evenodd"></path></svg>
|
After Width: | Height: | Size: 3.4 KiB |
|
@ -0,0 +1,28 @@
|
|||
import logging
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FishAudioProvider(ModelProvider):
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
|
||||
For debugging purposes, this method now always passes validation.
|
||||
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.TTS)
|
||||
model_instance.validate_credentials(
|
||||
credentials=credentials
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
||||
raise ex
|
|
@ -0,0 +1,76 @@
|
|||
provider: fishaudio
|
||||
label:
|
||||
en_US: Fish Audio
|
||||
description:
|
||||
en_US: Models provided by Fish Audio, currently only support TTS.
|
||||
zh_Hans: Fish Audio 提供的模型,目前仅支持 TTS。
|
||||
icon_small:
|
||||
en_US: fishaudio_s_en.svg
|
||||
icon_large:
|
||||
en_US: fishaudio_l_en.svg
|
||||
background: "#E5E7EB"
|
||||
help:
|
||||
title:
|
||||
en_US: Get your API key from Fish Audio
|
||||
zh_Hans: 从 Fish Audio 获取你的 API Key
|
||||
url:
|
||||
en_US: https://fish.audio/go-api/
|
||||
supported_model_types:
|
||||
- tts
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
- variable: api_key
|
||||
label:
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
- variable: api_base
|
||||
label:
|
||||
en_US: API URL
|
||||
type: text-input
|
||||
required: false
|
||||
default: https://api.fish.audio
|
||||
placeholder:
|
||||
en_US: Enter your API URL
|
||||
zh_Hans: 在此输入您的 API URL
|
||||
- variable: use_public_models
|
||||
label:
|
||||
en_US: Use Public Models
|
||||
type: select
|
||||
required: false
|
||||
default: "false"
|
||||
placeholder:
|
||||
en_US: Toggle to use public models
|
||||
zh_Hans: 切换以使用公共模型
|
||||
options:
|
||||
- value: "true"
|
||||
label:
|
||||
en_US: Allow Public Models
|
||||
zh_Hans: 使用公共模型
|
||||
- value: "false"
|
||||
label:
|
||||
en_US: Private Models Only
|
||||
zh_Hans: 仅使用私有模型
|
||||
- variable: latency
|
||||
label:
|
||||
en_US: Latency
|
||||
type: select
|
||||
required: false
|
||||
default: "normal"
|
||||
placeholder:
|
||||
en_US: Toggle to choice latency
|
||||
zh_Hans: 切换以调整延迟
|
||||
options:
|
||||
- value: "balanced"
|
||||
label:
|
||||
en_US: Low (may affect quality)
|
||||
zh_Hans: 低延迟 (可能降低质量)
|
||||
- value: "normal"
|
||||
label:
|
||||
en_US: Normal
|
||||
zh_Hans: 标准
|
174
api/core/model_runtime/model_providers/fishaudio/tts/tts.py
Normal file
174
api/core/model_runtime/model_providers/fishaudio/tts/tts.py
Normal file
|
@ -0,0 +1,174 @@
|
|||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError, InvokeError
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
|
||||
|
||||
class FishAudioText2SpeechModel(TTSModel):
|
||||
"""
|
||||
Model class for Fish.audio Text to Speech model.
|
||||
"""
|
||||
|
||||
def get_tts_model_voices(
|
||||
self, model: str, credentials: dict, language: Optional[str] = None
|
||||
) -> list:
|
||||
api_base = credentials.get("api_base", "https://api.fish.audio")
|
||||
api_key = credentials.get("api_key")
|
||||
use_public_models = credentials.get("use_public_models", "false") == "true"
|
||||
|
||||
params = {
|
||||
"self": str(not use_public_models).lower(),
|
||||
"page_size": "100",
|
||||
}
|
||||
|
||||
if language is not None:
|
||||
if "-" in language:
|
||||
language = language.split("-")[0]
|
||||
params["language"] = language
|
||||
|
||||
results = httpx.get(
|
||||
f"{api_base}/model",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
params=params,
|
||||
)
|
||||
|
||||
results.raise_for_status()
|
||||
data = results.json()
|
||||
|
||||
return [{"name": i["title"], "value": i["_id"]} for i in data["items"]]
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
tenant_id: str,
|
||||
credentials: dict,
|
||||
content_text: str,
|
||||
voice: str,
|
||||
user: Optional[str] = None,
|
||||
) -> any:
|
||||
"""
|
||||
Invoke text2speech model
|
||||
|
||||
:param model: model name
|
||||
:param tenant_id: user tenant id
|
||||
:param credentials: model credentials
|
||||
:param voice: model timbre
|
||||
:param content_text: text content to be translated
|
||||
:param user: unique user id
|
||||
:return: generator yielding audio chunks
|
||||
"""
|
||||
|
||||
return self._tts_invoke_streaming(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
content_text=content_text,
|
||||
voice=voice,
|
||||
)
|
||||
|
||||
def validate_credentials(
|
||||
self, credentials: dict, user: Optional[str] = None
|
||||
) -> None:
|
||||
"""
|
||||
Validate credentials for text2speech model
|
||||
|
||||
:param credentials: model credentials
|
||||
:param user: unique user id
|
||||
"""
|
||||
|
||||
try:
|
||||
self.get_tts_model_voices(
|
||||
None,
|
||||
credentials={
|
||||
"api_key": credentials["api_key"],
|
||||
"api_base": credentials["api_base"],
|
||||
# Disable public models will trigger a 403 error if user is not logged in
|
||||
"use_public_models": "false",
|
||||
},
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _tts_invoke_streaming(
|
||||
self, model: str, credentials: dict, content_text: str, voice: str
|
||||
) -> any:
|
||||
"""
|
||||
Invoke streaming text2speech model
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param voice: ID of the reference audio (if any)
|
||||
:return: generator yielding audio chunks
|
||||
"""
|
||||
|
||||
try:
|
||||
word_limit = self._get_model_word_limit(model, credentials)
|
||||
if len(content_text) > word_limit:
|
||||
sentences = self._split_text_into_sentences(
|
||||
content_text, max_length=word_limit
|
||||
)
|
||||
else:
|
||||
sentences = [content_text.strip()]
|
||||
|
||||
for i in range(len(sentences)):
|
||||
yield from self._tts_invoke_streaming_sentence(
|
||||
credentials=credentials, content_text=sentences[i], voice=voice
|
||||
)
|
||||
|
||||
except Exception as ex:
|
||||
raise InvokeBadRequestError(str(ex))
|
||||
|
||||
def _tts_invoke_streaming_sentence(
|
||||
self, credentials: dict, content_text: str, voice: Optional[str] = None
|
||||
) -> any:
|
||||
"""
|
||||
Invoke streaming text2speech model
|
||||
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param voice: ID of the reference audio (if any)
|
||||
:return: generator yielding audio chunks
|
||||
"""
|
||||
api_key = credentials.get("api_key")
|
||||
api_url = credentials.get("api_base", "https://api.fish.audio")
|
||||
latency = credentials.get("latency")
|
||||
|
||||
if not api_key:
|
||||
raise InvokeBadRequestError("API key is required")
|
||||
|
||||
with httpx.stream(
|
||||
"POST",
|
||||
api_url + "/v1/tts",
|
||||
json={
|
||||
"text": content_text,
|
||||
"reference_id": voice,
|
||||
"latency": latency
|
||||
},
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
},
|
||||
timeout=None,
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
raise InvokeBadRequestError(
|
||||
f"Error: {response.status_code} - {response.text}"
|
||||
)
|
||||
yield from response.iter_bytes()
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeBadRequestError: [
|
||||
httpx.HTTPStatusError,
|
||||
],
|
||||
}
|
|
@ -0,0 +1,5 @@
|
|||
model: tts-default
|
||||
model_type: tts
|
||||
model_properties:
|
||||
word_limit: 1000
|
||||
audio_type: 'mp3'
|
|
@ -416,11 +416,11 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the ermd = genai.GenerativeModel(model)ror type thrown to the caller
|
||||
The value is the md = genai.GenerativeModel(model)error type thrown by the model,
|
||||
The key is the ermd = genai.GenerativeModel(model) error type thrown to the caller
|
||||
The value is the md = genai.GenerativeModel(model) error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke emd = genai.GenerativeModel(model)rror mapping
|
||||
:return: Invoke emd = genai.GenerativeModel(model) error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
|
|
|
@ -86,7 +86,7 @@ class MinimaxLargeLanguageModel(LargeLanguageModel):
|
|||
Calculate num tokens for minimax model
|
||||
|
||||
not like ChatGLM, Minimax has a special prompt structure, we could not find a proper way
|
||||
to caculate the num tokens, so we use str() to convert the prompt to string
|
||||
to calculate the num tokens, so we use str() to convert the prompt to string
|
||||
|
||||
Minimax does not provide their own tokenizer of adab5.5 and abab5 model
|
||||
therefore, we use gpt2 tokenizer instead
|
||||
|
|
|
@ -10,6 +10,7 @@ from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAI
|
|||
class NovitaLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
|
||||
def _update_endpoint_url(self, credentials: dict):
|
||||
|
||||
credentials['endpoint_url'] = "https://api.novita.ai/v3/openai"
|
||||
credentials['extra_headers'] = { 'X-Novita-Source': 'dify.ai' }
|
||||
return credentials
|
||||
|
|
|
@ -54,7 +54,6 @@ class NvidiaRerankModel(RerankModel):
|
|||
"query": {"text": query},
|
||||
"passages": [{"text": doc} for doc in docs],
|
||||
}
|
||||
|
||||
session = requests.Session()
|
||||
response = session.post(invoke_url, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
|
@ -71,7 +70,10 @@ class NvidiaRerankModel(RerankModel):
|
|||
)
|
||||
|
||||
rerank_documents.append(rerank_document)
|
||||
|
||||
if rerank_documents:
|
||||
rerank_documents = sorted(rerank_documents, key=lambda x: x.score, reverse=True)
|
||||
if top_n:
|
||||
rerank_documents = rerank_documents[:top_n]
|
||||
return RerankResult(model=model, docs=rerank_documents)
|
||||
except requests.HTTPError as e:
|
||||
raise InvokeServerUnavailableError(str(e))
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
<svg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 231 30' preserveAspectRatio='xMinYMid'><path d='M99.61,19.52h15.24l-8.05-13L92,30H85.27l18-28.17a4.29,4.29,0,0,1,7-.05L128.32,30h-6.73l-3.17-5.25H103l-3.36-5.23m69.93,5.23V0.28h-5.72V27.16a2.76,2.76,0,0,0,.85,2,2.89,2.89,0,0,0,2.08.87h26l3.39-5.25H169.54M75,20.38A10,10,0,0,0,75,.28H50V30h5.71V5.54H74.65a4.81,4.81,0,0,1,0,9.62H58.54L75.6,30h8.29L72.43,20.38H75M14.88,30H32.15a14.86,14.86,0,0,0,0-29.71H14.88a14.86,14.86,0,1,0,0,29.71m16.88-5.23H15.26a9.62,9.62,0,0,1,0-19.23h16.5a9.62,9.62,0,1,1,0,19.23M140.25,30h17.63l3.34-5.23H140.64a9.62,9.62,0,1,1,0-19.23h16.75l3.38-5.25H140.25a14.86,14.86,0,1,0,0,29.71m69.87-5.23a9.62,9.62,0,0,1-9.26-7h24.42l3.36-5.24H200.86a9.61,9.61,0,0,1,9.26-7h16.76l3.35-5.25h-20.5a14.86,14.86,0,0,0,0,29.71h17.63l3.35-5.23h-20.6' transform='translate(-0.02 0)' style='fill:#C74634'/></svg>
|
After Width: | Height: | Size: 874 B |
|
@ -0,0 +1 @@
|
|||
<svg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 231 30' preserveAspectRatio='xMinYMid'><path d='M99.61,19.52h15.24l-8.05-13L92,30H85.27l18-28.17a4.29,4.29,0,0,1,7-.05L128.32,30h-6.73l-3.17-5.25H103l-3.36-5.23m69.93,5.23V0.28h-5.72V27.16a2.76,2.76,0,0,0,.85,2,2.89,2.89,0,0,0,2.08.87h26l3.39-5.25H169.54M75,20.38A10,10,0,0,0,75,.28H50V30h5.71V5.54H74.65a4.81,4.81,0,0,1,0,9.62H58.54L75.6,30h8.29L72.43,20.38H75M14.88,30H32.15a14.86,14.86,0,0,0,0-29.71H14.88a14.86,14.86,0,1,0,0,29.71m16.88-5.23H15.26a9.62,9.62,0,0,1,0-19.23h16.5a9.62,9.62,0,1,1,0,19.23M140.25,30h17.63l3.34-5.23H140.64a9.62,9.62,0,1,1,0-19.23h16.75l3.38-5.25H140.25a14.86,14.86,0,1,0,0,29.71m69.87-5.23a9.62,9.62,0,0,1-9.26-7h24.42l3.36-5.24H200.86a9.61,9.61,0,0,1,9.26-7h16.76l3.35-5.25h-20.5a14.86,14.86,0,0,0,0,29.71h17.63l3.35-5.23h-20.6' transform='translate(-0.02 0)' style='fill:#C74634'/></svg>
|
After Width: | Height: | Size: 874 B |
|
@ -0,0 +1,52 @@
|
|||
model: cohere.command-r-16k
|
||||
label:
|
||||
en_US: cohere.command-r-16k v1.2
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 1
|
||||
max: 1.0
|
||||
- name: topP
|
||||
use_template: top_p
|
||||
default: 0.75
|
||||
min: 0
|
||||
max: 1
|
||||
- name: topK
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
default: 0
|
||||
min: 0
|
||||
max: 500
|
||||
- name: presencePenalty
|
||||
use_template: presence_penalty
|
||||
min: 0
|
||||
max: 1
|
||||
default: 0
|
||||
- name: frequencyPenalty
|
||||
use_template: frequency_penalty
|
||||
min: 0
|
||||
max: 1
|
||||
default: 0
|
||||
- name: maxTokens
|
||||
use_template: max_tokens
|
||||
default: 600
|
||||
max: 4000
|
||||
pricing:
|
||||
input: '0.004'
|
||||
output: '0.004'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user