Merge branch 'fix/chore-fix' into dev/plugin-deploy

This commit is contained in:
Yeuoly 2024-11-01 16:46:40 +08:00
commit 7cea6c1713
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61
215 changed files with 2741 additions and 6645 deletions

View File

@ -1,3 +1,3 @@
#!/bin/bash
poetry install -C api
cd api && poetry install

View File

@ -50,7 +50,7 @@ jobs:
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
- name: Login to Docker Hub
uses: docker/login-action@v2
uses: docker/login-action@v3
with:
username: ${{ env.DOCKERHUB_USER }}
password: ${{ env.DOCKERHUB_TOKEN }}
@ -115,7 +115,7 @@ jobs:
merge-multiple: true
- name: Login to Docker Hub
uses: docker/login-action@v2
uses: docker/login-action@v3
with:
username: ${{ env.DOCKERHUB_USER }}
password: ${{ env.DOCKERHUB_TOKEN }}

153
README.md
View File

@ -46,6 +46,56 @@
</p>
## Table of Content
0. [Quick-Start🚀](https://github.com/langgenius/dify?tab=readme-ov-file#quick-start)
1. [Intro📖](https://github.com/langgenius/dify?tab=readme-ov-file#intro)
2. [How to use🔧](https://github.com/langgenius/dify?tab=readme-ov-file#using-dify)
3. [Stay Ahead🏃](https://github.com/langgenius/dify?tab=readme-ov-file#staying-ahead)
4. [Next Steps🏹](https://github.com/langgenius/dify?tab=readme-ov-file#next-steps)
5. [Contributing💪](https://github.com/langgenius/dify?tab=readme-ov-file#contributing)
6. [Community and Contact🏠](https://github.com/langgenius/dify?tab=readme-ov-file#community--contact)
7. [Star-History📈](https://github.com/langgenius/dify?tab=readme-ov-file#star-history)
8. [Security🔒](https://github.com/langgenius/dify?tab=readme-ov-file#security-disclosure)
9. [License🤝](https://github.com/langgenius/dify?tab=readme-ov-file#license)
> Make sure you read through this README before you start utilizing Dify😊
## Quick start
The quickest way to deploy Dify locally is to run our [docker-compose.yml](https://github.com/langgenius/dify/blob/main/docker/docker-compose.yaml). Follow the instructions to start in 5 minutes.
> Before installing Dify, make sure your machine meets the following minimum system requirements:
>
>- CPU >= 2 Core
>- RAM >= 4 GiB
>- Docker and Docker Compose Installed
</br>
Run the following command in your terminal to clone the whole repo.
```bash
git clone https://github.com/langgenius/dify.git
```
After cloning,run the following command one by one.
```bash
cd dify
cd docker
cp .env.example .env
docker compose up -d
```
After running, you can access the Dify dashboard in your browser at [http://localhost/install](http://localhost/install) and start the initialization process. You will be asked to setup an admin account.
For more info of quick setup, check [here](https://docs.dify.ai/getting-started/install-self-hosted/docker-compose)
## Intro
Dify is an open-source LLM app development platform. Its intuitive interface combines AI workflow, RAG pipeline, agent capabilities, model management, observability features and more, letting you quickly go from prototype to production. Here's a list of the core features:
</br> </br>
@ -79,73 +129,6 @@ Dify is an open-source LLM app development platform. Its intuitive interface com
All of Dify's offerings come with corresponding APIs, so you could effortlessly integrate Dify into your own business logic.
## Feature comparison
<table style="width: 100%;">
<tr>
<th align="center">Feature</th>
<th align="center">Dify.AI</th>
<th align="center">LangChain</th>
<th align="center">Flowise</th>
<th align="center">OpenAI Assistants API</th>
</tr>
<tr>
<td align="center">Programming Approach</td>
<td align="center">API + App-oriented</td>
<td align="center">Python Code</td>
<td align="center">App-oriented</td>
<td align="center">API-oriented</td>
</tr>
<tr>
<td align="center">Supported LLMs</td>
<td align="center">Rich Variety</td>
<td align="center">Rich Variety</td>
<td align="center">Rich Variety</td>
<td align="center">OpenAI-only</td>
</tr>
<tr>
<td align="center">RAG Engine</td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
</tr>
<tr>
<td align="center">Agent</td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
</tr>
<tr>
<td align="center">Workflow</td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
</tr>
<tr>
<td align="center">Observability</td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
</tr>
<tr>
<td align="center">Enterprise Features (SSO/Access control)</td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
</tr>
<tr>
<td align="center">Local Deployment</td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
<td align="center"></td>
</tr>
</table>
## Using Dify
- **Cloud </br>**
@ -166,30 +149,21 @@ Star Dify on GitHub and be instantly notified of new releases.
![star-us](https://github.com/langgenius/dify/assets/13230914/b823edc1-6388-4e25-ad45-2f6b187adbb4)
## Quick start
> Before installing Dify, make sure your machine meets the following minimum system requirements:
>
>- CPU >= 2 Core
>- RAM >= 4 GiB
</br>
The easiest way to start the Dify server is to run our [docker-compose.yml](docker/docker-compose.yaml) file. Before running the installation command, make sure that [Docker](https://docs.docker.com/get-docker/) and [Docker Compose](https://docs.docker.com/compose/install/) are installed on your machine:
```bash
cd docker
cp .env.example .env
docker compose up -d
```
After running, you can access the Dify dashboard in your browser at [http://localhost/install](http://localhost/install) and start the initialization process.
> If you'd like to contribute to Dify or do additional development, refer to our [guide to deploying from source code](https://docs.dify.ai/getting-started/install-self-hosted/local-source-code)
## Next steps
Go to [quick-start](https://github.com/langgenius/dify?tab=readme-ov-file#quick-start) to setup your Dify or setup by source code.
#### If you......
If you forget your admin account, you can refer to this [guide](https://docs.dify.ai/getting-started/install-self-hosted/faqs#id-4.-how-to-reset-the-password-of-the-admin-account) to reset the password.
> Use docker compose up without "-d" to enable logs printing out in your terminal. This might be useful if you have encountered unknow problems when using Dify.
If you encountered system error and would like to acquire help in Github issues, make sure you always paste logs of the error in the request to accerate the conversation. Go to [Community & contact](https://github.com/langgenius/dify?tab=readme-ov-file#community--contact) for more information.
> Please read the [Dify Documentation](https://docs.dify.ai/) for detailed how-to-use guidance. Most of the potential problems are explained in the doc.
> If you'd like to contribute to Dify or make additional development, refer to our [guide to deploying from source code](https://docs.dify.ai/getting-started/install-self-hosted/local-source-code)
If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker-compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
If you'd like to configure a highly-available setup, there are community-contributed [Helm Charts](https://helm.sh/) and YAML files which allow Dify to be deployed on Kubernetes.
@ -228,6 +202,7 @@ At the same time, please consider supporting Dify by sharing it on social media
* [GitHub Issues](https://github.com/langgenius/dify/issues). Best for: bugs you encounter using Dify.AI, and feature proposals. See our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
* [Discord](https://discord.gg/FngNHpbcY7). Best for: sharing your applications and hanging out with the community.
* [X(Twitter)](https://twitter.com/dify_ai). Best for: sharing your applications and hanging out with the community.
* Make sure a log, if possible, is attached to an error reported to maximize solution efficiency.
## Star history

View File

@ -55,7 +55,12 @@ RUN apt-get update \
&& echo "deb http://deb.debian.org/debian testing main" > /etc/apt/sources.list \
&& apt-get update \
# For Security
&& apt-get install -y --no-install-recommends libsqlite3-0=3.46.1-1 \
&& apt-get install -y --no-install-recommends expat=2.6.3-2 libldap-2.5-0=2.5.18+dfsg-3+b1 perl=5.40.0-6 libsqlite3-0=3.46.1-1 \
&& if [ "$(dpkg --print-architecture)" = "amd64" ]; then \
apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1+b1; \
else \
apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1; \
fi \
# install a chinese font to support the use of tools like matplotlib
&& apt-get install -y fonts-noto-cjk \
&& apt-get autoremove -y \

View File

@ -10,7 +10,6 @@ from pydantic import (
PositiveInt,
computed_field,
)
from pydantic_extra_types.timezone_name import TimeZoneName
from pydantic_settings import BaseSettings
from configs.feature.hosted_service import HostedServiceConfig
@ -393,9 +392,8 @@ class LoggingConfig(BaseSettings):
default=None,
)
LOG_TZ: Optional[TimeZoneName] = Field(
description="Timezone for log timestamps. Allowed timezone values can be referred to IANA Time Zone Database,"
" e.g., 'America/New_York')",
LOG_TZ: Optional[str] = Field(
description="Timezone for log timestamps (e.g., 'America/New_York')",
default=None,
)

View File

@ -16,6 +16,7 @@ from configs.middleware.storage.supabase_storage_config import SupabaseStorageCo
from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
from configs.middleware.storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig
from configs.middleware.vdb.baidu_vector_config import BaiduVectorDBConfig
from configs.middleware.vdb.chroma_config import ChromaConfig
from configs.middleware.vdb.couchbase_config import CouchbaseConfig
from configs.middleware.vdb.elasticsearch_config import ElasticsearchConfig
@ -259,5 +260,6 @@ class MiddlewareConfig(
UpstashConfig,
TidbOnQdrantConfig,
OceanBaseVectorConfig,
BaiduVectorDBConfig,
):
pass

View File

@ -0,0 +1,6 @@
from werkzeug.exceptions import HTTPException
class FilenameNotExistsError(HTTPException):
code = 400
description = "The specified filename does not exist."

View File

@ -0,0 +1,58 @@
import mimetypes
import os
import re
import urllib.parse
from uuid import uuid4
import httpx
from pydantic import BaseModel
class FileInfo(BaseModel):
filename: str
extension: str
mimetype: str
size: int
def guess_file_info_from_response(response: httpx.Response):
url = str(response.url)
# Try to extract filename from URL
parsed_url = urllib.parse.urlparse(url)
url_path = parsed_url.path
filename = os.path.basename(url_path)
# If filename couldn't be extracted, use Content-Disposition header
if not filename:
content_disposition = response.headers.get("Content-Disposition")
if content_disposition:
filename_match = re.search(r'filename="?(.+)"?', content_disposition)
if filename_match:
filename = filename_match.group(1)
# If still no filename, generate a unique one
if not filename:
unique_name = str(uuid4())
filename = f"{unique_name}"
# Guess MIME type from filename first, then URL
mimetype, _ = mimetypes.guess_type(filename)
if mimetype is None:
mimetype, _ = mimetypes.guess_type(url)
if mimetype is None:
# If guessing fails, use Content-Type from response headers
mimetype = response.headers.get("Content-Type", "application/octet-stream")
extension = os.path.splitext(filename)[1]
# Ensure filename has an extension
if not extension:
extension = mimetypes.guess_extension(mimetype) or ".bin"
filename = f"{filename}{extension}"
return FileInfo(
filename=filename,
extension=extension,
mimetype=mimetype,
size=int(response.headers.get("Content-Length", -1)),
)

View File

@ -2,9 +2,21 @@ from flask import Blueprint
from libs.external_api import ExternalApi
from .files import FileApi, FilePreviewApi, FileSupportTypeApi
from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi
bp = Blueprint("console", __name__, url_prefix="/console/api")
api = ExternalApi(bp)
# File
api.add_resource(FileApi, "/files/upload")
api.add_resource(FilePreviewApi, "/files/<uuid:file_id>/preview")
api.add_resource(FileSupportTypeApi, "/files/support-type")
# Remote files
api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>")
api.add_resource(RemoteFileUploadApi, "/remote-files/upload")
# Import other controllers
from . import admin, apikey, extension, feature, ping, setup, version
@ -43,7 +55,6 @@ from .datasets import (
datasets_document,
datasets_segments,
external,
file,
hit_testing,
website,
)

View File

@ -12,8 +12,7 @@ from models.dataset import Dataset
from models.model import ApiToken, App
from . import api
from .setup import setup_required
from .wraps import account_initialization_required
from .wraps import account_initialization_required, setup_required
api_key_fields = {
"id": fields.String,

View File

@ -1,8 +1,7 @@
from flask_restful import Resource, reqparse
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from controllers.console.wraps import account_initialization_required, setup_required
from libs.login import login_required
from services.advanced_prompt_template_service import AdvancedPromptTemplateService

View File

@ -2,8 +2,7 @@ from flask_restful import Resource, reqparse
from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from controllers.console.wraps import account_initialization_required, setup_required
from libs.helper import uuid_value
from libs.login import login_required
from models.model import AppMode

View File

@ -6,8 +6,11 @@ from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.app.error import NoFileUploadedError
from controllers.console.datasets.error import TooManyFilesError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
setup_required,
)
from extensions.ext_redis import redis_client
from fields.annotation_fields import (
annotation_fields,

View File

@ -6,8 +6,11 @@ from werkzeug.exceptions import BadRequest, Forbidden, abort
from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
setup_required,
)
from core.ops.ops_trace_manager import OpsTraceManager
from fields.app_fields import (
app_detail_fields,

View File

@ -18,8 +18,7 @@ from controllers.console.app.error import (
UnsupportedAudioTypeError,
)
from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from controllers.console.wraps import account_initialization_required, setup_required
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from libs.login import login_required

View File

@ -15,8 +15,7 @@ from controllers.console.app.error import (
ProviderQuotaExceededError,
)
from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom

View File

@ -10,8 +10,7 @@ from werkzeug.exceptions import Forbidden, NotFound
from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from controllers.console.wraps import account_initialization_required, setup_required
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from fields.conversation_fields import (

View File

@ -4,8 +4,7 @@ from sqlalchemy.orm import Session
from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db
from fields.conversation_variable_fields import paginated_conversation_variable_fields
from libs.login import login_required

View File

@ -10,8 +10,7 @@ from controllers.console.app.error import (
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from controllers.console.wraps import account_initialization_required, setup_required
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.llm_generator.llm_generator import LLMGenerator
from core.model_runtime.errors.invoke import InvokeError

View File

@ -14,8 +14,11 @@ from controllers.console.app.error import (
)
from controllers.console.app.wraps import get_app_model
from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
setup_required,
)
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError

View File

@ -6,8 +6,7 @@ from flask_restful import Resource
from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from controllers.console.wraps import account_initialization_required, setup_required
from core.agent.entities import AgentToolEntity
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolParameterConfigurationManager

View File

@ -2,8 +2,7 @@ from flask_restful import Resource, reqparse
from controllers.console import api
from controllers.console.app.error import TracingConfigCheckError, TracingConfigIsExist, TracingConfigNotExist
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from controllers.console.wraps import account_initialization_required, setup_required
from libs.login import login_required
from services.ops_service import OpsService

View File

@ -7,8 +7,7 @@ from werkzeug.exceptions import Forbidden, NotFound
from constants.languages import supported_language
from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db
from fields.app_fields import app_site_fields
from libs.login import login_required

View File

@ -8,8 +8,7 @@ from flask_restful import Resource, reqparse
from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db
from libs.helper import DatetimeString
from libs.login import login_required

View File

@ -9,8 +9,7 @@ import services
from controllers.console import api
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from controllers.console.wraps import account_initialization_required, setup_required
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from factories import variable_factory

View File

@ -3,8 +3,7 @@ from flask_restful.inputs import int_range
from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from controllers.console.wraps import account_initialization_required, setup_required
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
from libs.login import login_required
from models import App

View File

@ -3,8 +3,7 @@ from flask_restful.inputs import int_range
from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from controllers.console.wraps import account_initialization_required, setup_required
from fields.workflow_run_fields import (
advanced_chat_workflow_run_pagination_fields,
workflow_run_detail_fields,

View File

@ -8,8 +8,7 @@ from flask_restful import Resource, reqparse
from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db
from libs.helper import DatetimeString
from libs.login import login_required

View File

@ -7,8 +7,7 @@ from controllers.console.auth.error import ApiKeyAuthFailedError
from libs.login import login_required
from services.auth.api_key_auth_service import ApiKeyAuthService
from ..setup import setup_required
from ..wraps import account_initialization_required
from ..wraps import account_initialization_required, setup_required
class ApiKeyAuthDataSource(Resource):

View File

@ -11,8 +11,7 @@ from controllers.console import api
from libs.login import login_required
from libs.oauth_data_source import NotionOAuth
from ..setup import setup_required
from ..wraps import account_initialization_required
from ..wraps import account_initialization_required, setup_required
def get_oauth_providers():

View File

@ -15,7 +15,7 @@ from controllers.console.auth.error import (
PasswordMismatchError,
)
from controllers.console.error import EmailSendIpLimitError, NotAllowedRegister
from controllers.console.setup import setup_required
from controllers.console.wraps import setup_required
from events.tenant_event import tenant_was_created
from extensions.ext_database import db
from libs.helper import email, extract_remote_ip

View File

@ -20,7 +20,7 @@ from controllers.console.error import (
NotAllowedCreateWorkspace,
NotAllowedRegister,
)
from controllers.console.setup import setup_required
from controllers.console.wraps import setup_required
from events.tenant_event import tenant_was_created
from libs.helper import email, extract_remote_ip
from libs.password import valid_password

View File

@ -2,8 +2,7 @@ from flask_login import current_user
from flask_restful import Resource, reqparse
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required, only_edition_cloud
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
from libs.login import login_required
from services.billing_service import BillingService

View File

@ -9,8 +9,7 @@ from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from controllers.console.wraps import account_initialization_required, setup_required
from core.indexing_runner import IndexingRunner
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.notion_extractor import NotionExtractor

View File

@ -10,8 +10,7 @@ from controllers.console import api
from controllers.console.apikey import api_key_fields, api_key_list
from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from controllers.console.wraps import account_initialization_required, setup_required
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.indexing_runner import IndexingRunner
from core.model_runtime.entities.model_entities import ModelType

View File

@ -23,8 +23,11 @@ from controllers.console.datasets.error import (
InvalidActionError,
InvalidMetadataError,
)
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
setup_required,
)
from core.errors.error import (
LLMBadRequestError,
ModelCurrentlyNotSupportError,

View File

@ -11,11 +11,11 @@ import services
from controllers.console import api
from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import InvalidActionError, NoFileUploadedError, TooManyFilesError
from controllers.console.setup import setup_required
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_knowledge_limit_check,
cloud_edition_billing_resource_check,
setup_required,
)
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_manager import ModelManager

View File

@ -6,8 +6,7 @@ from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
from controllers.console import api
from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from controllers.console.wraps import account_initialization_required, setup_required
from fields.dataset_fields import dataset_detail_fields
from libs.login import login_required
from services.dataset_service import DatasetService

View File

@ -2,8 +2,7 @@ from flask_restful import Resource
from controllers.console import api
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from controllers.console.wraps import account_initialization_required, setup_required
from libs.login import login_required

View File

@ -2,8 +2,7 @@ from flask_restful import Resource, reqparse
from controllers.console import api
from controllers.console.datasets.error import WebsiteCrawlError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from controllers.console.wraps import account_initialization_required, setup_required
from libs.login import login_required
from services.website_service import WebsiteService

View File

@ -3,8 +3,7 @@ from flask_restful import Resource, marshal_with, reqparse
from constants import HIDDEN_VALUE
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from controllers.console.wraps import account_initialization_required, setup_required
from fields.api_based_extension_fields import api_based_extension_fields
from libs.login import login_required
from models.api_based_extension import APIBasedExtension

View File

@ -5,8 +5,7 @@ from libs.login import login_required
from services.feature_service import FeatureService
from . import api
from .setup import setup_required
from .wraps import account_initialization_required, cloud_utm_record
from .wraps import account_initialization_required, cloud_utm_record, setup_required
class FeatureApi(Resource):

View File

@ -1,25 +1,26 @@
import urllib.parse
from flask import request
from flask_login import current_user
from flask_restful import Resource, marshal_with, reqparse
from flask_restful import Resource, marshal_with
import services
from configs import dify_config
from constants import DOCUMENT_EXTENSIONS
from controllers.console import api
from controllers.console.datasets.error import (
from controllers.common.errors import FilenameNotExistsError
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
setup_required,
)
from fields.file_fields import file_fields, upload_config_fields
from libs.login import login_required
from services.file_service import FileService
from .errors import (
FileTooLargeError,
NoFileUploadedError,
TooManyFilesError,
UnsupportedFileTypeError,
)
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from core.helper import ssrf_proxy
from fields.file_fields import file_fields, remote_file_info_fields, upload_config_fields
from libs.login import login_required
from services.file_service import FileService
PREVIEW_WORDS_LIMIT = 3000
@ -44,21 +45,29 @@ class FileApi(Resource):
@marshal_with(file_fields)
@cloud_edition_billing_resource_check("documents")
def post(self):
# get file from request
file = request.files["file"]
source = request.form.get("source")
parser = reqparse.RequestParser()
parser.add_argument("source", type=str, required=False, location="args")
source = parser.parse_args().get("source")
# check file
if "file" not in request.files:
raise NoFileUploadedError()
if len(request.files) > 1:
raise TooManyFilesError()
if not file.filename:
raise FilenameNotExistsError
if source not in ("datasets", None):
source = None
try:
upload_file = FileService.upload_file(file=file, user=current_user, source=source)
upload_file = FileService.upload_file(
filename=file.filename,
content=file.read(),
mimetype=file.mimetype,
user=current_user,
source=source,
)
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
@ -83,23 +92,3 @@ class FileSupportTypeApi(Resource):
@account_initialization_required
def get(self):
return {"allowed_extensions": DOCUMENT_EXTENSIONS}
class RemoteFileInfoApi(Resource):
@marshal_with(remote_file_info_fields)
def get(self, url):
decoded_url = urllib.parse.unquote(url)
try:
response = ssrf_proxy.head(decoded_url)
return {
"file_type": response.headers.get("Content-Type", "application/octet-stream"),
"file_length": int(response.headers.get("Content-Length", 0)),
}
except Exception as e:
return {"error": str(e)}, 400
api.add_resource(FileApi, "/files/upload")
api.add_resource(FilePreviewApi, "/files/<uuid:file_id>/preview")
api.add_resource(FileSupportTypeApi, "/files/support-type")
api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>")

View File

@ -0,0 +1,25 @@
from libs.exception import BaseHTTPException
class FileTooLargeError(BaseHTTPException):
error_code = "file_too_large"
description = "File size exceeded. {message}"
code = 413
class UnsupportedFileTypeError(BaseHTTPException):
error_code = "unsupported_file_type"
description = "File type not allowed."
code = 415
class TooManyFilesError(BaseHTTPException):
error_code = "too_many_files"
description = "Only one file is allowed."
code = 400
class NoFileUploadedError(BaseHTTPException):
error_code = "no_file_uploaded"
description = "Please upload your file."
code = 400

View File

@ -0,0 +1,71 @@
import urllib.parse
from typing import cast
from flask_login import current_user
from flask_restful import Resource, marshal_with, reqparse
from controllers.common import helpers
from core.file import helpers as file_helpers
from core.helper import ssrf_proxy
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
from models.account import Account
from services.file_service import FileService
class RemoteFileInfoApi(Resource):
@marshal_with(remote_file_info_fields)
def get(self, url):
decoded_url = urllib.parse.unquote(url)
try:
response = ssrf_proxy.head(decoded_url)
return {
"file_type": response.headers.get("Content-Type", "application/octet-stream"),
"file_length": int(response.headers.get("Content-Length", 0)),
}
except Exception as e:
return {"error": str(e)}, 400
class RemoteFileUploadApi(Resource):
@marshal_with(file_fields_with_signed_url)
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("url", type=str, required=True, help="URL is required")
args = parser.parse_args()
url = args["url"]
response = ssrf_proxy.head(url)
response.raise_for_status()
file_info = helpers.guess_file_info_from_response(response)
if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size):
return {"error": "File size exceeded"}, 400
response = ssrf_proxy.get(url)
response.raise_for_status()
content = response.content
try:
user = cast(Account, current_user)
upload_file = FileService.upload_file(
filename=file_info.filename,
content=content,
mimetype=file_info.mimetype,
user=user,
source_url=url,
)
except Exception as e:
return {"error": str(e)}, 400
return {
"id": upload_file.id,
"name": upload_file.name,
"size": upload_file.size,
"extension": upload_file.extension,
"url": file_helpers.get_signed_file_url(upload_file_id=upload_file.id),
"mime_type": upload_file.mime_type,
"created_by": upload_file.created_by,
"created_at": upload_file.created_at,
}, 201

View File

@ -1,5 +1,3 @@
from functools import wraps
from flask import request
from flask_restful import Resource, reqparse
@ -10,7 +8,7 @@ from models.model import DifySetup, db
from services.account_service import RegisterService, TenantService
from . import api
from .error import AlreadySetupError, NotInitValidateError, NotSetupError
from .error import AlreadySetupError, NotInitValidateError
from .init_validate import get_init_validate_status
from .wraps import only_edition_self_hosted
@ -52,21 +50,6 @@ class SetupApi(Resource):
return {"result": "success"}, 201
def setup_required(view):
@wraps(view)
def decorated(*args, **kwargs):
# check setup
if not get_init_validate_status():
raise NotInitValidateError()
elif not get_setup_status():
raise NotSetupError()
return view(*args, **kwargs)
return decorated
def get_setup_status():
if dify_config.EDITION == "SELF_HOSTED":
return db.session.query(DifySetup).first()

View File

@ -4,8 +4,7 @@ from flask_restful import Resource, marshal_with, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from controllers.console.wraps import account_initialization_required, setup_required
from fields.tag_fields import tag_fields
from libs.login import login_required
from models.model import Tag

View File

@ -3,6 +3,7 @@ import logging
import requests
from flask_restful import Resource, reqparse
from packaging import version
from configs import dify_config
@ -47,43 +48,15 @@ class VersionApi(Resource):
def _has_new_version(*, latest_version: str, current_version: str) -> bool:
def parse_version(version: str) -> tuple:
# Split version into parts and pre-release suffix if any
parts = version.split("-")
version_parts = parts[0].split(".")
pre_release = parts[1] if len(parts) > 1 else None
try:
latest = version.parse(latest_version)
current = version.parse(current_version)
# Validate version format
if len(version_parts) != 3:
raise ValueError(f"Invalid version format: {version}")
try:
# Convert version parts to integers
major, minor, patch = map(int, version_parts)
return (major, minor, patch, pre_release)
except ValueError:
raise ValueError(f"Invalid version format: {version}")
latest = parse_version(latest_version)
current = parse_version(current_version)
# Compare major, minor, and patch versions
for latest_part, current_part in zip(latest[:3], current[:3]):
if latest_part > current_part:
return True
elif latest_part < current_part:
return False
# If versions are equal, check pre-release suffixes
if latest[3] is None and current[3] is not None:
return True
elif latest[3] is not None and current[3] is None:
# Compare versions
return latest > current
except version.InvalidVersion:
logging.warning(f"Invalid version format: latest={latest_version}, current={current_version}")
return False
elif latest[3] is not None and current[3] is not None:
# Simple string comparison for pre-release versions
return latest[3] > current[3]
return False
api.add_resource(VersionApi, "/version")

View File

@ -8,14 +8,13 @@ from flask_restful import Resource, fields, marshal_with, reqparse
from configs import dify_config
from constants.languages import supported_language
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.workspace.error import (
AccountAlreadyInitedError,
CurrentPasswordIncorrectError,
InvalidInvitationCodeError,
RepeatPasswordNotMatchError,
)
from controllers.console.wraps import account_initialization_required
from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db
from fields.member_fields import account_fields
from libs.helper import TimestampField, timezone

View File

@ -3,8 +3,7 @@ from flask_restful import Resource, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.login import login_required
from services.plugin.endpoint_service import EndpointService

View File

@ -2,8 +2,7 @@ from flask_restful import Resource, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from libs.login import current_user, login_required

View File

@ -4,8 +4,11 @@ from flask_restful import Resource, abort, marshal_with, reqparse
import services
from configs import dify_config
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
setup_required,
)
from extensions.ext_database import db
from fields.member_fields import account_with_role_list_fields
from libs.login import login_required

View File

@ -6,8 +6,7 @@ from flask_restful import Resource, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder

View File

@ -5,8 +5,7 @@ from flask_restful import Resource, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder

View File

@ -7,9 +7,8 @@ from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.workspace import plugin_permission_required
from controllers.console.wraps import account_initialization_required
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.login import login_required
from models.account import TenantPluginPermission

View File

@ -7,8 +7,7 @@ from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from libs.helper import alphanumeric, uuid_value
from libs.login import login_required

View File

@ -6,6 +6,7 @@ from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqpa
from werkzeug.exceptions import Unauthorized
import services
from controllers.common.errors import FilenameNotExistsError
from controllers.console import api
from controllers.console.admin import admin_required
from controllers.console.datasets.error import (
@ -15,8 +16,11 @@ from controllers.console.datasets.error import (
UnsupportedFileTypeError,
)
from controllers.console.error import AccountNotLinkTenantError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
setup_required,
)
from extensions.ext_database import db
from libs.helper import TimestampField
from libs.login import login_required
@ -193,12 +197,20 @@ class WebappLogoWorkspaceApi(Resource):
if len(request.files) > 1:
raise TooManyFilesError()
if not file.filename:
raise FilenameNotExistsError
extension = file.filename.split(".")[-1]
if extension.lower() not in {"svg", "png"}:
raise UnsupportedFileTypeError()
try:
upload_file = FileService.upload_file(file=file, user=current_user)
upload_file = FileService.upload_file(
filename=file.filename,
content=file.read(),
mimetype=file.mimetype,
user=current_user,
)
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)

View File

@ -1,4 +1,5 @@
import json
import os
from functools import wraps
from flask import abort, request
@ -6,9 +7,13 @@ from flask_login import current_user
from configs import dify_config
from controllers.console.workspace.error import AccountNotInitializedError
from extensions.ext_database import db
from models.model import DifySetup
from services.feature_service import FeatureService
from services.operation_service import OperationService
from .error import NotInitValidateError, NotSetupError
def account_initialization_required(view):
@wraps(view)
@ -124,3 +129,21 @@ def cloud_utm_record(view):
return view(*args, **kwargs)
return decorated
def setup_required(view):
@wraps(view)
def decorated(*args, **kwargs):
# check setup
if (
dify_config.EDITION == "SELF_HOSTED"
and os.environ.get("INIT_PASSWORD")
and not db.session.query(DifySetup).first()
):
raise NotInitValidateError()
elif dify_config.EDITION == "SELF_HOSTED" and not db.session.query(DifySetup).first():
raise NotSetupError()
return view(*args, **kwargs)
return decorated

View File

@ -1,6 +1,6 @@
from flask_restful import Resource
from controllers.console.setup import setup_required
from controllers.console.wraps import setup_required
from controllers.inner_api import api
from controllers.inner_api.plugin.wraps import get_tenant, plugin_data
from controllers.inner_api.wraps import plugin_inner_api_only

View File

@ -1,6 +1,6 @@
from flask_restful import Resource, reqparse
from controllers.console.setup import setup_required
from controllers.console.wraps import setup_required
from controllers.inner_api import api
from controllers.inner_api.wraps import enterprise_inner_api_only
from events.tenant_event import tenant_was_created

View File

@ -2,6 +2,7 @@ from flask import request
from flask_restful import Resource, marshal_with
import services
from controllers.common.errors import FilenameNotExistsError
from controllers.service_api import api
from controllers.service_api.app.error import (
FileTooLargeError,
@ -31,8 +32,17 @@ class FileApi(Resource):
if len(request.files) > 1:
raise TooManyFilesError()
if not file.filename:
raise FilenameNotExistsError
try:
upload_file = FileService.upload_file(file, end_user)
upload_file = FileService.upload_file(
filename=file.filename,
content=file.read(),
mimetype=file.mimetype,
user=end_user,
source="datasets",
)
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:

View File

@ -6,6 +6,7 @@ from sqlalchemy import desc
from werkzeug.exceptions import NotFound
import services.dataset_service
from controllers.common.errors import FilenameNotExistsError
from controllers.service_api import api
from controllers.service_api.app.error import ProviderNotInitializeError
from controllers.service_api.dataset.error import (
@ -55,7 +56,12 @@ class DocumentAddByTextApi(DatasetApiResource):
if not dataset.indexing_technique and not args["indexing_technique"]:
raise ValueError("indexing_technique is required.")
upload_file = FileService.upload_text(args.get("text"), args.get("name"))
text = args.get("text")
name = args.get("name")
if text is None or name is None:
raise ValueError("Both 'text' and 'name' must be non-null values.")
upload_file = FileService.upload_text(text=str(text), text_name=str(name))
data_source = {
"type": "upload_file",
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
@ -104,7 +110,11 @@ class DocumentUpdateByTextApi(DatasetApiResource):
raise ValueError("Dataset is not exist.")
if args["text"]:
upload_file = FileService.upload_text(args.get("text"), args.get("name"))
text = args.get("text")
name = args.get("name")
if text is None or name is None:
raise ValueError("Both text and name must be strings.")
upload_file = FileService.upload_text(text=str(text), text_name=str(name))
data_source = {
"type": "upload_file",
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
@ -163,7 +173,16 @@ class DocumentAddByFileApi(DatasetApiResource):
if len(request.files) > 1:
raise TooManyFilesError()
upload_file = FileService.upload_file(file, current_user)
if not file.filename:
raise FilenameNotExistsError
upload_file = FileService.upload_file(
filename=file.filename,
content=file.read(),
mimetype=file.mimetype,
user=current_user,
source="datasets",
)
data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}}
args["data_source"] = data_source
# validate args
@ -212,7 +231,16 @@ class DocumentUpdateByFileApi(DatasetApiResource):
if len(request.files) > 1:
raise TooManyFilesError()
upload_file = FileService.upload_file(file, current_user)
if not file.filename:
raise FilenameNotExistsError
upload_file = FileService.upload_file(
filename=file.filename,
content=file.read(),
mimetype=file.mimetype,
user=current_user,
source="datasets",
)
data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}}
args["data_source"] = data_source
# validate args
@ -331,10 +359,26 @@ class DocumentIndexingStatusApi(DatasetApiResource):
return data
api.add_resource(DocumentAddByTextApi, "/datasets/<uuid:dataset_id>/document/create_by_text")
api.add_resource(DocumentAddByFileApi, "/datasets/<uuid:dataset_id>/document/create_by_file")
api.add_resource(DocumentUpdateByTextApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_text")
api.add_resource(DocumentUpdateByFileApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_file")
api.add_resource(
DocumentAddByTextApi,
"/datasets/<uuid:dataset_id>/document/create_by_text",
"/datasets/<uuid:dataset_id>/document/create-by-text",
)
api.add_resource(
DocumentAddByFileApi,
"/datasets/<uuid:dataset_id>/document/create_by_file",
"/datasets/<uuid:dataset_id>/document/create-by-file",
)
api.add_resource(
DocumentUpdateByTextApi,
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_text",
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update-by-text",
)
api.add_resource(
DocumentUpdateByFileApi,
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update_by_file",
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/update-by-file",
)
api.add_resource(DocumentDeleteApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
api.add_resource(DocumentListApi, "/datasets/<uuid:dataset_id>/documents")
api.add_resource(DocumentIndexingStatusApi, "/datasets/<uuid:dataset_id>/documents/<string:batch>/indexing-status")

View File

@ -14,4 +14,4 @@ class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
return self.perform_hit_testing(dataset, args)
api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing")
api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing", "/datasets/<uuid:dataset_id>/retrieve")

View File

@ -2,8 +2,17 @@ from flask import Blueprint
from libs.external_api import ExternalApi
from .files import FileApi
from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi
bp = Blueprint("web", __name__, url_prefix="/api")
api = ExternalApi(bp)
# Files
api.add_resource(FileApi, "/files/upload")
from . import app, audio, completion, conversation, feature, file, message, passport, saved_message, site, workflow
# Remote files
api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>")
api.add_resource(RemoteFileUploadApi, "/remote-files/upload")
from . import app, audio, completion, conversation, feature, message, passport, saved_message, site, workflow

View File

@ -1,56 +0,0 @@
import urllib.parse
from flask import request
from flask_restful import marshal_with, reqparse
import services
from controllers.web import api
from controllers.web.error import FileTooLargeError, NoFileUploadedError, TooManyFilesError, UnsupportedFileTypeError
from controllers.web.wraps import WebApiResource
from core.helper import ssrf_proxy
from fields.file_fields import file_fields, remote_file_info_fields
from services.file_service import FileService
class FileApi(WebApiResource):
@marshal_with(file_fields)
def post(self, app_model, end_user):
# get file from request
file = request.files["file"]
parser = reqparse.RequestParser()
parser.add_argument("source", type=str, required=False, location="args")
source = parser.parse_args().get("source")
# check file
if "file" not in request.files:
raise NoFileUploadedError()
if len(request.files) > 1:
raise TooManyFilesError()
try:
upload_file = FileService.upload_file(file=file, user=end_user, source=source)
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
return upload_file, 201
class RemoteFileInfoApi(WebApiResource):
@marshal_with(remote_file_info_fields)
def get(self, url):
decoded_url = urllib.parse.unquote(url)
try:
response = ssrf_proxy.head(decoded_url)
return {
"file_type": response.headers.get("Content-Type", "application/octet-stream"),
"file_length": int(response.headers.get("Content-Length", -1)),
}
except Exception as e:
return {"error": str(e)}, 400
api.add_resource(FileApi, "/files/upload")
api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>")

View File

@ -0,0 +1,43 @@
from flask import request
from flask_restful import marshal_with
import services
from controllers.common.errors import FilenameNotExistsError
from controllers.web.error import FileTooLargeError, NoFileUploadedError, TooManyFilesError, UnsupportedFileTypeError
from controllers.web.wraps import WebApiResource
from fields.file_fields import file_fields
from services.file_service import FileService
class FileApi(WebApiResource):
@marshal_with(file_fields)
def post(self, app_model, end_user):
file = request.files["file"]
source = request.form.get("source")
if "file" not in request.files:
raise NoFileUploadedError()
if len(request.files) > 1:
raise TooManyFilesError()
if not file.filename:
raise FilenameNotExistsError
if source not in ("datasets", None):
source = None
try:
upload_file = FileService.upload_file(
filename=file.filename,
content=file.read(),
mimetype=file.mimetype,
user=end_user,
source=source,
)
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
return upload_file, 201

View File

@ -0,0 +1,69 @@
import urllib.parse
from flask_login import current_user
from flask_restful import marshal_with, reqparse
from controllers.common import helpers
from controllers.web.wraps import WebApiResource
from core.file import helpers as file_helpers
from core.helper import ssrf_proxy
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
from services.file_service import FileService
class RemoteFileInfoApi(WebApiResource):
@marshal_with(remote_file_info_fields)
def get(self, url):
decoded_url = urllib.parse.unquote(url)
try:
response = ssrf_proxy.head(decoded_url)
return {
"file_type": response.headers.get("Content-Type", "application/octet-stream"),
"file_length": int(response.headers.get("Content-Length", -1)),
}
except Exception as e:
return {"error": str(e)}, 400
class RemoteFileUploadApi(WebApiResource):
@marshal_with(file_fields_with_signed_url)
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("url", type=str, required=True, help="URL is required")
args = parser.parse_args()
url = args["url"]
response = ssrf_proxy.head(url)
response.raise_for_status()
file_info = helpers.guess_file_info_from_response(response)
if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size):
return {"error": "File size exceeded"}, 400
response = ssrf_proxy.get(url)
response.raise_for_status()
content = response.content
try:
upload_file = FileService.upload_file(
filename=file_info.filename,
content=content,
mimetype=file_info.mimetype,
user=current_user,
source_url=url,
)
except Exception as e:
return {"error": str(e)}, 400
return {
"id": upload_file.id,
"name": upload_file.name,
"size": upload_file.size,
"extension": upload_file.extension,
"url": file_helpers.get_signed_file_url(upload_file_id=upload_file.id),
"mime_type": upload_file.mime_type,
"created_by": upload_file.created_by,
"created_at": upload_file.created_at,
}, 201

View File

@ -17,6 +17,7 @@ from core.errors.error import ProviderTokenNotInitError
from core.llm_generator.llm_generator import LLMGenerator
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.extract_setting import ExtractSetting
@ -597,26 +598,9 @@ class IndexingRunner:
rules = DatasetProcessRule.AUTOMATIC_RULES
else:
rules = json.loads(processing_rule.rules) if processing_rule.rules else {}
document_text = CleanProcessor.clean(text, rules)
if "pre_processing_rules" in rules:
pre_processing_rules = rules["pre_processing_rules"]
for pre_processing_rule in pre_processing_rules:
if pre_processing_rule["id"] == "remove_extra_spaces" and pre_processing_rule["enabled"] is True:
# Remove extra spaces
pattern = r"\n{3,}"
text = re.sub(pattern, "\n\n", text)
pattern = r"[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}"
text = re.sub(pattern, " ", text)
elif pre_processing_rule["id"] == "remove_urls_emails" and pre_processing_rule["enabled"] is True:
# Remove email
pattern = r"([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)"
text = re.sub(pattern, "", text)
# Remove URL
pattern = r"https?://[^\s]+"
text = re.sub(pattern, "", text)
return text
return document_text
@staticmethod
def format_split_text(text):

View File

@ -1,9 +0,0 @@
- claude-3-5-sonnet-20241022
- claude-3-5-sonnet-20240620
- claude-3-haiku-20240307
- claude-3-opus-20240229
- claude-3-sonnet-20240229
- claude-2.1
- claude-instant-1.2
- claude-2
- claude-instant-1

View File

@ -1,39 +0,0 @@
model: claude-3-5-sonnet-20241022
label:
en_US: claude-3-5-sonnet-20241022
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 200000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
- name: response_format
use_template: response_format
pricing:
input: '3.00'
output: '15.00'
unit: '0.000001'
currency: USD

View File

@ -1,245 +0,0 @@
provider: azure_openai
label:
en_US: Azure OpenAI Service Model
icon_small:
en_US: icon_s_en.svg
icon_large:
en_US: icon_l_en.png
background: "#E3F0FF"
help:
title:
en_US: Get your API key from Azure
zh_Hans: 从 Azure 获取 API Key
url:
en_US: https://azure.microsoft.com/en-us/products/ai-services/openai-service
supported_model_types:
- llm
- text-embedding
- speech2text
- tts
configurate_methods:
- customizable-model
model_credential_schema:
model:
label:
en_US: Deployment Name
zh_Hans: 部署名称
placeholder:
en_US: Enter your Deployment Name here, matching the Azure deployment name.
zh_Hans: 在此输入您的部署名称,与 Azure 部署名称匹配。
credential_form_schemas:
- variable: openai_api_base
label:
en_US: API Endpoint URL
zh_Hans: API 域名
type: text-input
required: true
placeholder:
zh_Hans: '在此输入您的 API 域名https://example.com/xxx'
en_US: 'Enter your API Endpoint, eg: https://example.com/xxx'
- variable: openai_api_key
label:
en_US: API Key
zh_Hans: API Key
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API key here
- variable: openai_api_version
label:
zh_Hans: API 版本
en_US: API Version
type: select
required: true
options:
- label:
en_US: 2024-10-01-preview
value: 2024-10-01-preview
- label:
en_US: 2024-09-01-preview
value: 2024-09-01-preview
- label:
en_US: 2024-08-01-preview
value: 2024-08-01-preview
- label:
en_US: 2024-07-01-preview
value: 2024-07-01-preview
- label:
en_US: 2024-05-01-preview
value: 2024-05-01-preview
- label:
en_US: 2024-04-01-preview
value: 2024-04-01-preview
- label:
en_US: 2024-03-01-preview
value: 2024-03-01-preview
- label:
en_US: 2024-02-15-preview
value: 2024-02-15-preview
- label:
en_US: 2023-12-01-preview
value: 2023-12-01-preview
- label:
en_US: '2024-02-01'
value: '2024-02-01'
- label:
en_US: '2024-06-01'
value: '2024-06-01'
placeholder:
zh_Hans: 在此选择您的 API 版本
en_US: Select your API Version here
- variable: base_model_name
label:
en_US: Base Model
zh_Hans: 基础模型
type: select
required: true
options:
- label:
en_US: gpt-35-turbo
value: gpt-35-turbo
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-35-turbo-0125
value: gpt-35-turbo-0125
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-35-turbo-16k
value: gpt-35-turbo-16k
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-4
value: gpt-4
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-4-32k
value: gpt-4-32k
show_on:
- variable: __model_type
value: llm
- label:
en_US: o1-mini
value: o1-mini
show_on:
- variable: __model_type
value: llm
- label:
en_US: o1-preview
value: o1-preview
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-4o-mini
value: gpt-4o-mini
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-4o-mini-2024-07-18
value: gpt-4o-mini-2024-07-18
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-4o
value: gpt-4o
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-4o-2024-05-13
value: gpt-4o-2024-05-13
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-4o-2024-08-06
value: gpt-4o-2024-08-06
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-4-turbo
value: gpt-4-turbo
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-4-turbo-2024-04-09
value: gpt-4-turbo-2024-04-09
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-4-0125-preview
value: gpt-4-0125-preview
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-4-1106-preview
value: gpt-4-1106-preview
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-4-vision-preview
value: gpt-4-vision-preview
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-35-turbo-instruct
value: gpt-35-turbo-instruct
show_on:
- variable: __model_type
value: llm
- label:
en_US: text-embedding-ada-002
value: text-embedding-ada-002
show_on:
- variable: __model_type
value: text-embedding
- label:
en_US: text-embedding-3-small
value: text-embedding-3-small
show_on:
- variable: __model_type
value: text-embedding
- label:
en_US: text-embedding-3-large
value: text-embedding-3-large
show_on:
- variable: __model_type
value: text-embedding
- label:
en_US: whisper-1
value: whisper-1
show_on:
- variable: __model_type
value: speech2text
- label:
en_US: tts-1
value: tts-1
show_on:
- variable: __model_type
value: tts
- label:
en_US: tts-1-hd
value: tts-1-hd
show_on:
- variable: __model_type
value: tts
placeholder:
zh_Hans: 在此输入您的模型版本
en_US: Enter your model version

View File

@ -1,764 +0,0 @@
import copy
import json
import logging
from collections.abc import Generator, Sequence
from typing import Optional, Union, cast
import tiktoken
from openai import AzureOpenAI, Stream
from openai.types import Completion
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall
from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContentType,
PromptMessageFunction,
PromptMessageTool,
SystemPromptMessage,
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI
from core.model_runtime.model_providers.azure_openai._constant import LLM_BASE_MODELS
from core.model_runtime.utils import helper
logger = logging.getLogger(__name__)
class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
base_model_name = self._get_base_model_name(credentials)
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
if ai_model_entity and ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
# chat model
return self._chat_generate(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
)
else:
# text completion model
return self._generate(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
stop=stop,
stream=stream,
user=user,
)
def get_num_tokens(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
) -> int:
base_model_name = self._get_base_model_name(credentials)
model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
if not model_entity:
raise ValueError(f"Base Model Name {base_model_name} is invalid")
model_mode = model_entity.entity.model_properties.get(ModelPropertyKey.MODE)
if model_mode == LLMMode.CHAT.value:
# chat model
return self._num_tokens_from_messages(credentials, prompt_messages, tools)
else:
# text completion model, do not support tool calling
content = prompt_messages[0].content
assert isinstance(content, str)
return self._num_tokens_from_string(credentials, content)
def validate_credentials(self, model: str, credentials: dict) -> None:
if "openai_api_base" not in credentials:
raise CredentialsValidateFailedError("Azure OpenAI API Base Endpoint is required")
if "openai_api_key" not in credentials:
raise CredentialsValidateFailedError("Azure OpenAI API key is required")
if "base_model_name" not in credentials:
raise CredentialsValidateFailedError("Base Model Name is required")
base_model_name = self._get_base_model_name(credentials)
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
if not ai_model_entity:
raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid')
try:
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
if model.startswith("o1"):
client.chat.completions.create(
messages=[{"role": "user", "content": "ping"}],
model=model,
temperature=1,
max_completion_tokens=20,
stream=False,
)
elif ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value:
# chat model
client.chat.completions.create(
messages=[{"role": "user", "content": "ping"}],
model=model,
temperature=0,
max_tokens=20,
stream=False,
)
else:
# text completion model
client.completions.create(
prompt="ping",
model=model,
temperature=0,
max_tokens=20,
stream=False,
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
base_model_name = self._get_base_model_name(credentials)
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
return ai_model_entity.entity if ai_model_entity else None
def _generate(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
extra_model_kwargs = {}
if stop:
extra_model_kwargs["stop"] = stop
if user:
extra_model_kwargs["user"] = user
# text completion model
response = client.completions.create(
prompt=prompt_messages[0].content, model=model, stream=stream, **model_parameters, **extra_model_kwargs
)
if stream:
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
return self._handle_generate_response(model, credentials, response, prompt_messages)
def _handle_generate_response(
self, model: str, credentials: dict, response: Completion, prompt_messages: list[PromptMessage]
):
assistant_text = response.choices[0].text
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(content=assistant_text)
# calculate num tokens
if response.usage:
# transform usage
prompt_tokens = response.usage.prompt_tokens
completion_tokens = response.usage.completion_tokens
else:
# calculate num tokens
content = prompt_messages[0].content
assert isinstance(content, str)
prompt_tokens = self._num_tokens_from_string(credentials, content)
completion_tokens = self._num_tokens_from_string(credentials, assistant_text)
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
# transform response
result = LLMResult(
model=response.model,
prompt_messages=prompt_messages,
message=assistant_prompt_message,
usage=usage,
system_fingerprint=response.system_fingerprint,
)
return result
def _handle_generate_stream_response(
self, model: str, credentials: dict, response: Stream[Completion], prompt_messages: list[PromptMessage]
) -> Generator:
full_text = ""
for chunk in response:
if len(chunk.choices) == 0:
continue
delta = chunk.choices[0]
if delta.finish_reason is None and (delta.text is None or delta.text == ""):
continue
# transform assistant message to prompt message
text = delta.text or ""
assistant_prompt_message = AssistantPromptMessage(content=text)
full_text += text
if delta.finish_reason is not None:
# calculate num tokens
if chunk.usage:
# transform usage
prompt_tokens = chunk.usage.prompt_tokens
completion_tokens = chunk.usage.completion_tokens
else:
# calculate num tokens
content = prompt_messages[0].content
assert isinstance(content, str)
prompt_tokens = self._num_tokens_from_string(credentials, content)
completion_tokens = self._num_tokens_from_string(credentials, full_text)
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
yield LLMResultChunk(
model=chunk.model,
prompt_messages=prompt_messages,
system_fingerprint=chunk.system_fingerprint,
delta=LLMResultChunkDelta(
index=delta.index,
message=assistant_prompt_message,
finish_reason=delta.finish_reason,
usage=usage,
),
)
else:
yield LLMResultChunk(
model=chunk.model,
prompt_messages=prompt_messages,
system_fingerprint=chunk.system_fingerprint,
delta=LLMResultChunkDelta(
index=delta.index,
message=assistant_prompt_message,
),
)
def _chat_generate(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
client = AzureOpenAI(**self._to_credential_kwargs(credentials))
response_format = model_parameters.get("response_format")
if response_format:
if response_format == "json_schema":
json_schema = model_parameters.get("json_schema")
if not json_schema:
raise ValueError("Must define JSON Schema when the response format is json_schema")
try:
schema = json.loads(json_schema)
except:
raise ValueError(f"not correct json_schema format: {json_schema}")
model_parameters.pop("json_schema")
model_parameters["response_format"] = {"type": "json_schema", "json_schema": schema}
else:
model_parameters["response_format"] = {"type": response_format}
extra_model_kwargs = {}
if tools:
extra_model_kwargs["tools"] = [helper.dump_model(PromptMessageFunction(function=tool)) for tool in tools]
if stop:
extra_model_kwargs["stop"] = stop
if user:
extra_model_kwargs["user"] = user
# clear illegal prompt messages
prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages)
block_as_stream = False
if model.startswith("o1"):
if stream:
block_as_stream = True
stream = False
if "stream_options" in extra_model_kwargs:
del extra_model_kwargs["stream_options"]
if "stop" in extra_model_kwargs:
del extra_model_kwargs["stop"]
# chat model
response = client.chat.completions.create(
messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
model=model,
stream=stream,
**model_parameters,
**extra_model_kwargs,
)
if stream:
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools)
block_result = self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
if block_as_stream:
return self._handle_chat_block_as_stream_response(block_result, prompt_messages, stop)
return block_result
def _handle_chat_block_as_stream_response(
self,
block_result: LLMResult,
prompt_messages: list[PromptMessage],
stop: Optional[list[str]] = None,
) -> Generator[LLMResultChunk, None, None]:
"""
Handle llm chat response
:param model: model name
:param credentials: credentials
:param response: response
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:param stop: stop words
:return: llm response chunk generator
"""
text = block_result.message.content
text = cast(str, text)
if stop:
text = self.enforce_stop_tokens(text, stop)
yield LLMResultChunk(
model=block_result.model,
prompt_messages=prompt_messages,
system_fingerprint=block_result.system_fingerprint,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content=text),
finish_reason="stop",
usage=block_result.usage,
),
)
def _clear_illegal_prompt_messages(self, model: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
Clear illegal prompt messages for OpenAI API
:param model: model name
:param prompt_messages: prompt messages
:return: cleaned prompt messages
"""
checklist = ["gpt-4-turbo", "gpt-4-turbo-2024-04-09"]
if model in checklist:
# count how many user messages are there
user_message_count = len([m for m in prompt_messages if isinstance(m, UserPromptMessage)])
if user_message_count > 1:
for prompt_message in prompt_messages:
if isinstance(prompt_message, UserPromptMessage):
if isinstance(prompt_message.content, list):
prompt_message.content = "\n".join(
[
item.data
if item.type == PromptMessageContentType.TEXT
else "[IMAGE]"
if item.type == PromptMessageContentType.IMAGE
else ""
for item in prompt_message.content
]
)
if model.startswith("o1"):
system_message_count = len([m for m in prompt_messages if isinstance(m, SystemPromptMessage)])
if system_message_count > 0:
new_prompt_messages = []
for prompt_message in prompt_messages:
if isinstance(prompt_message, SystemPromptMessage):
prompt_message = UserPromptMessage(
content=prompt_message.content,
name=prompt_message.name,
)
new_prompt_messages.append(prompt_message)
prompt_messages = new_prompt_messages
return prompt_messages
def _handle_chat_generate_response(
self,
model: str,
credentials: dict,
response: ChatCompletion,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
):
assistant_message = response.choices[0].message
assistant_message_tool_calls = assistant_message.tool_calls
# extract tool calls from response
tool_calls = []
self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=assistant_message_tool_calls)
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(content=assistant_message.content, tool_calls=tool_calls)
# calculate num tokens
if response.usage:
# transform usage
prompt_tokens = response.usage.prompt_tokens
completion_tokens = response.usage.completion_tokens
else:
# calculate num tokens
prompt_tokens = self._num_tokens_from_messages(credentials, prompt_messages, tools)
completion_tokens = self._num_tokens_from_messages(credentials, [assistant_prompt_message])
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
# transform response
result = LLMResult(
model=response.model or model,
prompt_messages=prompt_messages,
message=assistant_prompt_message,
usage=usage,
system_fingerprint=response.system_fingerprint,
)
return result
def _handle_chat_generate_stream_response(
self,
model: str,
credentials: dict,
response: Stream[ChatCompletionChunk],
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
):
index = 0
full_assistant_content = ""
real_model = model
system_fingerprint = None
completion = ""
tool_calls = []
for chunk in response:
if len(chunk.choices) == 0:
continue
delta = chunk.choices[0]
# NOTE: For fix https://github.com/langgenius/dify/issues/5790
if delta.delta is None:
continue
# extract tool calls from response
self._update_tool_calls(tool_calls=tool_calls, tool_calls_response=delta.delta.tool_calls)
# Handling exceptions when content filters' streaming mode is set to asynchronous modified filter
if delta.finish_reason is None and not delta.delta.content:
continue
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(content=delta.delta.content or "", tool_calls=tool_calls)
full_assistant_content += delta.delta.content or ""
real_model = chunk.model
system_fingerprint = chunk.system_fingerprint
completion += delta.delta.content or ""
yield LLMResultChunk(
model=real_model,
prompt_messages=prompt_messages,
system_fingerprint=system_fingerprint,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message,
),
)
index += 1
# calculate num tokens
prompt_tokens = self._num_tokens_from_messages(credentials, prompt_messages, tools)
full_assistant_prompt_message = AssistantPromptMessage(content=completion)
completion_tokens = self._num_tokens_from_messages(credentials, [full_assistant_prompt_message])
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
yield LLMResultChunk(
model=real_model,
prompt_messages=prompt_messages,
system_fingerprint=system_fingerprint,
delta=LLMResultChunkDelta(
index=index, message=AssistantPromptMessage(content=""), finish_reason="stop", usage=usage
),
)
@staticmethod
def _update_tool_calls(
tool_calls: list[AssistantPromptMessage.ToolCall],
tool_calls_response: Optional[Sequence[ChatCompletionMessageToolCall | ChoiceDeltaToolCall]],
) -> None:
if tool_calls_response:
for response_tool_call in tool_calls_response:
if isinstance(response_tool_call, ChatCompletionMessageToolCall):
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_tool_call.function.name, arguments=response_tool_call.function.arguments
)
tool_call = AssistantPromptMessage.ToolCall(
id=response_tool_call.id, type=response_tool_call.type, function=function
)
tool_calls.append(tool_call)
elif isinstance(response_tool_call, ChoiceDeltaToolCall):
index = response_tool_call.index
if index < len(tool_calls):
tool_calls[index].id = response_tool_call.id or tool_calls[index].id
tool_calls[index].type = response_tool_call.type or tool_calls[index].type
if response_tool_call.function:
tool_calls[index].function.name = (
response_tool_call.function.name or tool_calls[index].function.name
)
tool_calls[index].function.arguments += response_tool_call.function.arguments or ""
else:
assert response_tool_call.id is not None
assert response_tool_call.type is not None
assert response_tool_call.function is not None
assert response_tool_call.function.name is not None
assert response_tool_call.function.arguments is not None
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_tool_call.function.name, arguments=response_tool_call.function.arguments
)
tool_call = AssistantPromptMessage.ToolCall(
id=response_tool_call.id, type=response_tool_call.type, function=function
)
tool_calls.append(tool_call)
@staticmethod
def _convert_prompt_message_to_dict(message: PromptMessage):
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
if isinstance(message.content, str):
message_dict = {"role": "user", "content": message.content}
else:
sub_messages = []
assert message.content is not None
for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(TextPromptMessageContent, message_content)
sub_message_dict = {"type": "text", "text": message_content.data}
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, message_content)
sub_message_dict = {
"type": "image_url",
"image_url": {"url": message_content.data, "detail": message_content.detail.value},
}
sub_messages.append(sub_message_dict)
message_dict = {"role": "user", "content": sub_messages}
elif isinstance(message, AssistantPromptMessage):
# message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content}
if message.tool_calls:
message_dict["tool_calls"] = [helper.dump_model(tool_call) for tool_call in message.tool_calls]
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolPromptMessage):
message = cast(ToolPromptMessage, message)
message_dict = {
"role": "tool",
"name": message.name,
"content": message.content,
"tool_call_id": message.tool_call_id,
}
else:
raise ValueError(f"Got unknown type {message}")
if message.name:
message_dict["name"] = message.name
return message_dict
def _num_tokens_from_string(
self, credentials: dict, text: str, tools: Optional[list[PromptMessageTool]] = None
) -> int:
try:
encoding = tiktoken.encoding_for_model(credentials["base_model_name"])
except KeyError:
encoding = tiktoken.get_encoding("cl100k_base")
num_tokens = len(encoding.encode(text))
if tools:
num_tokens += self._num_tokens_for_tools(encoding, tools)
return num_tokens
def _num_tokens_from_messages(
self, credentials: dict, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None
) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
model = credentials["base_model_name"]
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
logger.warning("Warning: model not found. Using cl100k_base encoding.")
model = "cl100k_base"
encoding = tiktoken.get_encoding(model)
if model.startswith("gpt-35-turbo-0301"):
# every message follows <im_start>{role/name}\n{content}<im_end>\n
tokens_per_message = 4
# if there's a name, the role is omitted
tokens_per_name = -1
elif model.startswith("gpt-35-turbo") or model.startswith("gpt-4") or model.startswith("o1"):
tokens_per_message = 3
tokens_per_name = 1
else:
raise NotImplementedError(
f"get_num_tokens_from_messages() is not presently implemented "
f"for model {model}."
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
"information on how messages are converted to tokens."
)
num_tokens = 0
messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
for message in messages_dict:
num_tokens += tokens_per_message
for key, value in message.items():
# Cast str(value) in case the message value is not a string
# This occurs with function messages
# TODO: The current token calculation method for the image type is not implemented,
# which need to download the image and then get the resolution for calculation,
# and will increase the request delay
if isinstance(value, list):
text = ""
for item in value:
if isinstance(item, dict) and item["type"] == "text":
text += item["text"]
value = text
if key == "tool_calls":
for tool_call in value:
assert isinstance(tool_call, dict)
for t_key, t_value in tool_call.items():
num_tokens += len(encoding.encode(t_key))
if t_key == "function":
for f_key, f_value in t_value.items():
num_tokens += len(encoding.encode(f_key))
num_tokens += len(encoding.encode(f_value))
else:
num_tokens += len(encoding.encode(t_key))
num_tokens += len(encoding.encode(t_value))
else:
num_tokens += len(encoding.encode(str(value)))
if key == "name":
num_tokens += tokens_per_name
# every reply is primed with <im_start>assistant
num_tokens += 3
if tools:
num_tokens += self._num_tokens_for_tools(encoding, tools)
return num_tokens
@staticmethod
def _num_tokens_for_tools(encoding: tiktoken.Encoding, tools: list[PromptMessageTool]) -> int:
num_tokens = 0
for tool in tools:
num_tokens += len(encoding.encode("type"))
num_tokens += len(encoding.encode("function"))
# calculate num tokens for function object
num_tokens += len(encoding.encode("name"))
num_tokens += len(encoding.encode(tool.name))
num_tokens += len(encoding.encode("description"))
num_tokens += len(encoding.encode(tool.description))
parameters = tool.parameters
num_tokens += len(encoding.encode("parameters"))
if "title" in parameters:
num_tokens += len(encoding.encode("title"))
num_tokens += len(encoding.encode(parameters["title"]))
num_tokens += len(encoding.encode("type"))
num_tokens += len(encoding.encode(parameters["type"]))
if "properties" in parameters:
num_tokens += len(encoding.encode("properties"))
for key, value in parameters["properties"].items():
num_tokens += len(encoding.encode(key))
for field_key, field_value in value.items():
num_tokens += len(encoding.encode(field_key))
if field_key == "enum":
for enum_field in field_value:
num_tokens += 3
num_tokens += len(encoding.encode(enum_field))
else:
num_tokens += len(encoding.encode(field_key))
num_tokens += len(encoding.encode(str(field_value)))
if "required" in parameters:
num_tokens += len(encoding.encode("required"))
for required_field in parameters["required"]:
num_tokens += 3
num_tokens += len(encoding.encode(required_field))
return num_tokens
@staticmethod
def _get_ai_model_entity(base_model_name: str, model: str):
for ai_model_entity in LLM_BASE_MODELS:
if ai_model_entity.base_model_name == base_model_name:
ai_model_entity_copy = copy.deepcopy(ai_model_entity)
ai_model_entity_copy.entity.model = model
ai_model_entity_copy.entity.label.en_US = model
ai_model_entity_copy.entity.label.zh_Hans = model
return ai_model_entity_copy
def _get_base_model_name(self, credentials: dict) -> str:
base_model_name = credentials.get("base_model_name")
if not base_model_name:
raise ValueError("Base Model Name is required")
return base_model_name

View File

@ -1,450 +0,0 @@
import base64
import io
import json
import logging
from collections.abc import Generator
from typing import Optional, Union, cast
import google.ai.generativelanguage as glm
import google.generativeai as genai
import requests
from google.api_core import exceptions
from google.generativeai.client import _ClientManager
from google.generativeai.types import ContentType, GenerateContentResponse
from google.generativeai.types.content_types import to_part
from PIL import Image
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContentType,
PromptMessageTool,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
logger = logging.getLogger(__name__)
GEMINI_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
if you are not sure about the structure.
<instructions>
{{instructions}}
</instructions>
""" # noqa: E501
class GoogleLargeLanguageModel(LargeLanguageModel):
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
# invoke model
return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
def get_num_tokens(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:return:md = genai.GenerativeModel(model)
"""
prompt = self._convert_messages_to_prompt(prompt_messages)
return self._get_num_tokens_by_gpt2(prompt)
def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str:
"""
Format a list of messages into a full prompt for the Google model
:param messages: List of PromptMessage to combine.
:return: Combined string with necessary human_prompt and ai_prompt tags.
"""
messages = messages.copy() # don't mutate the original list
text = "".join(self._convert_one_message_to_text(message) for message in messages)
return text.rstrip()
def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool:
"""
Convert tool messages to glm tools
:param tools: tool messages
:return: glm tools
"""
function_declarations = []
for tool in tools:
properties = {}
for key, value in tool.parameters.get("properties", {}).items():
properties[key] = {
"type_": glm.Type.STRING,
"description": value.get("description", ""),
"enum": value.get("enum", []),
}
if properties:
parameters = glm.Schema(
type=glm.Type.OBJECT,
properties=properties,
required=tool.parameters.get("required", []),
)
else:
parameters = None
function_declaration = glm.FunctionDeclaration(
name=tool.name,
parameters=parameters,
description=tool.description,
)
function_declarations.append(function_declaration)
return glm.Tool(function_declarations=function_declarations)
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
ping_message = SystemPromptMessage(content="ping")
self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5})
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def _generate(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke large language model
:param model: model name
:param credentials: credentials kwargs
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
config_kwargs = model_parameters.copy()
config_kwargs["max_output_tokens"] = config_kwargs.pop("max_tokens_to_sample", None)
if stop:
config_kwargs["stop_sequences"] = stop
google_model = genai.GenerativeModel(model_name=model)
history = []
# hack for gemini-pro-vision, which currently does not support multi-turn chat
if model == "gemini-pro-vision":
last_msg = prompt_messages[-1]
content = self._format_message_to_glm_content(last_msg)
history.append(content)
else:
for msg in prompt_messages: # makes message roles strictly alternating
content = self._format_message_to_glm_content(msg)
if history and history[-1]["role"] == content["role"]:
history[-1]["parts"].extend(content["parts"])
else:
history.append(content)
# Create a new ClientManager with tenant's API key
new_client_manager = _ClientManager()
new_client_manager.configure(api_key=credentials["google_api_key"])
new_custom_client = new_client_manager.make_client("generative")
google_model._client = new_custom_client
response = google_model.generate_content(
contents=history,
generation_config=genai.types.GenerationConfig(**config_kwargs),
stream=stream,
tools=self._convert_tools_to_glm_tool(tools) if tools else None,
request_options={"timeout": 600},
)
if stream:
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
return self._handle_generate_response(model, credentials, response, prompt_messages)
def _handle_generate_response(
self, model: str, credentials: dict, response: GenerateContentResponse, prompt_messages: list[PromptMessage]
) -> LLMResult:
"""
Handle llm response
:param model: model name
:param credentials: credentials
:param response: response
:param prompt_messages: prompt messages
:return: llm response
"""
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(content=response.text)
# calculate num tokens
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
# transform response
result = LLMResult(
model=model,
prompt_messages=prompt_messages,
message=assistant_prompt_message,
usage=usage,
)
return result
def _handle_generate_stream_response(
self, model: str, credentials: dict, response: GenerateContentResponse, prompt_messages: list[PromptMessage]
) -> Generator:
"""
Handle llm stream response
:param model: model name
:param credentials: credentials
:param response: response
:param prompt_messages: prompt messages
:return: llm response chunk generator result
"""
index = -1
for chunk in response:
for part in chunk.parts:
assistant_prompt_message = AssistantPromptMessage(content="")
if part.text:
assistant_prompt_message.content += part.text
if part.function_call:
assistant_prompt_message.tool_calls = [
AssistantPromptMessage.ToolCall(
id=part.function_call.name,
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=part.function_call.name,
arguments=json.dumps(dict(part.function_call.args.items())),
),
)
]
index += 1
if not response._done:
# transform assistant message to prompt message
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(index=index, message=assistant_prompt_message),
)
else:
# calculate num tokens
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message,
finish_reason=str(chunk.candidates[0].finish_reason),
usage=usage,
),
)
def _convert_one_message_to_text(self, message: PromptMessage) -> str:
"""
Convert a single message to a string.
:param message: PromptMessage to convert.
:return: String representation of the message.
"""
human_prompt = "\n\nuser:"
ai_prompt = "\n\nmodel:"
content = message.content
if isinstance(content, list):
content = "".join(c.data for c in content if c.type != PromptMessageContentType.IMAGE)
if isinstance(message, UserPromptMessage):
message_text = f"{human_prompt} {content}"
elif isinstance(message, AssistantPromptMessage):
message_text = f"{ai_prompt} {content}"
elif isinstance(message, SystemPromptMessage | ToolPromptMessage):
message_text = f"{human_prompt} {content}"
else:
raise ValueError(f"Got unknown type {message}")
return message_text
def _format_message_to_glm_content(self, message: PromptMessage) -> ContentType:
"""
Format a single message into glm.Content for Google API
:param message: one PromptMessage
:return: glm Content representation of message
"""
if isinstance(message, UserPromptMessage):
glm_content = {"role": "user", "parts": []}
if isinstance(message.content, str):
glm_content["parts"].append(to_part(message.content))
else:
for c in message.content:
if c.type == PromptMessageContentType.TEXT:
glm_content["parts"].append(to_part(c.data))
elif c.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, c)
if message_content.data.startswith("data:"):
metadata, base64_data = c.data.split(",", 1)
mime_type = metadata.split(";", 1)[0].split(":")[1]
else:
# fetch image data from url
try:
image_content = requests.get(message_content.data).content
with Image.open(io.BytesIO(image_content)) as img:
mime_type = f"image/{img.format.lower()}"
base64_data = base64.b64encode(image_content).decode("utf-8")
except Exception as ex:
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
blob = {"inline_data": {"mime_type": mime_type, "data": base64_data}}
glm_content["parts"].append(blob)
return glm_content
elif isinstance(message, AssistantPromptMessage):
glm_content = {"role": "model", "parts": []}
if message.content:
glm_content["parts"].append(to_part(message.content))
if message.tool_calls:
glm_content["parts"].append(
to_part(
glm.FunctionCall(
name=message.tool_calls[0].function.name,
args=json.loads(message.tool_calls[0].function.arguments),
)
)
)
return glm_content
elif isinstance(message, SystemPromptMessage):
return {"role": "user", "parts": [to_part(message.content)]}
elif isinstance(message, ToolPromptMessage):
return {
"role": "function",
"parts": [
glm.Part(
function_response=glm.FunctionResponse(
name=message.name, response={"response": message.content}
)
)
],
}
else:
raise ValueError(f"Got unknown type {message}")
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the ermd = genai.GenerativeModel(model) error type thrown to the caller
The value is the md = genai.GenerativeModel(model) error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke emd = genai.GenerativeModel(model) error mapping
"""
return {
InvokeConnectionError: [exceptions.RetryError],
InvokeServerUnavailableError: [
exceptions.ServiceUnavailable,
exceptions.InternalServerError,
exceptions.BadGateway,
exceptions.GatewayTimeout,
exceptions.DeadlineExceeded,
],
InvokeRateLimitError: [exceptions.ResourceExhausted, exceptions.TooManyRequests],
InvokeAuthorizationError: [
exceptions.Unauthenticated,
exceptions.PermissionDenied,
exceptions.Unauthenticated,
exceptions.Forbidden,
],
InvokeBadRequestError: [
exceptions.BadRequest,
exceptions.InvalidArgument,
exceptions.FailedPrecondition,
exceptions.OutOfRange,
exceptions.NotFound,
exceptions.MethodNotAllowed,
exceptions.Conflict,
exceptions.AlreadyExists,
exceptions.Aborted,
exceptions.LengthRequired,
exceptions.PreconditionFailed,
exceptions.RequestRangeNotSatisfiable,
exceptions.Cancelled,
],
}

View File

@ -1,330 +0,0 @@
import json
from collections.abc import Generator
from typing import Optional, Union, cast
import requests
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContent,
PromptMessageContentType,
PromptMessageTool,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import (
AIModelEntity,
FetchFrom,
ModelFeature,
ModelPropertyKey,
ModelType,
ParameterRule,
ParameterType,
)
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
self._add_custom_parameters(credentials)
self._add_function_call(model, credentials)
user = user[:32] if user else None
# {"response_format": "json_object"} need convert to {"response_format": {"type": "json_object"}}
if "response_format" in model_parameters:
model_parameters["response_format"] = {"type": model_parameters.get("response_format")}
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
def validate_credentials(self, model: str, credentials: dict) -> None:
self._add_custom_parameters(credentials)
super().validate_credentials(model, credentials)
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
return AIModelEntity(
model=model,
label=I18nObject(en_US=model, zh_Hans=model),
model_type=ModelType.LLM,
features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL]
if credentials.get("function_calling_type") == "tool_call"
else [],
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 4096)),
ModelPropertyKey.MODE: LLMMode.CHAT.value,
},
parameter_rules=[
ParameterRule(
name="temperature",
use_template="temperature",
label=I18nObject(en_US="Temperature", zh_Hans="温度"),
type=ParameterType.FLOAT,
),
ParameterRule(
name="max_tokens",
use_template="max_tokens",
default=512,
min=1,
max=int(credentials.get("max_tokens", 4096)),
label=I18nObject(en_US="Max Tokens", zh_Hans="最大标记"),
type=ParameterType.INT,
),
ParameterRule(
name="top_p",
use_template="top_p",
label=I18nObject(en_US="Top P", zh_Hans="Top P"),
type=ParameterType.FLOAT,
),
],
)
def _add_custom_parameters(self, credentials: dict) -> None:
credentials["mode"] = "chat"
if "endpoint_url" not in credentials or credentials["endpoint_url"] == "":
credentials["endpoint_url"] = "https://api.moonshot.cn/v1"
def _add_function_call(self, model: str, credentials: dict) -> None:
model_schema = self.get_model_schema(model, credentials)
if model_schema and {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}.intersection(
model_schema.features or []
):
credentials["function_calling_type"] = "tool_call"
def _convert_prompt_message_to_dict(self, message: PromptMessage, credentials: Optional[dict] = None) -> dict:
"""
Convert PromptMessage to dict for OpenAI API format
"""
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
if isinstance(message.content, str):
message_dict = {"role": "user", "content": message.content}
else:
sub_messages = []
for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(PromptMessageContent, message_content)
sub_message_dict = {"type": "text", "text": message_content.data}
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, message_content)
sub_message_dict = {
"type": "image_url",
"image_url": {"url": message_content.data, "detail": message_content.detail.value},
}
sub_messages.append(sub_message_dict)
message_dict = {"role": "user", "content": sub_messages}
elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content}
if message.tool_calls:
message_dict["tool_calls"] = []
for function_call in message.tool_calls:
message_dict["tool_calls"].append(
{
"id": function_call.id,
"type": function_call.type,
"function": {
"name": function_call.function.name,
"arguments": function_call.function.arguments,
},
}
)
elif isinstance(message, ToolPromptMessage):
message = cast(ToolPromptMessage, message)
message_dict = {"role": "tool", "content": message.content, "tool_call_id": message.tool_call_id}
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content}
else:
raise ValueError(f"Got unknown type {message}")
if message.name:
message_dict["name"] = message.name
return message_dict
def _extract_response_tool_calls(self, response_tool_calls: list[dict]) -> list[AssistantPromptMessage.ToolCall]:
"""
Extract tool calls from response
:param response_tool_calls: response tool calls
:return: list of tool calls
"""
tool_calls = []
if response_tool_calls:
for response_tool_call in response_tool_calls:
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_tool_call["function"]["name"]
if response_tool_call.get("function", {}).get("name")
else "",
arguments=response_tool_call["function"]["arguments"]
if response_tool_call.get("function", {}).get("arguments")
else "",
)
tool_call = AssistantPromptMessage.ToolCall(
id=response_tool_call["id"] if response_tool_call.get("id") else "",
type=response_tool_call["type"] if response_tool_call.get("type") else "",
function=function,
)
tool_calls.append(tool_call)
return tool_calls
def _handle_generate_stream_response(
self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage]
) -> Generator:
"""
Handle llm stream response
:param model: model name
:param credentials: model credentials
:param response: streamed response
:param prompt_messages: prompt messages
:return: llm response chunk generator
"""
full_assistant_content = ""
chunk_index = 0
def create_final_llm_result_chunk(
index: int, message: AssistantPromptMessage, finish_reason: str
) -> LLMResultChunk:
# calculate num tokens
prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content)
completion_tokens = self._num_tokens_from_string(model, full_assistant_content)
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
return LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage),
)
tools_calls: list[AssistantPromptMessage.ToolCall] = []
finish_reason = "Unknown"
def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]):
def get_tool_call(tool_name: str):
if not tool_name:
return tools_calls[-1]
tool_call = next((tool_call for tool_call in tools_calls if tool_call.function.name == tool_name), None)
if tool_call is None:
tool_call = AssistantPromptMessage.ToolCall(
id="",
type="",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments=""),
)
tools_calls.append(tool_call)
return tool_call
for new_tool_call in new_tool_calls:
# get tool call
tool_call = get_tool_call(new_tool_call.function.name)
# update tool call
if new_tool_call.id:
tool_call.id = new_tool_call.id
if new_tool_call.type:
tool_call.type = new_tool_call.type
if new_tool_call.function.name:
tool_call.function.name = new_tool_call.function.name
if new_tool_call.function.arguments:
tool_call.function.arguments += new_tool_call.function.arguments
for chunk in response.iter_lines(decode_unicode=True, delimiter="\n\n"):
if chunk:
# ignore sse comments
if chunk.startswith(":"):
continue
decoded_chunk = chunk.strip().lstrip("data: ").lstrip()
chunk_json = None
try:
chunk_json = json.loads(decoded_chunk)
# stream ended
except json.JSONDecodeError as e:
yield create_final_llm_result_chunk(
index=chunk_index + 1,
message=AssistantPromptMessage(content=""),
finish_reason="Non-JSON encountered.",
)
break
if not chunk_json or len(chunk_json["choices"]) == 0:
continue
choice = chunk_json["choices"][0]
finish_reason = chunk_json["choices"][0].get("finish_reason")
chunk_index += 1
if "delta" in choice:
delta = choice["delta"]
delta_content = delta.get("content")
assistant_message_tool_calls = delta.get("tool_calls", None)
# assistant_message_function_call = delta.delta.function_call
# extract tool calls from response
if assistant_message_tool_calls:
tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
increase_tool_call(tool_calls)
if delta_content is None or delta_content == "":
continue
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=delta_content, tool_calls=tool_calls if assistant_message_tool_calls else []
)
full_assistant_content += delta_content
elif "text" in choice:
choice_text = choice.get("text", "")
if choice_text == "":
continue
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(content=choice_text)
full_assistant_content += choice_text
else:
continue
# check payload indicator for completion
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=chunk_index,
message=assistant_prompt_message,
),
)
chunk_index += 1
if tools_calls:
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=chunk_index,
message=AssistantPromptMessage(tool_calls=tools_calls, content=""),
),
)
yield create_final_llm_result_chunk(
index=chunk_index, message=AssistantPromptMessage(content=""), finish_reason=finish_reason
)

View File

@ -1,847 +0,0 @@
import json
import logging
from collections.abc import Generator
from decimal import Decimal
from typing import Optional, Union, cast
from urllib.parse import urljoin
import requests
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContent,
PromptMessageContentType,
PromptMessageFunction,
PromptMessageTool,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import (
AIModelEntity,
DefaultParameterName,
FetchFrom,
ModelFeature,
ModelPropertyKey,
ModelType,
ParameterRule,
ParameterType,
PriceConfig,
)
from core.model_runtime.errors.invoke import InvokeError
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat
from core.model_runtime.utils import helper
logger = logging.getLogger(__name__)
class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
"""
Model class for OpenAI large language model.
"""
def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
# text completion model
return self._generate(
model=model,
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user,
)
def get_num_tokens(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
) -> int:
"""
Get number of tokens for given prompt messages
:param model:
:param credentials:
:param prompt_messages:
:param tools: tools for tool calling
:return:
"""
return self._num_tokens_from_messages(model, prompt_messages, tools, credentials)
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials using requests to ensure compatibility with all providers following
OpenAI's API standard.
:param model: model name
:param credentials: model credentials
:return:
"""
try:
headers = {"Content-Type": "application/json"}
api_key = credentials.get("api_key")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
endpoint_url = credentials["endpoint_url"]
if not endpoint_url.endswith("/"):
endpoint_url += "/"
# prepare the payload for a simple ping to the model
data = {"model": model, "max_tokens": 5}
completion_type = LLMMode.value_of(credentials["mode"])
if completion_type is LLMMode.CHAT:
data["messages"] = [
{"role": "user", "content": "ping"},
]
endpoint_url = urljoin(endpoint_url, "chat/completions")
elif completion_type is LLMMode.COMPLETION:
data["prompt"] = "ping"
endpoint_url = urljoin(endpoint_url, "completions")
else:
raise ValueError("Unsupported completion type for model configuration.")
# send a post request to validate the credentials
response = requests.post(endpoint_url, headers=headers, json=data, timeout=(10, 300))
if response.status_code != 200:
raise CredentialsValidateFailedError(
f"Credentials validation failed with status code {response.status_code}"
)
try:
json_result = response.json()
except json.JSONDecodeError as e:
raise CredentialsValidateFailedError("Credentials validation failed: JSON decode error")
if completion_type is LLMMode.CHAT and json_result.get("object", "") == "":
json_result["object"] = "chat.completion"
elif completion_type is LLMMode.COMPLETION and json_result.get("object", "") == "":
json_result["object"] = "text_completion"
if completion_type is LLMMode.CHAT and (
"object" not in json_result or json_result["object"] != "chat.completion"
):
raise CredentialsValidateFailedError(
"Credentials validation failed: invalid response object, must be 'chat.completion'"
)
elif completion_type is LLMMode.COMPLETION and (
"object" not in json_result or json_result["object"] != "text_completion"
):
raise CredentialsValidateFailedError(
"Credentials validation failed: invalid response object, must be 'text_completion'"
)
except CredentialsValidateFailedError:
raise
except Exception as ex:
raise CredentialsValidateFailedError(f"An error occurred during credentials validation: {str(ex)}")
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
"""
generate custom model entities from credentials
"""
features = []
function_calling_type = credentials.get("function_calling_type", "no_call")
if function_calling_type == "function_call":
features.append(ModelFeature.TOOL_CALL)
elif function_calling_type == "tool_call":
features.append(ModelFeature.MULTI_TOOL_CALL)
stream_function_calling = credentials.get("stream_function_calling", "supported")
if stream_function_calling == "supported":
features.append(ModelFeature.STREAM_TOOL_CALL)
vision_support = credentials.get("vision_support", "not_support")
if vision_support == "support":
features.append(ModelFeature.VISION)
entity = AIModelEntity(
model=model,
label=I18nObject(en_US=model),
model_type=ModelType.LLM,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
features=features,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", "4096")),
ModelPropertyKey.MODE: credentials.get("mode"),
},
parameter_rules=[
ParameterRule(
name=DefaultParameterName.TEMPERATURE.value,
label=I18nObject(en_US="Temperature", zh_Hans="温度"),
help=I18nObject(
en_US="Kernel sampling threshold. Used to determine the randomness of the results."
"The higher the value, the stronger the randomness."
"The higher the possibility of getting different answers to the same question.",
zh_Hans="核采样阈值。用于决定结果随机性,取值越高随机性越强即相同的问题得到的不同答案的可能性越高。",
),
type=ParameterType.FLOAT,
default=float(credentials.get("temperature", 0.7)),
min=0,
max=2,
precision=2,
),
ParameterRule(
name=DefaultParameterName.TOP_P.value,
label=I18nObject(en_US="Top P", zh_Hans="Top P"),
help=I18nObject(
en_US="The probability threshold of the nucleus sampling method during the generation process."
"The larger the value is, the higher the randomness of generation will be."
"The smaller the value is, the higher the certainty of generation will be.",
zh_Hans="生成过程中核采样方法概率阈值。取值越大,生成的随机性越高;取值越小,生成的确定性越高。",
),
type=ParameterType.FLOAT,
default=float(credentials.get("top_p", 1)),
min=0,
max=1,
precision=2,
),
ParameterRule(
name=DefaultParameterName.FREQUENCY_PENALTY.value,
label=I18nObject(en_US="Frequency Penalty", zh_Hans="频率惩罚"),
help=I18nObject(
en_US="For controlling the repetition rate of words used by the model."
"Increasing this can reduce the repetition of the same words in the model's output.",
zh_Hans="用于控制模型已使用字词的重复率。 提高此项可以降低模型在输出中重复相同字词的重复度。",
),
type=ParameterType.FLOAT,
default=float(credentials.get("frequency_penalty", 0)),
min=-2,
max=2,
),
ParameterRule(
name=DefaultParameterName.PRESENCE_PENALTY.value,
label=I18nObject(en_US="Presence Penalty", zh_Hans="存在惩罚"),
help=I18nObject(
en_US="Used to control the repetition rate when generating models."
"Increasing this can reduce the repetition rate of model generation.",
zh_Hans="用于控制模型生成时的重复度。提高此项可以降低模型生成的重复度。",
),
type=ParameterType.FLOAT,
default=float(credentials.get("presence_penalty", 0)),
min=-2,
max=2,
),
ParameterRule(
name=DefaultParameterName.MAX_TOKENS.value,
label=I18nObject(en_US="Max Tokens", zh_Hans="最大标记"),
help=I18nObject(
en_US="Maximum length of tokens for the model response.", zh_Hans="模型回答的tokens的最大长度。"
),
type=ParameterType.INT,
default=512,
min=1,
max=int(credentials.get("max_tokens_to_sample", 4096)),
),
],
pricing=PriceConfig(
input=Decimal(credentials.get("input_price", 0)),
output=Decimal(credentials.get("output_price", 0)),
unit=Decimal(credentials.get("unit", 0)),
currency=credentials.get("currency", "USD"),
),
)
if credentials["mode"] == "chat":
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value
elif credentials["mode"] == "completion":
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value
else:
raise ValueError(f"Unknown completion type {credentials['completion_type']}")
return entity
# validate_credentials method has been rewritten to use the requests library for compatibility with all providers
# following OpenAI's API standard.
def _generate(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke llm completion model
:param model: model name
:param credentials: credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
headers = {
"Content-Type": "application/json",
"Accept-Charset": "utf-8",
}
extra_headers = credentials.get("extra_headers")
if extra_headers is not None:
headers = {
**headers,
**extra_headers,
}
api_key = credentials.get("api_key")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
endpoint_url = credentials["endpoint_url"]
if not endpoint_url.endswith("/"):
endpoint_url += "/"
data = {"model": model, "stream": stream, **model_parameters}
completion_type = LLMMode.value_of(credentials["mode"])
if completion_type is LLMMode.CHAT:
endpoint_url = urljoin(endpoint_url, "chat/completions")
data["messages"] = [self._convert_prompt_message_to_dict(m, credentials) for m in prompt_messages]
elif completion_type is LLMMode.COMPLETION:
endpoint_url = urljoin(endpoint_url, "completions")
data["prompt"] = prompt_messages[0].content
else:
raise ValueError("Unsupported completion type for model configuration.")
# annotate tools with names, descriptions, etc.
function_calling_type = credentials.get("function_calling_type", "no_call")
formatted_tools = []
if tools:
if function_calling_type == "function_call":
data["functions"] = [
{"name": tool.name, "description": tool.description, "parameters": tool.parameters}
for tool in tools
]
elif function_calling_type == "tool_call":
data["tool_choice"] = "auto"
for tool in tools:
formatted_tools.append(helper.dump_model(PromptMessageFunction(function=tool)))
data["tools"] = formatted_tools
if stop:
data["stop"] = stop
if user:
data["user"] = user
response = requests.post(endpoint_url, headers=headers, json=data, timeout=(10, 300), stream=stream)
if response.encoding is None or response.encoding == "ISO-8859-1":
response.encoding = "utf-8"
if response.status_code != 200:
raise InvokeError(f"API request failed with status code {response.status_code}: {response.text}")
if stream:
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
return self._handle_generate_response(model, credentials, response, prompt_messages)
def _handle_generate_stream_response(
self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage]
) -> Generator:
"""
Handle llm stream response
:param model: model name
:param credentials: model credentials
:param response: streamed response
:param prompt_messages: prompt messages
:return: llm response chunk generator
"""
full_assistant_content = ""
chunk_index = 0
def create_final_llm_result_chunk(
id: Optional[str], index: int, message: AssistantPromptMessage, finish_reason: str, usage: dict
) -> LLMResultChunk:
# calculate num tokens
prompt_tokens = usage and usage.get("prompt_tokens")
if prompt_tokens is None:
prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content)
completion_tokens = usage and usage.get("completion_tokens")
if completion_tokens is None:
completion_tokens = self._num_tokens_from_string(model, full_assistant_content)
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
return LLMResultChunk(
id=id,
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(index=index, message=message, finish_reason=finish_reason, usage=usage),
)
# delimiter for stream response, need unicode_escape
import codecs
delimiter = credentials.get("stream_mode_delimiter", "\n\n")
delimiter = codecs.decode(delimiter, "unicode_escape")
tools_calls: list[AssistantPromptMessage.ToolCall] = []
def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]):
def get_tool_call(tool_call_id: str):
if not tool_call_id:
return tools_calls[-1]
tool_call = next((tool_call for tool_call in tools_calls if tool_call.id == tool_call_id), None)
if tool_call is None:
tool_call = AssistantPromptMessage.ToolCall(
id=tool_call_id,
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""),
)
tools_calls.append(tool_call)
return tool_call
for new_tool_call in new_tool_calls:
# get tool call
tool_call = get_tool_call(new_tool_call.function.name)
# update tool call
if new_tool_call.id:
tool_call.id = new_tool_call.id
if new_tool_call.type:
tool_call.type = new_tool_call.type
if new_tool_call.function.name:
tool_call.function.name = new_tool_call.function.name
if new_tool_call.function.arguments:
tool_call.function.arguments += new_tool_call.function.arguments
finish_reason = None # The default value of finish_reason is None
message_id, usage = None, None
for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter):
chunk = chunk.strip()
if chunk:
# ignore sse comments
if chunk.startswith(":"):
continue
decoded_chunk = chunk.strip().lstrip("data: ").lstrip()
if decoded_chunk == "[DONE]": # Some provider returns "data: [DONE]"
continue
try:
chunk_json: dict = json.loads(decoded_chunk)
# stream ended
except json.JSONDecodeError as e:
yield create_final_llm_result_chunk(
id=message_id,
index=chunk_index + 1,
message=AssistantPromptMessage(content=""),
finish_reason="Non-JSON encountered.",
usage=usage,
)
break
if chunk_json:
if u := chunk_json.get("usage"):
usage = u
if not chunk_json or len(chunk_json["choices"]) == 0:
continue
choice = chunk_json["choices"][0]
finish_reason = chunk_json["choices"][0].get("finish_reason")
message_id = chunk_json.get("id")
chunk_index += 1
if "delta" in choice:
delta = choice["delta"]
delta_content = delta.get("content")
assistant_message_tool_calls = None
if "tool_calls" in delta and credentials.get("function_calling_type", "no_call") == "tool_call":
assistant_message_tool_calls = delta.get("tool_calls", None)
elif (
"function_call" in delta
and credentials.get("function_calling_type", "no_call") == "function_call"
):
assistant_message_tool_calls = [
{"id": "tool_call_id", "type": "function", "function": delta.get("function_call", {})}
]
# assistant_message_function_call = delta.delta.function_call
# extract tool calls from response
if assistant_message_tool_calls:
tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
increase_tool_call(tool_calls)
if delta_content is None or delta_content == "":
continue
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=delta_content,
)
# reset tool calls
tool_calls = []
full_assistant_content += delta_content
elif "text" in choice:
choice_text = choice.get("text", "")
if choice_text == "":
continue
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(content=choice_text)
full_assistant_content += choice_text
else:
continue
yield LLMResultChunk(
id=message_id,
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=chunk_index,
message=assistant_prompt_message,
),
)
chunk_index += 1
if tools_calls:
yield LLMResultChunk(
id=message_id,
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=chunk_index,
message=AssistantPromptMessage(tool_calls=tools_calls, content=""),
),
)
yield create_final_llm_result_chunk(
id=message_id,
index=chunk_index,
message=AssistantPromptMessage(content=""),
finish_reason=finish_reason,
usage=usage,
)
def _handle_generate_response(
self, model: str, credentials: dict, response: requests.Response, prompt_messages: list[PromptMessage]
) -> LLMResult:
response_json: dict = response.json()
completion_type = LLMMode.value_of(credentials["mode"])
output = response_json["choices"][0]
message_id = response_json.get("id")
response_content = ""
tool_calls = None
function_calling_type = credentials.get("function_calling_type", "no_call")
if completion_type is LLMMode.CHAT:
response_content = output.get("message", {})["content"]
if function_calling_type == "tool_call":
tool_calls = output.get("message", {}).get("tool_calls")
elif function_calling_type == "function_call":
tool_calls = output.get("message", {}).get("function_call")
elif completion_type is LLMMode.COMPLETION:
response_content = output["text"]
assistant_message = AssistantPromptMessage(content=response_content, tool_calls=[])
if tool_calls:
if function_calling_type == "tool_call":
assistant_message.tool_calls = self._extract_response_tool_calls(tool_calls)
elif function_calling_type == "function_call":
assistant_message.tool_calls = [self._extract_response_function_call(tool_calls)]
usage = response_json.get("usage")
if usage:
# transform usage
prompt_tokens = usage["prompt_tokens"]
completion_tokens = usage["completion_tokens"]
else:
# calculate num tokens
prompt_tokens = self._num_tokens_from_string(model, prompt_messages[0].content)
completion_tokens = self._num_tokens_from_string(model, assistant_message.content)
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
# transform response
result = LLMResult(
id=message_id,
model=response_json["model"],
prompt_messages=prompt_messages,
message=assistant_message,
usage=usage,
)
return result
def _convert_prompt_message_to_dict(self, message: PromptMessage, credentials: Optional[dict] = None) -> dict:
"""
Convert PromptMessage to dict for OpenAI API format
"""
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
if isinstance(message.content, str):
message_dict = {"role": "user", "content": message.content}
else:
sub_messages = []
for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(PromptMessageContent, message_content)
sub_message_dict = {"type": "text", "text": message_content.data}
sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, message_content)
sub_message_dict = {
"type": "image_url",
"image_url": {"url": message_content.data, "detail": message_content.detail.value},
}
sub_messages.append(sub_message_dict)
message_dict = {"role": "user", "content": sub_messages}
elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content}
if message.tool_calls:
function_calling_type = credentials.get("function_calling_type", "no_call")
if function_calling_type == "tool_call":
message_dict["tool_calls"] = [tool_call.dict() for tool_call in message.tool_calls]
elif function_calling_type == "function_call":
function_call = message.tool_calls[0]
message_dict["function_call"] = {
"name": function_call.function.name,
"arguments": function_call.function.arguments,
}
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolPromptMessage):
message = cast(ToolPromptMessage, message)
function_calling_type = credentials.get("function_calling_type", "no_call")
if function_calling_type == "tool_call":
message_dict = {"role": "tool", "content": message.content, "tool_call_id": message.tool_call_id}
elif function_calling_type == "function_call":
message_dict = {"role": "function", "content": message.content, "name": message.tool_call_id}
else:
raise ValueError(f"Got unknown type {message}")
if message.name and message_dict.get("role", "") != "tool":
message_dict["name"] = message.name
return message_dict
def _num_tokens_from_string(
self, model: str, text: Union[str, list[PromptMessageContent]], tools: Optional[list[PromptMessageTool]] = None
) -> int:
"""
Approximate num tokens for model with gpt2 tokenizer.
:param model: model name
:param text: prompt text
:param tools: tools for tool calling
:return: number of tokens
"""
if isinstance(text, str):
full_text = text
else:
full_text = ""
for message_content in text:
if message_content.type == PromptMessageContentType.TEXT:
message_content = cast(PromptMessageContent, message_content)
full_text += message_content.data
num_tokens = self._get_num_tokens_by_gpt2(full_text)
if tools:
num_tokens += self._num_tokens_for_tools(tools)
return num_tokens
def _num_tokens_from_messages(
self,
model: str,
messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
credentials: Optional[dict] = None,
) -> int:
"""
Approximate num tokens with GPT2 tokenizer.
"""
tokens_per_message = 3
tokens_per_name = 1
num_tokens = 0
messages_dict = [self._convert_prompt_message_to_dict(m, credentials) for m in messages]
for message in messages_dict:
num_tokens += tokens_per_message
for key, value in message.items():
# Cast str(value) in case the message value is not a string
# This occurs with function messages
# TODO: The current token calculation method for the image type is not implemented,
# which need to download the image and then get the resolution for calculation,
# and will increase the request delay
if isinstance(value, list):
text = ""
for item in value:
if isinstance(item, dict) and item["type"] == "text":
text += item["text"]
value = text
if key == "tool_calls":
for tool_call in value:
for t_key, t_value in tool_call.items():
num_tokens += self._get_num_tokens_by_gpt2(t_key)
if t_key == "function":
for f_key, f_value in t_value.items():
num_tokens += self._get_num_tokens_by_gpt2(f_key)
num_tokens += self._get_num_tokens_by_gpt2(f_value)
else:
num_tokens += self._get_num_tokens_by_gpt2(t_key)
num_tokens += self._get_num_tokens_by_gpt2(t_value)
else:
num_tokens += self._get_num_tokens_by_gpt2(str(value))
if key == "name":
num_tokens += tokens_per_name
# every reply is primed with <im_start>assistant
num_tokens += 3
if tools:
num_tokens += self._num_tokens_for_tools(tools)
return num_tokens
def _num_tokens_for_tools(self, tools: list[PromptMessageTool]) -> int:
"""
Calculate num tokens for tool calling with tiktoken package.
:param tools: tools for tool calling
:return: number of tokens
"""
num_tokens = 0
for tool in tools:
num_tokens += self._get_num_tokens_by_gpt2("type")
num_tokens += self._get_num_tokens_by_gpt2("function")
num_tokens += self._get_num_tokens_by_gpt2("function")
# calculate num tokens for function object
num_tokens += self._get_num_tokens_by_gpt2("name")
num_tokens += self._get_num_tokens_by_gpt2(tool.name)
num_tokens += self._get_num_tokens_by_gpt2("description")
num_tokens += self._get_num_tokens_by_gpt2(tool.description)
parameters = tool.parameters
num_tokens += self._get_num_tokens_by_gpt2("parameters")
if "title" in parameters:
num_tokens += self._get_num_tokens_by_gpt2("title")
num_tokens += self._get_num_tokens_by_gpt2(parameters.get("title"))
num_tokens += self._get_num_tokens_by_gpt2("type")
num_tokens += self._get_num_tokens_by_gpt2(parameters.get("type"))
if "properties" in parameters:
num_tokens += self._get_num_tokens_by_gpt2("properties")
for key, value in parameters.get("properties").items():
num_tokens += self._get_num_tokens_by_gpt2(key)
for field_key, field_value in value.items():
num_tokens += self._get_num_tokens_by_gpt2(field_key)
if field_key == "enum":
for enum_field in field_value:
num_tokens += 3
num_tokens += self._get_num_tokens_by_gpt2(enum_field)
else:
num_tokens += self._get_num_tokens_by_gpt2(field_key)
num_tokens += self._get_num_tokens_by_gpt2(str(field_value))
if "required" in parameters:
num_tokens += self._get_num_tokens_by_gpt2("required")
for required_field in parameters["required"]:
num_tokens += 3
num_tokens += self._get_num_tokens_by_gpt2(required_field)
return num_tokens
def _extract_response_tool_calls(self, response_tool_calls: list[dict]) -> list[AssistantPromptMessage.ToolCall]:
"""
Extract tool calls from response
:param response_tool_calls: response tool calls
:return: list of tool calls
"""
tool_calls = []
if response_tool_calls:
for response_tool_call in response_tool_calls:
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_tool_call.get("function", {}).get("name", ""),
arguments=response_tool_call.get("function", {}).get("arguments", ""),
)
tool_call = AssistantPromptMessage.ToolCall(
id=response_tool_call.get("id", ""), type=response_tool_call.get("type", ""), function=function
)
tool_calls.append(tool_call)
return tool_calls
def _extract_response_function_call(self, response_function_call) -> AssistantPromptMessage.ToolCall:
"""
Extract function call from response
:param response_function_call: response function call
:return: tool call
"""
tool_call = None
if response_function_call:
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_function_call.get("name", ""), arguments=response_function_call.get("arguments", "")
)
tool_call = AssistantPromptMessage.ToolCall(
id=response_function_call.get("id", ""), type="function", function=function
)
return tool_call

Binary file not shown.

After

Width:  |  Height:  |  Size: 11 KiB

View File

@ -0,0 +1,3 @@
<svg width="1200" height="925" viewBox="0 0 1200 925" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M780.152 250.999L907.882 462.174C907.882 462.174 880.925 510.854 867.43 535.21C834.845 594.039 764.171 612.49 710.442 508.333L420.376 0H0L459.926 803.307C552.303 964.663 787.366 964.663 879.743 803.307C989.874 610.952 1089.87 441.97 1200 249.646L1052.28 0H639.519L780.152 250.999Z" fill="#3366FF"/>
</svg>

After

Width:  |  Height:  |  Size: 417 B

View File

@ -0,0 +1,83 @@
from decimal import Decimal
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.model_entities import (
AIModelEntity,
DefaultParameterName,
FetchFrom,
ModelPropertyKey,
ModelType,
ParameterRule,
ParameterType,
PriceConfig,
)
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
class VesslAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
features = []
entity = AIModelEntity(
model=model,
label=I18nObject(en_US=model),
model_type=ModelType.LLM,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
features=features,
model_properties={
ModelPropertyKey.MODE: credentials.get("mode"),
},
parameter_rules=[
ParameterRule(
name=DefaultParameterName.TEMPERATURE.value,
label=I18nObject(en_US="Temperature"),
type=ParameterType.FLOAT,
default=float(credentials.get("temperature", 0.7)),
min=0,
max=2,
precision=2,
),
ParameterRule(
name=DefaultParameterName.TOP_P.value,
label=I18nObject(en_US="Top P"),
type=ParameterType.FLOAT,
default=float(credentials.get("top_p", 1)),
min=0,
max=1,
precision=2,
),
ParameterRule(
name=DefaultParameterName.TOP_K.value,
label=I18nObject(en_US="Top K"),
type=ParameterType.INT,
default=int(credentials.get("top_k", 50)),
min=-2147483647,
max=2147483647,
precision=0,
),
ParameterRule(
name=DefaultParameterName.MAX_TOKENS.value,
label=I18nObject(en_US="Max Tokens"),
type=ParameterType.INT,
default=512,
min=1,
max=int(credentials.get("max_tokens_to_sample", 4096)),
),
],
pricing=PriceConfig(
input=Decimal(credentials.get("input_price", 0)),
output=Decimal(credentials.get("output_price", 0)),
unit=Decimal(credentials.get("unit", 0)),
currency=credentials.get("currency", "USD"),
),
)
if credentials["mode"] == "chat":
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.CHAT.value
elif credentials["mode"] == "completion":
entity.model_properties[ModelPropertyKey.MODE] = LLMMode.COMPLETION.value
else:
raise ValueError(f"Unknown completion type {credentials['completion_type']}")
return entity

View File

@ -0,0 +1,10 @@
import logging
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
logger = logging.getLogger(__name__)
class VesslAIProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
pass

View File

@ -0,0 +1,56 @@
provider: vessl_ai
label:
en_US: vessl_ai
icon_small:
en_US: icon_s_en.svg
icon_large:
en_US: icon_l_en.png
background: "#F1EFED"
help:
title:
en_US: How to deploy VESSL AI LLM Model Endpoint
url:
en_US: https://docs.vessl.ai/guides/get-started/llama3-deployment
supported_model_types:
- llm
configurate_methods:
- customizable-model
model_credential_schema:
model:
label:
en_US: Model Name
placeholder:
en_US: Enter your model name
credential_form_schemas:
- variable: endpoint_url
label:
en_US: endpoint url
type: text-input
required: true
placeholder:
en_US: Enter the url of your endpoint url
- variable: api_key
required: true
label:
en_US: API Key
type: secret-input
placeholder:
en_US: Enter your VESSL AI secret key
- variable: mode
show_on:
- variable: __model_type
value: llm
label:
en_US: Completion mode
type: select
required: false
default: chat
placeholder:
en_US: Select completion mode
options:
- value: completion
label:
en_US: Completion
- value: chat
label:
en_US: Chat

View File

@ -413,7 +413,7 @@ class PluginModelManager(BasePluginManager):
"""
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/model/voices",
path=f"plugin/{tenant_id}/dispatch/tts/model/voices",
type=PluginVoicesResponse,
data=jsonable_encoder(
{
@ -434,8 +434,10 @@ class PluginModelManager(BasePluginManager):
)
for resp in response:
voices = []
for voice in resp.voices:
return [{"name": voice.name, "value": voice.value}]
voices.append({"name": voice.name, "value": voice.value})
return voices
return []

View File

@ -34,6 +34,8 @@ class RetrievalService:
reranking_mode: Optional[str] = "reranking_model",
weights: Optional[dict] = None,
):
if not query:
return []
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if not dataset:
return []

View File

@ -3,11 +3,13 @@ import time
import uuid
from typing import Any
import numpy as np
from pydantic import BaseModel, model_validator
from pymochow import MochowClient
from pymochow.auth.bce_credentials import BceCredentials
from pymochow.configuration import Configuration
from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, TableState
from pymochow.exception import ServerError
from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, ServerErrCode, TableState
from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex
from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row
@ -116,6 +118,7 @@ class BaiduVector(BaseVector):
self._db.table(self._collection_name).delete(filter=f"{key} = '{value}'")
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
query_vector = [float(val) if isinstance(val, np.float64) else val for val in query_vector]
anns = AnnSearch(
vector_field=self.field_vector,
vector_floats=query_vector,
@ -149,7 +152,13 @@ class BaiduVector(BaseVector):
return docs
def delete(self) -> None:
self._db.drop_table(table_name=self._collection_name)
try:
self._db.drop_table(table_name=self._collection_name)
except ServerError as e:
if e.code == ServerErrCode.TABLE_NOT_EXIST:
pass
else:
raise
def _init_client(self, config) -> MochowClient:
config = Configuration(credentials=BceCredentials(config.account, config.api_key), endpoint=config.endpoint)
@ -166,7 +175,14 @@ class BaiduVector(BaseVector):
if exists:
return self._client.database(self._client_config.database)
else:
return self._client.create_database(database_name=self._client_config.database)
try:
self._client.create_database(database_name=self._client_config.database)
except ServerError as e:
if e.code == ServerErrCode.DB_ALREADY_EXIST:
pass
else:
raise
return
def _table_existed(self) -> bool:
tables = self._db.list_table()
@ -175,7 +191,7 @@ class BaiduVector(BaseVector):
def _create_table(self, dimension: int) -> None:
# Try to grab distributed lock and create table
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
with redis_client.lock(lock_name, timeout=20):
with redis_client.lock(lock_name, timeout=60):
table_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
if redis_client.get(table_exist_cache_key):
return
@ -238,15 +254,14 @@ class BaiduVector(BaseVector):
description="Table for Dify",
)
# Wait for table created
while True:
time.sleep(1)
table = self._db.describe_table(self._collection_name)
if table.state == TableState.NORMAL:
break
redis_client.set(table_exist_cache_key, 1, ex=3600)
# Wait for table created
while True:
time.sleep(1)
table = self._db.describe_table(self._collection_name)
if table.state == TableState.NORMAL:
break
class BaiduVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaiduVector:

View File

@ -37,7 +37,7 @@ class TidbService:
}
spending_limit = {
"monthly": 100,
"monthly": dify_config.TIDB_SPEND_LIMIT,
}
password = str(uuid.uuid4()).replace("-", "")[:16]
display_name = str(uuid.uuid4()).replace("-", "")[:16]

View File

@ -27,18 +27,17 @@ class RerankModelRunner(BaseRerankRunner):
:return:
"""
docs = []
doc_id = []
doc_id = set()
unique_documents = []
dify_documents = [item for item in documents if item.provider == "dify"]
external_documents = [item for item in documents if item.provider == "external"]
for document in dify_documents:
if document.metadata["doc_id"] not in doc_id:
doc_id.append(document.metadata["doc_id"])
for document in documents:
if document.provider == "dify" and document.metadata["doc_id"] not in doc_id:
doc_id.add(document.metadata["doc_id"])
docs.append(document.page_content)
unique_documents.append(document)
for document in external_documents:
docs.append(document.page_content)
unique_documents.append(document)
elif document.provider == "external":
if document not in unique_documents:
docs.append(document.page_content)
unique_documents.append(document)
documents = unique_documents

View File

@ -116,10 +116,8 @@ class ToolInvokeMessage(BaseModel):
class VariableMessage(BaseModel):
variable_name: str = Field(..., description="The name of the variable")
variable_value: str = Field(...,
description="The value of the variable")
stream: bool = Field(
default=False, description="Whether the variable is streamed")
variable_value: str = Field(..., description="The value of the variable")
stream: bool = Field(default=False, description="Whether the variable is streamed")
@field_validator("variable_value", mode="before")
@classmethod
@ -133,8 +131,7 @@ class ToolInvokeMessage(BaseModel):
# if stream is true, the value must be a string
if values.get("stream"):
if not isinstance(value, str):
raise ValueError(
"When 'stream' is True, 'variable_value' must be a string.")
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
return value
@ -271,8 +268,7 @@ class ToolParameter(BaseModel):
return str(value)
except Exception:
raise ValueError(
f"The tool parameter value {value} is not in correct type.")
raise ValueError(f"The tool parameter value {value} is not in correct type of {self.as_normal_type()}.")
class ToolParameterForm(Enum):
SCHEMA = "schema" # should be set while adding tool
@ -280,17 +276,12 @@ class ToolParameter(BaseModel):
LLM = "llm" # will be set by LLM
name: str = Field(..., description="The name of the parameter")
label: I18nObject = Field(...,
description="The label presented to the user")
human_description: Optional[I18nObject] = Field(
default=None, description="The description presented to the user")
placeholder: Optional[I18nObject] = Field(
default=None, description="The placeholder presented to the user")
type: ToolParameterType = Field(...,
description="The type of the parameter")
label: I18nObject = Field(..., description="The label presented to the user")
human_description: Optional[I18nObject] = Field(default=None, description="The description presented to the user")
placeholder: Optional[I18nObject] = Field(default=None, description="The placeholder presented to the user")
type: ToolParameterType = Field(..., description="The type of the parameter")
scope: AppSelectorScope | ModelConfigScope | None = None
form: ToolParameterForm = Field(...,
description="The form of the parameter, schema/form/llm")
form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm")
llm_description: Optional[str] = None
required: Optional[bool] = False
default: Optional[Union[float, int, str]] = None
@ -346,8 +337,7 @@ class ToolParameter(BaseModel):
class ToolProviderIdentity(BaseModel):
author: str = Field(..., description="The author of the tool")
name: str = Field(..., description="The name of the tool")
description: I18nObject = Field(...,
description="The description of the tool")
description: I18nObject = Field(..., description="The description of the tool")
icon: str = Field(..., description="The icon of the tool")
label: I18nObject = Field(..., description="The label of the tool")
tags: Optional[list[ToolLabelEnum]] = Field(
@ -365,8 +355,7 @@ class ToolIdentity(BaseModel):
class ToolDescription(BaseModel):
human: I18nObject = Field(...,
description="The description presented to the user")
human: I18nObject = Field(..., description="The description presented to the user")
llm: str = Field(..., description="The description presented to the LLM")
@ -375,8 +364,7 @@ class ToolEntity(BaseModel):
parameters: list[ToolParameter] = Field(default_factory=list)
description: Optional[ToolDescription] = None
output_schema: Optional[dict] = None
has_runtime_parameters: bool = Field(
default=False, description="Whether the tool has runtime parameters")
has_runtime_parameters: bool = Field(default=False, description="Whether the tool has runtime parameters")
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
@ -403,10 +391,8 @@ class WorkflowToolParameterConfiguration(BaseModel):
"""
name: str = Field(..., description="The name of the parameter")
description: str = Field(...,
description="The description of the parameter")
form: ToolParameter.ToolParameterForm = Field(
..., description="The form of the parameter")
description: str = Field(..., description="The description of the parameter")
form: ToolParameter.ToolParameterForm = Field(..., description="The form of the parameter")
class ToolInvokeMeta(BaseModel):
@ -414,8 +400,7 @@ class ToolInvokeMeta(BaseModel):
Tool invoke meta
"""
time_cost: float = Field(...,
description="The time cost of the tool invoke")
time_cost: float = Field(..., description="The time cost of the tool invoke")
error: Optional[str] = None
tool_config: Optional[dict] = None
@ -474,5 +459,4 @@ class ToolProviderID:
if not re.match(r"^[a-z0-9_-]+\/[a-z0-9_-]+\/[a-z0-9_-]+$", value):
raise ValueError("Invalid plugin id")
self.organization, self.plugin_name, self.provider_name = value.split(
"/")
self.organization, self.plugin_name, self.provider_name = value.split("/")

View File

@ -1,42 +0,0 @@
from typing import Any
import requests
class AliYuqueTool:
# yuque service url
server_url = "https://www.yuque.com"
@staticmethod
def auth(token):
session = requests.Session()
session.headers.update({"Accept": "application/json", "X-Auth-Token": token})
login = session.request("GET", AliYuqueTool.server_url + "/api/v2/user")
login.raise_for_status()
resp = login.json()
return resp
def request(self, method: str, token, tool_parameters: dict[str, Any], path: str) -> str:
if not token:
raise Exception("token is required")
session = requests.Session()
session.headers.update({"accept": "application/json", "X-Auth-Token": token})
new_params = {**tool_parameters}
replacements = {k: v for k, v in new_params.items() if f"{{{k}}}" in path}
for key, value in replacements.items():
path = path.replace(f"{{{key}}}", str(value))
del new_params[key]
if method.upper() in {"POST", "PUT"}:
session.headers.update(
{
"Content-Type": "application/json",
}
)
response = session.request(method.upper(), self.server_url + path, json=new_params)
else:
response = session.request(method, self.server_url + path, params=new_params)
response.raise_for_status()
return response.text

View File

@ -1,15 +0,0 @@
from typing import Any, Union
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool
from core.tools.tool.builtin_tool import BuiltinTool
class AliYuqueCreateDocumentTool(AliYuqueTool, BuiltinTool):
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
token = self.runtime.credentials.get("token", None)
if not token:
raise Exception("token is required")
return self.create_text_message(self.request("POST", token, tool_parameters, "/api/v2/repos/{book_id}/docs"))

View File

@ -1,99 +0,0 @@
identity:
name: aliyuque_create_document
author: 佐井
label:
en_US: Create Document
zh_Hans: 创建文档
icon: icon.svg
description:
human:
en_US: Creates a new document within a knowledge base without automatic addition to the table of contents. Requires a subsequent call to the "knowledge base directory update API". Supports setting visibility, format, and content. # 接口英文描述
zh_Hans: 在知识库中创建新文档,但不会自动加入目录,需额外调用“知识库目录更新接口”。允许设置公开性、格式及正文内容。
llm: Creates docs in a KB.
parameters:
- name: book_id
type: string
required: true
form: llm
label:
en_US: Knowledge Base ID
zh_Hans: 知识库ID
human_description:
en_US: The unique identifier of the knowledge base where the document will be created.
zh_Hans: 文档将被创建的知识库的唯一标识。
llm_description: ID of the target knowledge base.
- name: title
type: string
required: false
form: llm
label:
en_US: Title
zh_Hans: 标题
human_description:
en_US: The title of the document, defaults to 'Untitled' if not provided.
zh_Hans: 文档标题,默认为'无标题'如未提供。
llm_description: Title of the document, defaults to 'Untitled'.
- name: public
type: select
required: false
form: llm
options:
- value: 0
label:
en_US: Private
zh_Hans: 私密
- value: 1
label:
en_US: Public
zh_Hans: 公开
- value: 2
label:
en_US: Enterprise-only
zh_Hans: 企业内公开
label:
en_US: Visibility
zh_Hans: 公开性
human_description:
en_US: Document visibility (0 Private, 1 Public, 2 Enterprise-only).
zh_Hans: 文档可见性0 私密, 1 公开, 2 企业内公开)。
llm_description: Doc visibility options, 0-private, 1-public, 2-enterprise.
- name: format
type: select
required: false
form: llm
options:
- value: markdown
label:
en_US: markdown
zh_Hans: markdown
- value: html
label:
en_US: html
zh_Hans: html
- value: lake
label:
en_US: lake
zh_Hans: lake
label:
en_US: Content Format
zh_Hans: 内容格式
human_description:
en_US: Format of the document content (markdown, HTML, Lake).
zh_Hans: 文档内容格式markdown, HTML, Lake
llm_description: Content format choices, markdown, HTML, Lake.
- name: body
type: string
required: true
form: llm
label:
en_US: Body Content
zh_Hans: 正文内容
human_description:
en_US: The actual content of the document.
zh_Hans: 文档的实际内容。
llm_description: Content of the document.

View File

@ -1,17 +0,0 @@
from typing import Any, Union
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool
from core.tools.tool.builtin_tool import BuiltinTool
class AliYuqueDeleteDocumentTool(AliYuqueTool, BuiltinTool):
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
token = self.runtime.credentials.get("token", None)
if not token:
raise Exception("token is required")
return self.create_text_message(
self.request("DELETE", token, tool_parameters, "/api/v2/repos/{book_id}/docs/{id}")
)

View File

@ -1,37 +0,0 @@
identity:
name: aliyuque_delete_document
author: 佐井
label:
en_US: Delete Document
zh_Hans: 删除文档
icon: icon.svg
description:
human:
en_US: Delete Document
zh_Hans: 根据id删除文档
llm: Delete document.
parameters:
- name: book_id
type: string
required: true
form: llm
label:
en_US: Knowledge Base ID
zh_Hans: 知识库ID
human_description:
en_US: The unique identifier of the knowledge base where the document will be created.
zh_Hans: 文档将被创建的知识库的唯一标识。
llm_description: ID of the target knowledge base.
- name: id
type: string
required: true
form: llm
label:
en_US: Document ID or Path
zh_Hans: 文档 ID or 路径
human_description:
en_US: Document ID or path.
zh_Hans: 文档 ID or 路径。
llm_description: Document ID or path.

View File

@ -1,17 +0,0 @@
from typing import Any, Union
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool
from core.tools.tool.builtin_tool import BuiltinTool
class AliYuqueDescribeBookIndexPageTool(AliYuqueTool, BuiltinTool):
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
token = self.runtime.credentials.get("token", None)
if not token:
raise Exception("token is required")
return self.create_text_message(
self.request("GET", token, tool_parameters, "/api/v2/repos/{group_login}/{book_slug}/index_page")
)

View File

@ -1,15 +0,0 @@
from typing import Any, Union
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool
from core.tools.tool.builtin_tool import BuiltinTool
class YuqueDescribeBookTableOfContentsTool(AliYuqueTool, BuiltinTool):
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> (Union)[ToolInvokeMessage, list[ToolInvokeMessage]]:
token = self.runtime.credentials.get("token", None)
if not token:
raise Exception("token is required")
return self.create_text_message(self.request("GET", token, tool_parameters, "/api/v2/repos/{book_id}/toc"))

View File

@ -1,25 +0,0 @@
identity:
name: aliyuque_describe_book_table_of_contents
author: 佐井
label:
en_US: Get Book's Table of Contents
zh_Hans: 获取知识库的目录
icon: icon.svg
description:
human:
en_US: Get Book's Table of Contents.
zh_Hans: 获取知识库的目录。
llm: Get Book's Table of Contents.
parameters:
- name: book_id
type: string
required: true
form: llm
label:
en_US: Book ID
zh_Hans: 知识库 ID
human_description:
en_US: Book ID.
zh_Hans: 知识库 ID。
llm_description: Book ID.

View File

@ -1,52 +0,0 @@
import json
from typing import Any, Union
from urllib.parse import urlparse
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool
from core.tools.tool.builtin_tool import BuiltinTool
class AliYuqueDescribeDocumentContentTool(AliYuqueTool, BuiltinTool):
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
new_params = {**tool_parameters}
token = new_params.pop("token")
if not token or token.lower() == "none":
token = self.runtime.credentials.get("token", None)
if not token:
raise Exception("token is required")
new_params = {**tool_parameters}
url = new_params.pop("url")
if not url or not url.startswith("http"):
raise Exception("url is not valid")
parsed_url = urlparse(url)
path_parts = parsed_url.path.strip("/").split("/")
if len(path_parts) < 3:
raise Exception("url is not correct")
doc_id = path_parts[-1]
book_slug = path_parts[-2]
group_id = path_parts[-3]
new_params["group_login"] = group_id
new_params["book_slug"] = book_slug
index_page = json.loads(
self.request("GET", token, new_params, "/api/v2/repos/{group_login}/{book_slug}/index_page")
)
book_id = index_page.get("data", {}).get("book", {}).get("id")
if not book_id:
raise Exception(f"can not parse book_id from {index_page}")
new_params["book_id"] = book_id
new_params["id"] = doc_id
data = self.request("GET", token, new_params, "/api/v2/repos/{book_id}/docs/{id}")
data = json.loads(data)
body_only = tool_parameters.get("body_only") or ""
if body_only.lower() == "true":
return self.create_text_message(data.get("data").get("body"))
else:
raw = data.get("data")
del raw["body_lake"]
del raw["body_html"]
return self.create_text_message(json.dumps(data))

View File

@ -1,17 +0,0 @@
from typing import Any, Union
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool
from core.tools.tool.builtin_tool import BuiltinTool
class AliYuqueDescribeDocumentsTool(AliYuqueTool, BuiltinTool):
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
token = self.runtime.credentials.get("token", None)
if not token:
raise Exception("token is required")
return self.create_text_message(
self.request("GET", token, tool_parameters, "/api/v2/repos/{book_id}/docs/{id}")
)

View File

@ -1,38 +0,0 @@
identity:
name: aliyuque_describe_documents
author: 佐井
label:
en_US: Get Doc Detail
zh_Hans: 获取文档详情
icon: icon.svg
description:
human:
en_US: Retrieves detailed information of a specific document identified by its ID or path within a knowledge base.
zh_Hans: 根据知识库ID和文档ID或路径获取文档详细信息。
llm: Fetches detailed doc info using ID/path from a knowledge base; supports doc lookup in Yuque.
parameters:
- name: book_id
type: string
required: true
form: llm
label:
en_US: Knowledge Base ID
zh_Hans: 知识库 ID
human_description:
en_US: Identifier for the knowledge base where the document resides.
zh_Hans: 文档所属知识库的唯一标识。
llm_description: ID of the knowledge base holding the document.
- name: id
type: string
required: true
form: llm
label:
en_US: Document ID or Path
zh_Hans: 文档 ID 或路径
human_description:
en_US: The unique identifier or path of the document to retrieve.
zh_Hans: 需要获取的文档的ID或其在知识库中的路径。
llm_description: Unique doc ID or its path for retrieval.

View File

@ -1,21 +0,0 @@
from typing import Any, Union
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool
from core.tools.tool.builtin_tool import BuiltinTool
class YuqueDescribeBookTableOfContentsTool(AliYuqueTool, BuiltinTool):
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> (Union)[ToolInvokeMessage, list[ToolInvokeMessage]]:
token = self.runtime.credentials.get("token", None)
if not token:
raise Exception("token is required")
doc_ids = tool_parameters.get("doc_ids")
if doc_ids:
doc_ids = [int(doc_id.strip()) for doc_id in doc_ids.split(",")]
tool_parameters["doc_ids"] = doc_ids
return self.create_text_message(self.request("PUT", token, tool_parameters, "/api/v2/repos/{book_id}/toc"))

Some files were not shown because too many files have changed in this diff Show More