commit db896255d625e2418787f1b53f0dac091c57df1e Author: John Wang Date: Mon May 15 08:51:32 2023 +0800 Initial commit diff --git a/.github/workflows/build-api-image.sh b/.github/workflows/build-api-image.sh new file mode 100644 index 0000000000..b7c92525fb --- /dev/null +++ b/.github/workflows/build-api-image.sh @@ -0,0 +1,61 @@ +#!/usr/bin/env bash + +set -eo pipefail + +SHA=$(git rev-parse HEAD) +REPO_NAME=langgenius/dify +API_REPO_NAME="${REPO_NAME}-api" + +if [[ "${GITHUB_EVENT_NAME}" == "pull_request" ]]; then + REFSPEC=$(echo "${GITHUB_HEAD_REF}" | sed 's/[^a-zA-Z0-9]/-/g' | head -c 40) + PR_NUM=$(echo "${GITHUB_REF}" | sed 's:refs/pull/::' | sed 's:/merge::') + LATEST_TAG="pr-${PR_NUM}" + CACHE_FROM_TAG="latest" +elif [[ "${GITHUB_EVENT_NAME}" == "release" ]]; then + REFSPEC=$(echo "${GITHUB_REF}" | sed 's:refs/tags/::' | head -c 40) + LATEST_TAG="${REFSPEC}" + CACHE_FROM_TAG="latest" +else + REFSPEC=$(echo "${GITHUB_REF}" | sed 's:refs/heads/::' | sed 's/[^a-zA-Z0-9]/-/g' | head -c 40) + LATEST_TAG="${REFSPEC}" + CACHE_FROM_TAG="${REFSPEC}" +fi + +if [[ "${REFSPEC}" == "main" ]]; then + LATEST_TAG="latest" + CACHE_FROM_TAG="latest" +fi + +echo "Pulling cache image ${API_REPO_NAME}:${CACHE_FROM_TAG}" +if docker pull "${API_REPO_NAME}:${CACHE_FROM_TAG}"; then + API_CACHE_FROM_SCRIPT="--cache-from ${API_REPO_NAME}:${CACHE_FROM_TAG}" +else + echo "WARNING: Failed to pull ${API_REPO_NAME}:${CACHE_FROM_TAG}, disable build image cache." + API_CACHE_FROM_SCRIPT="" +fi + + +cat</langgenius-gateway.git +``` + +### Install backend + +To learn how to install the backend application, please refer to the [Backend README](api/README.md). + +### Install frontend + +To learn how to install the frontend application, please refer to the [Frontend README](web/README.md). + +### Visit dify in your browser + +Finally, you can now visit [http://localhost:3000](http://localhost:3000) to view the [Dify](https://dify.ai) in local environment. + + +## Create a pull request + +After making your changes, open a pull request (PR). Once you submit your pull request, others from the Dify team/community will review it with you. + +Did you have an issue, like a merge conflict, or don't know how to open a pull request? Check out [GitHub's pull request tutorial](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests) on how to resolve merge conflicts and other issues. Once your PR has been merged, you will be proudly listed as a contributor in the [contributor chart](https://github.com/langgenius/langgenius-gateway/graphs/contributors). + +## Community channels + +Stuck somewhere? Have any questions? Join the [Discord Community Server](https://discord.gg/AhzKf7dNgk). We are here to help! diff --git a/CONTRIBUTING_CN.md b/CONTRIBUTING_CN.md new file mode 100644 index 0000000000..51327d24b7 --- /dev/null +++ b/CONTRIBUTING_CN.md @@ -0,0 +1,53 @@ +# 贡献 + +感谢您对 [Dify](https://dify.ai) 的兴趣,并希望您能够做出贡献!在开始之前,请先阅读[行为准则](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md)并查看[现有问题](https://github.com/langgenius/dify/issues)。 +本文档介绍了如何设置开发环境以构建和测试 [Dify](https://dify.ai)。 + +### 安装依赖项 + +您需要在计算机上安装和配置以下依赖项才能构建 [Dify](https://dify.ai): + +- [Git](http://git-scm.com/) +- [Docker](https://www.docker.com/) +- [Docker Compose](https://docs.docker.com/compose/install/) +- [Node.js v18.x (LTS)](http://nodejs.org) +- [npm](https://www.npmjs.com/) 版本 8.x.x 或 [Yarn](https://yarnpkg.com/) +- [Python](https://www.python.org/) 版本 3.10.x + +## 本地开发 + +要设置一个可工作的开发环境,只需 fork 项目的 git 存储库,并使用适当的软件包管理器安装后端和前端依赖项,然后创建并运行 docker-compose 堆栈。 + +### Fork存储库 + +您需要 fork [存储库](https://github.com/langgenius/dify)。 + +### 克隆存储库 + +克隆您在 GitHub 上 fork 的存储库: + +``` +git clone git@github.com:/dify.git +``` + +### 安装后端 + +要了解如何安装后端应用程序,请参阅[后端 README](api/README.md)。 + +### 安装前端 + +要了解如何安装前端应用程序,请参阅[前端 README](web/README.md)。 + +### 在浏览器中访问 Dify + +最后,您现在可以访问 [http://localhost:3000](http://localhost:3000) 在本地环境中查看 [Dify](https://dify.ai)。 + +## 创建拉取请求 + +在进行更改后,打开一个拉取请求(PR)。提交拉取请求后,Dify 团队/社区的其他人将与您一起审查它。 + +如果遇到问题,比如合并冲突或不知道如何打开拉取请求,请查看 GitHub 的[拉取请求教程](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests),了解如何解决合并冲突和其他问题。一旦您的 PR 被合并,您将自豪地被列为[贡献者表](https://github.com/langgenius/dify/graphs/contributors)中的一员。 + +## 社区渠道 + +遇到困难了吗?有任何问题吗? 加入 [Discord Community Server](https://discord.gg/AhzKf7dNgk),我们将为您提供帮助。 diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000..d5d166643f --- /dev/null +++ b/LICENSE @@ -0,0 +1,46 @@ +# Dify Open Source License + +The Dify project uses a combination of the Apache License 2.0, MIT License, and an additional agreement to protect against direct competition with Dify Cloud services. + +As a contributor, you should agree that your contributed code: +a. Might be subject to a more permissive open source license in the future. +b. Can be used for commercial purposes, such as Dify's cloud business. + +The following components are open source under the MIT license, allowing you to build and develop applications based on them: +- WebApp elements, e.g., web/app/components/share +- Derived WebApp Template projects + +The remaining parts of the project are open source under the Apache License 2.0. + +With the Apache License 2.0, MIT License, and this supplementary agreement, anyone can freely use, modify, and distribute Dify, provided that: + +- If you use Dify solely as a backend service for other applications, no authorization is needed for commercial or closed source purposes. +- If you wish to use Dify for commercial and closed source SaaS services similar to Dify Cloud, please contact us for authorization. + +The interactive design of this product is protected by appearance patent. + +© 2023 LangGenius, Inc. + +---------- + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +---------- +The MIT License + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000000..169618b79f --- /dev/null +++ b/README.md @@ -0,0 +1,115 @@ +![](./images/describe-en.png) +

+ English | + 简体中文 +

+ +[Website](http://dify.ai) • [Docs](https://docs.dify.ai) • [Twitter](https://twitter.com/dify_ai) + +**Dify** is an easy-to-use LLMOps platform designed to empower more people to create sustainable, AI-native applications. With visual orchestration for various application types, Dify offers out-of-the-box, ready-to-use applications that can also serve as Backend-as-a-Service APIs. Unify your development process with one API for plugins and datasets integration, and streamline your operations using a single interface for prompt engineering, visual analytics, and continuous improvement. + +Applications created with Dify include: + +Out-of-the-box web sites supporting form mode and chat conversation mode +A single API encompassing plugin capabilities, context enhancement, and more, saving you backend coding effort +Visual data analysis, log review, and annotation for applications +Dify is compatible with Langchain, meaning we'll gradually support multiple LLMs, currently supported: + +- GPT 3 (text-davinci-003) +- GPT 3.5 Turbo(ChatGPT) +- GPT-4 + +## Use Cloud Services + +Visit [Dify.ai](http://dify.ai) + +## Install the Community Edition + +### System Requirements + +Before installing Dify, make sure your machine meets the following minimum system requirements: + +- CPU >= 1 Core +- RAM >= 4GB + +### Quick Start + +The easiest way to start the Dify server is to run our [docker-compose.yml](docker/docker-compose.yaml) file. Before running the installation command, make sure that [Docker](https://docs.docker.com/get-docker/) and [Docker Compose](https://docs.docker.com/compose/install/) are installed on your machine: + +```bash +cd docker +docker-compose up -d +``` + +After running, you can access the Dify console in your browser at [http://localhost](http://localhost) and start the initialization operation. + +### Configuration + +If you need to customize the configuration, please refer to the comments in our [docker-compose.yml](docker/docker-compose.yaml) file and manually set the environment configuration. After making the changes, please run 'docker-compose up -d' again. + +## Roadmap + +Features under development: + +- **Datasets**, supporting more datasets, e.g. syncing content from Notion or webpages +We will support more datasets, including text, webpages, and even Notion content. Users can build AI applications based on their own data sources. +- **Plugins**, introducing ChatGPT Plugin-standard plugins for applications, or using Dify-produced plugins +We will release plugins complying with ChatGPT standard, or Dify's own plugins to enable more capabilities in applications. +- **Open-source models**, e.g. adopting Llama as a model provider or for further fine-tuning +We will work with excellent open-source models like Llama, by providing them as model options in our platform, or using them for further fine-tuning. + + +## Q&A + +**Q: What can I do with Dify?** + +A: Dify is a simple yet powerful LLM development and operations tool. You can use it to build commercial-grade applications, personal assistants. If you want to develop your own applications, LangDifyGenius can save you backend work in integrating with OpenAI and offer visual operations capabilities, allowing you to continuously improve and train your GPT model. + +**Q: How do I use Dify to "train" my own model?** + +A: A valuable application consists of Prompt Engineering, context enhancement, and Fine-tuning. We've created a hybrid programming approach combining Prompts with programming languages (similar to a template engine), making it easy to accomplish long-text embedding or capturing subtitles from a user-input Youtube video - all of which will be submitted as context for LLMs to process. We place great emphasis on application operability, with data generated by users during App usage available for analysis, annotation, and continuous training. Without the right tools, these steps can be time-consuming. + +**Q: What do I need to prepare if I want to create my own application?** + +A: We assume you already have an OpenAI API Key; if not, please register for one. If you already have some content that can serve as training context, that's great! + +**Q: What interface languages are available?** + +A: English and Chinese are currently supported, and you can contribute language packs to us. + +## Star History + +[![Star History Chart](https://api.star-history.com/svg?repos=langgenius/dify&type=Date)](https://star-history.com/#langgenius/dify&Date) + +## Contact Us + +If you have any questions, suggestions, or partnership inquiries, feel free to contact us through the following channels: + +- Submit an Issue or PR on our GitHub Repo +- Join the discussion in our [Discord](https://discord.gg/AhzKf7dNgk) Community +- Send an email to hello@dify.ai + +We're eager to assist you and together create more fun and useful AI applications! + +## Contributing + +To ensure proper review, all code contributions - including those from contributors with direct commit access - must be submitted via pull requests and approved by the core development team prior to being merged. + +We welcome all pull requests! If you'd like to help, check out the [Contribution Guide](CONTRIBUTING.md) for more information on how to get started. + +## Security + +To protect your privacy, please avoid posting security issues on GitHub. Instead, send your questions to security@dify.ai and we will provide you with a more detailed answer. + +## Citation + +This software uses the following open-source software: + +- Chase, H. (2022). LangChain [Computer software]. https://github.com/hwchase17/langchain +- Liu, J. (2022). LlamaIndex [Computer software]. doi: 10.5281/zenodo.1234. + +For more information, please refer to the official website or license text of the respective software. + +## License + +This repository is available under the [Dify Open Source License](LICENSE). diff --git a/README_CN.md b/README_CN.md new file mode 100644 index 0000000000..3788c857fc --- /dev/null +++ b/README_CN.md @@ -0,0 +1,114 @@ +![](./images/describe-cn.jpg) +

+ English | + 简体中文 +

+ + +[官方网站](http://dify.ai) • [文档](https://docs.dify.ai/v/zh-hans) • [Twitter](https://twitter.com/dify_ai) + +**Dify** 是一个易用的 LLMOps 平台,旨在让更多人可以创建可持续运营的原生 AI 应用。Dify 提供多种类型应用的可视化编排,应用可开箱即用,也能以“后端即服务”的 API 提供服务。 + +通过 Dify 创建的应用包含了: + +- 开箱即用的的 Web 站点,支持表单模式和聊天对话模式 +- 一套 API 即可包含插件、上下文增强等能力,替你省下了后端代码的编写工作 +- 可视化的对应用进行数据分析,查阅日志或进行标注 + +Dify 兼容 Langchain,这意味着我们将逐步支持多种 LLMs ,目前已支持: + +- GPT 3 (text-davinci-003) +- GPT 3.5 Turbo(ChatGPT) +- GPT-4 + +## 使用云服务 + +访问 [Dify.ai](http://cloud.dify.ai) + +## 安装社区版 + +### 系统要求 + +在安装 Dify 之前,请确保您的机器满足以下最低系统要求: + +- CPU >= 1 Core +- RAM >= 4GB + +### 快速启动 + +启动 Dify 服务器的最简单方法是运行我们的 [docker-compose.yml](docker/docker-compose.yaml) 文件。在运行安装命令之前,请确保您的机器上安装了 [Docker](https://docs.docker.com/get-docker/) 和 [Docker Compose](https://docs.docker.com/compose/install/): + +```bash +cd docker +docker-compose up -d +``` + +运行后,可以在浏览器上访问 [http://localhost](http://localhost) 进入 Dify 控制台,并开始初始化操作。 + +### 配置 + +需要自定义配置,请参考我们的 [docker-compose.yml](docker/docker-compose.yaml) 文件中的注释,并手动设置环境配置,修改完毕后,请再次执行 `docker-compose up -d`。 + +## Roadmap + +我们正在开发中的功能: + +- **数据集**,支持更多的数据集,例如同步 Notion 或网页的内容 +我们将支持更多的数据集,包括文本、网页,甚至 Notion 内容。用户可以根据自己的数据源构建 AI 应用程序。 +- **插件**,推出符合 ChatGPT 标准的插件,或使用 Dify 产生的插件 +我们将发布符合 ChatGPT 标准的插件,或者 Dify 自己的插件,以在应用程序中启用更多功能。 +- **开源模型**,例如采用 Llama 作为模型提供者,或进行进一步的微调 +我们将与优秀的开源模型如 Llama 合作,通过在我们的平台中提供它们作为模型选项,或使用它们进行进一步的微调。 + +## Q&A + +**Q: 我能用 Dify 做什么?** + +A: Dify 是一个简单且能力丰富的 LLM 开发和运营工具。你可以用它搭建商用级应用,个人助理。如果你想自己开发应用,Dify 也能为你省下接入 OpenAI 的后端工作,使用我们逐步提供的可视化运营能力,你可以持续的改进和训练你的 GPT 模型。 + +**Q: 如何使用 Dify “训练”自己的模型?** + +A: 一个有价值的应用由 Prompt Engineering、上下文增强和 Fine-tune 三个环节组成。我们创造了一种 Prompt 结合编程语言的 Hybrid 编程方式(类似一个模版引擎),你可以轻松的完成长文本嵌入,或抓取用户输入的一个 Youtube 视频的字幕——这些都将作为上下文提交给 LLMs 进行计算。我们十分注重应用的可运营性,你的用户在使用 App 期间产生的数据,可进行分析、标记和持续训练。以上环节如果没有好的工具支持,可能会消耗你大量的时间。 + +**Q: 如果要创建一个自己的应用,我需要准备什么?** + +A: 我们假定你已经有了 OpenAI API Key,如果没有请去注册一个。如果你已经有了一些内容可以作为训练上下文,就太好了。 + +**Q: 提供哪些界面语言?** + +A: 现已支持英文与中文,你可以为我们贡献语言包。 + +## Star History + +[![Star History Chart](https://api.star-history.com/svg?repos=langgenius/dify&type=Date)](https://star-history.com/#langgenius/dify&Date) + +## 联系我们 + +如果您有任何问题、建议或合作意向,欢迎通过以下方式联系我们: + +- 在我们的 [GitHub Repo](https://github.com/langgenius/dify) 上提交 Issue 或 PR +- 在我们的 [Discord 社区](https://discord.gg/AhzKf7dNgk) 上加入讨论 +- 发送邮件至 hello@dify.ai + +## 贡献代码 + +为了确保正确审查,所有代码贡献 - 包括来自具有直接提交更改权限的贡献者 - 都必须提交 PR 请求并在合并分支之前得到核心开发人员的批准。 + +我们欢迎所有人提交 PR!如果您愿意提供帮助,可以在 [贡献指南](CONTRIBUTING_CN.md) 中了解有关如何为项目做出贡献的更多信息。 + +## 安全 + +为了保护您的隐私,请避免在 GitHub 上发布安全问题。发送问题至 security@dify.ai,我们将为您做更细致的解答。 + +## Citation + +本软件使用了以下开源软件: + +- Chase, H. (2022). LangChain [Computer software]. https://github.com/hwchase17/langchain +- Liu, J. (2022). LlamaIndex [Computer software]. doi: 10.5281/zenodo.1234. + +更多信息,请参考相应软件的官方网站或许可证文本。 + +## License + +本仓库遵循 [Dify Open Source License](LICENSE) 开源协议。 diff --git a/api/.dockerignore b/api/.dockerignore new file mode 100644 index 0000000000..9b5050396d --- /dev/null +++ b/api/.dockerignore @@ -0,0 +1,2 @@ +.env +storage/privkeys/* \ No newline at end of file diff --git a/api/.env.example b/api/.env.example new file mode 100644 index 0000000000..5f307dc106 --- /dev/null +++ b/api/.env.example @@ -0,0 +1,85 @@ +# Server Edition +EDITION=SELF_HOSTED + +# Your App secret key will be used for securely signing the session cookie +# Make sure you are changing this key for your deployment with a strong key. +# You can generate a strong key using `openssl rand -base64 42`. +# Alternatively you can set it with `SECRET_KEY` environment variable. +SECRET_KEY= + +# Console API base URL +CONSOLE_URL=http://127.0.0.1:5001 + +# Service API base URL +API_URL=http://127.0.0.1:5001 + +# Web APP base URL +APP_URL=http://127.0.0.1:5001 + +# celery configuration +CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1 + +# redis configuration +REDIS_HOST=localhost +REDIS_PORT=6379 +REDIS_PASSWORD=difyai123456 +REDIS_DB=0 + +# PostgreSQL database configuration +DB_USERNAME=postgres +DB_PASSWORD=difyai123456 +DB_HOST=localhost +DB_PORT=5432 +DB_DATABASE=dify + +# Storage configuration +# use for store upload files, private keys... +# storage type: local, s3 +STORAGE_TYPE=local +STORAGE_LOCAL_PATH=storage +S3_ENDPOINT=https://your-bucket-name.storage.s3.clooudflare.com +S3_BUCKET_NAME=your-bucket-name +S3_ACCESS_KEY=your-access-key +S3_SECRET_KEY=your-secret-key +S3_REGION=your-region + +# CORS configuration +WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* +CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* + +# Cookie configuration +COOKIE_HTTPONLY=true +COOKIE_SAMESITE=None +COOKIE_SECURE=true + +# Session configuration +SESSION_PERMANENT=true +SESSION_USE_SIGNER=true + +## support redis, sqlalchemy +SESSION_TYPE=redis + +# session redis configuration +SESSION_REDIS_HOST=localhost +SESSION_REDIS_PORT=6379 +SESSION_REDIS_PASSWORD=difyai123456 +SESSION_REDIS_DB=2 + +# Vector database configuration, support: weaviate, qdrant +VECTOR_STORE=weaviate + +# Weaviate configuration +WEAVIATE_ENDPOINT=http://localhost:8080 +WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih +WEAVIATE_GRPC_ENABLED=false + +# Qdrant configuration, use `path:` prefix for local mode or `https://your-qdrant-cluster-url.qdrant.io` for remote mode +QDRANT_URL=path:storage/qdrant +QDRANT_API_KEY=your-qdrant-api-key + +# Sentry configuration +SENTRY_DSN= + +# DEBUG +DEBUG=false +SQLALCHEMY_ECHO=false diff --git a/api/Dockerfile b/api/Dockerfile new file mode 100644 index 0000000000..fb451129d1 --- /dev/null +++ b/api/Dockerfile @@ -0,0 +1,28 @@ +FROM langgenius/base:1.0.0-bullseye-slim as langgenius-api + +LABEL maintainer="takatost@gmail.com" + +ENV FLASK_APP app.py +ENV EDITION SELF_HOSTED +ENV DEPLOY_ENV PRODUCTION +ENV CONSOLE_URL http://127.0.0.1:5001 +ENV API_URL http://127.0.0.1:5001 +ENV APP_URL http://127.0.0.1:5001 + +EXPOSE 5001 + +WORKDIR /app/api + +COPY requirements.txt /app/api/requirements.txt + +RUN pip install -r requirements.txt + +COPY . /app/api/ + +COPY docker/entrypoint.sh /entrypoint.sh +RUN chmod +x /entrypoint.sh + +ARG COMMIT_SHA +ENV COMMIT_SHA ${COMMIT_SHA} + +ENTRYPOINT ["/entrypoint.sh"] \ No newline at end of file diff --git a/api/README.md b/api/README.md new file mode 100644 index 0000000000..97f09fc700 --- /dev/null +++ b/api/README.md @@ -0,0 +1,35 @@ +# Dify Backend API + +## Usage + +1. Start the docker-compose stack + + The backend require some middleware, including PostgreSQL, Redis, and Weaviate, which can be started together using `docker-compose`. + + ```bash + cd ../docker + docker-compose -f docker-compose.middleware.yaml up -d + cd ../api + ``` +2. Copy `.env.example` to `.env` +3. Generate a `SECRET_KEY` in the `.env` file. + + ```bash + openssl rand -base64 42 + ``` +4. Install dependencies + ```bash + pip install -r requirements.txt + ``` +5. Run migrate + + Before the first launch, migrate the database to the latest version. + + ```bash + flask db upgrade + ``` +6. Start backend: + ```bash + flask run --host 0.0.0.0 --port=5001 --debug + ``` +7. Setup your application by visiting http://localhost:5001/console/api/setup or other apis... diff --git a/api/app.py b/api/app.py new file mode 100644 index 0000000000..b3fcbc220a --- /dev/null +++ b/api/app.py @@ -0,0 +1,222 @@ +# -*- coding:utf-8 -*- +import os +if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true': + from gevent import monkey + monkey.patch_all() + +import logging +import json +import threading + +from flask import Flask, request, Response, session +import flask_login +from flask_cors import CORS + +from extensions import ext_session, ext_celery, ext_sentry, ext_redis, ext_login, ext_vector_store, ext_migrate, \ + ext_database, ext_storage +from extensions.ext_database import db +from extensions.ext_login import login_manager + +# DO NOT REMOVE BELOW +from models import model, account, dataset, web, task +from events import event_handlers +# DO NOT REMOVE ABOVE + +import core +from config import Config, CloudEditionConfig +from commands import register_commands +from models.account import TenantAccountJoin +from models.model import Account, EndUser, App + +import warnings +warnings.simplefilter("ignore", ResourceWarning) + + +class DifyApp(Flask): + pass + +# ------------- +# Configuration +# ------------- + + +config_type = os.getenv('EDITION', default='SELF_HOSTED') # ce edition first + +# ---------------------------- +# Application Factory Function +# ---------------------------- + + +def create_app(test_config=None) -> Flask: + app = DifyApp(__name__) + + if test_config: + app.config.from_object(test_config) + else: + if config_type == "CLOUD": + app.config.from_object(CloudEditionConfig()) + else: + app.config.from_object(Config()) + + app.secret_key = app.config['SECRET_KEY'] + + logging.basicConfig(level=app.config.get('LOG_LEVEL', 'INFO')) + + initialize_extensions(app) + register_blueprints(app) + register_commands(app) + + core.init_app(app) + + return app + + +def initialize_extensions(app): + # Since the application instance is now created, pass it to each Flask + # extension instance to bind it to the Flask application instance (app) + ext_database.init_app(app) + ext_migrate.init(app, db) + ext_redis.init_app(app) + ext_vector_store.init_app(app) + ext_storage.init_app(app) + ext_celery.init_app(app) + ext_session.init_app(app) + ext_login.init_app(app) + ext_sentry.init_app(app) + + +# Flask-Login configuration +@login_manager.user_loader +def load_user(user_id): + """Load user based on the user_id.""" + if request.blueprint == 'console': + # Check if the user_id contains a dot, indicating the old format + if '.' in user_id: + tenant_id, account_id = user_id.split('.') + else: + account_id = user_id + + account = db.session.query(Account).filter(Account.id == account_id).first() + + if account: + workspace_id = session.get('workspace_id') + if workspace_id: + tenant_account_join = db.session.query(TenantAccountJoin).filter( + TenantAccountJoin.account_id == account.id, + TenantAccountJoin.tenant_id == workspace_id + ).first() + + if not tenant_account_join: + tenant_account_join = db.session.query(TenantAccountJoin).filter( + TenantAccountJoin.account_id == account.id).first() + + if tenant_account_join: + account.current_tenant_id = tenant_account_join.tenant_id + session['workspace_id'] = account.current_tenant_id + else: + account.current_tenant_id = workspace_id + else: + tenant_account_join = db.session.query(TenantAccountJoin).filter( + TenantAccountJoin.account_id == account.id).first() + if tenant_account_join: + account.current_tenant_id = tenant_account_join.tenant_id + session['workspace_id'] = account.current_tenant_id + + # Log in the user with the updated user_id + flask_login.login_user(account, remember=True) + + return account + else: + return None + + +@login_manager.unauthorized_handler +def unauthorized_handler(): + """Handle unauthorized requests.""" + return Response(json.dumps({ + 'code': 'unauthorized', + 'message': "Unauthorized." + }), status=401, content_type="application/json") + + +# register blueprint routers +def register_blueprints(app): + from controllers.service_api import bp as service_api_bp + from controllers.web import bp as web_bp + from controllers.console import bp as console_app_bp + + app.register_blueprint(service_api_bp) + + CORS(web_bp, + resources={ + r"/*": {"origins": app.config['WEB_API_CORS_ALLOW_ORIGINS']}}, + supports_credentials=True, + allow_headers=['Content-Type', 'Authorization'], + methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'], + expose_headers=['X-Version', 'X-Env'] + ) + + app.register_blueprint(web_bp) + + CORS(console_app_bp, + resources={ + r"/*": {"origins": app.config['CONSOLE_CORS_ALLOW_ORIGINS']}}, + supports_credentials=True, + allow_headers=['Content-Type', 'Authorization'], + methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'], + expose_headers=['X-Version', 'X-Env'] + ) + + app.register_blueprint(console_app_bp) + + +# create app +app = create_app() +celery = app.extensions["celery"] + + +if app.config['TESTING']: + print("App is running in TESTING mode") + + +@app.after_request +def after_request(response): + """Add Version headers to the response.""" + response.headers.add('X-Version', app.config['CURRENT_VERSION']) + response.headers.add('X-Env', app.config['DEPLOY_ENV']) + return response + + +@app.route('/health') +def health(): + return Response(json.dumps({ + 'status': 'ok', + 'version': app.config['CURRENT_VERSION'] + }), status=200, content_type="application/json") + + +@app.route('/threads') +def threads(): + num_threads = threading.active_count() + threads = threading.enumerate() + + thread_list = [] + for thread in threads: + thread_name = thread.name + thread_id = thread.ident + is_alive = thread.is_alive() + + thread_list.append({ + 'name': thread_name, + 'id': thread_id, + 'is_alive': is_alive + }) + + return { + 'thread_num': num_threads, + 'threads': thread_list + } + + +if __name__ == '__main__': + app.run(host='0.0.0.0', port=5001) diff --git a/api/commands.py b/api/commands.py new file mode 100644 index 0000000000..b67b4f8676 --- /dev/null +++ b/api/commands.py @@ -0,0 +1,160 @@ +import datetime +import json +import random +import string + +import click + +from libs.password import password_pattern, valid_password, hash_password +from libs.helper import email as email_validate +from extensions.ext_database import db +from models.account import InvitationCode +from models.model import Account, AppModelConfig, ApiToken, Site, App, RecommendedApp +import secrets +import base64 + + +@click.command('reset-password', help='Reset the account password.') +@click.option('--email', prompt=True, help='The email address of the account whose password you need to reset') +@click.option('--new-password', prompt=True, help='the new password.') +@click.option('--password-confirm', prompt=True, help='the new password confirm.') +def reset_password(email, new_password, password_confirm): + if str(new_password).strip() != str(password_confirm).strip(): + click.echo(click.style('sorry. The two passwords do not match.', fg='red')) + return + account = db.session.query(Account). \ + filter(Account.email == email). \ + one_or_none() + if not account: + click.echo(click.style('sorry. the account: [{}] not exist .'.format(email), fg='red')) + return + try: + valid_password(new_password) + except: + click.echo( + click.style('sorry. The passwords must match {} '.format(password_pattern), fg='red')) + return + + # generate password salt + salt = secrets.token_bytes(16) + base64_salt = base64.b64encode(salt).decode() + + # encrypt password with salt + password_hashed = hash_password(new_password, salt) + base64_password_hashed = base64.b64encode(password_hashed).decode() + account.password = base64_password_hashed + account.password_salt = base64_salt + db.session.commit() + click.echo(click.style('Congratulations!, password has been reset.', fg='green')) + + +@click.command('reset-email', help='Reset the account email.') +@click.option('--email', prompt=True, help='The old email address of the account whose email you need to reset') +@click.option('--new-email', prompt=True, help='the new email.') +@click.option('--email-confirm', prompt=True, help='the new email confirm.') +def reset_email(email, new_email, email_confirm): + if str(new_email).strip() != str(email_confirm).strip(): + click.echo(click.style('Sorry, new email and confirm email do not match.', fg='red')) + return + account = db.session.query(Account). \ + filter(Account.email == email). \ + one_or_none() + if not account: + click.echo(click.style('sorry. the account: [{}] not exist .'.format(email), fg='red')) + return + try: + email_validate(new_email) + except: + click.echo( + click.style('sorry. {} is not a valid email. '.format(email), fg='red')) + return + + account.email = new_email + db.session.commit() + click.echo(click.style('Congratulations!, email has been reset.', fg='green')) + + +@click.command('generate-invitation-codes', help='Generate invitation codes.') +@click.option('--batch', help='The batch of invitation codes.') +@click.option('--count', prompt=True, help='Invitation codes count.') +def generate_invitation_codes(batch, count): + if not batch: + now = datetime.datetime.now() + batch = now.strftime('%Y%m%d%H%M%S') + + if not count or int(count) <= 0: + click.echo(click.style('sorry. the count must be greater than 0.', fg='red')) + return + + count = int(count) + + click.echo('Start generate {} invitation codes for batch {}.'.format(count, batch)) + + codes = '' + for i in range(count): + code = generate_invitation_code() + invitation_code = InvitationCode( + code=code, + batch=batch + ) + db.session.add(invitation_code) + click.echo(code) + + codes += code + "\n" + db.session.commit() + + filename = 'storage/invitation-codes-{}.txt'.format(batch) + + with open(filename, 'w') as f: + f.write(codes) + + click.echo(click.style( + 'Congratulations! Generated {} invitation codes for batch {} and saved to the file \'{}\''.format(count, batch, + filename), + fg='green')) + + +def generate_invitation_code(): + code = generate_upper_string() + while db.session.query(InvitationCode).filter(InvitationCode.code == code).count() > 0: + code = generate_upper_string() + + return code + + +def generate_upper_string(): + letters_digits = string.ascii_uppercase + string.digits + result = "" + for i in range(8): + result += random.choice(letters_digits) + + return result + + +@click.command('gen-recommended-apps', help='Number of records to generate') +def generate_recommended_apps(): + print('Generating recommended app data...') + apps = App.query.all() + for app in apps: + recommended_app = RecommendedApp( + app_id=app.id, + description={ + 'en': 'Description for ' + app.name, + 'zh': '描述 ' + app.name + }, + copyright='Copyright ' + str(random.randint(1990, 2020)), + privacy_policy='https://privacypolicy.example.com', + category=random.choice(['Games', 'News', 'Music', 'Sports']), + position=random.randint(1, 100), + install_count=random.randint(100, 100000) + ) + db.session.add(recommended_app) + db.session.commit() + print('Done!') + + +def register_commands(app): + app.cli.add_command(reset_password) + app.cli.add_command(reset_email) + app.cli.add_command(generate_invitation_codes) + app.cli.add_command(generate_recommended_apps) diff --git a/api/config.py b/api/config.py new file mode 100644 index 0000000000..04c44f2447 --- /dev/null +++ b/api/config.py @@ -0,0 +1,200 @@ +# -*- coding:utf-8 -*- +import os +from datetime import timedelta + +import dotenv + +from extensions.ext_database import db +from extensions.ext_redis import redis_client + +dotenv.load_dotenv() + +DEFAULTS = { + 'COOKIE_HTTPONLY': 'True', + 'COOKIE_SECURE': 'True', + 'COOKIE_SAMESITE': 'None', + 'DB_USERNAME': 'postgres', + 'DB_PASSWORD': '', + 'DB_HOST': 'localhost', + 'DB_PORT': '5432', + 'DB_DATABASE': 'dify', + 'REDIS_HOST': 'localhost', + 'REDIS_PORT': '6379', + 'REDIS_DB': '0', + 'SESSION_REDIS_HOST': 'localhost', + 'SESSION_REDIS_PORT': '6379', + 'SESSION_REDIS_DB': '2', + 'OAUTH_REDIRECT_PATH': '/console/api/oauth/authorize', + 'OAUTH_REDIRECT_INDEX_PATH': '/', + 'CONSOLE_URL': 'https://cloud.dify.ai', + 'API_URL': 'https://api.dify.ai', + 'APP_URL': 'https://udify.app', + 'STORAGE_TYPE': 'local', + 'STORAGE_LOCAL_PATH': 'storage', + 'CHECK_UPDATE_URL': 'https://updates.dify.ai', + 'SESSION_TYPE': 'sqlalchemy', + 'SESSION_PERMANENT': 'True', + 'SESSION_USE_SIGNER': 'True', + 'DEPLOY_ENV': 'PRODUCTION', + 'SQLALCHEMY_POOL_SIZE': 30, + 'SQLALCHEMY_ECHO': 'False', + 'SENTRY_TRACES_SAMPLE_RATE': 1.0, + 'SENTRY_PROFILES_SAMPLE_RATE': 1.0, + 'WEAVIATE_GRPC_ENABLED': 'True', + 'CELERY_BACKEND': 'database', + 'PDF_PREVIEW': 'True', + 'LOG_LEVEL': 'INFO', +} + + +def get_env(key): + return os.environ.get(key, DEFAULTS.get(key)) + + +def get_bool_env(key): + return get_env(key).lower() == 'true' + + +def get_cors_allow_origins(env, default): + cors_allow_origins = [] + if get_env(env): + for origin in get_env(env).split(','): + cors_allow_origins.append(origin) + else: + cors_allow_origins = [default] + + return cors_allow_origins + + +class Config: + """Application configuration class.""" + + def __init__(self): + # app settings + self.CONSOLE_URL = get_env('CONSOLE_URL') + self.API_URL = get_env('API_URL') + self.APP_URL = get_env('APP_URL') + self.CURRENT_VERSION = "0.2.0" + self.COMMIT_SHA = get_env('COMMIT_SHA') + self.EDITION = "SELF_HOSTED" + self.DEPLOY_ENV = get_env('DEPLOY_ENV') + self.TESTING = False + self.LOG_LEVEL = get_env('LOG_LEVEL') + self.PDF_PREVIEW = get_bool_env('PDF_PREVIEW') + + # Your App secret key will be used for securely signing the session cookie + # Make sure you are changing this key for your deployment with a strong key. + # You can generate a strong key using `openssl rand -base64 42`. + # Alternatively you can set it with `SECRET_KEY` environment variable. + self.SECRET_KEY = get_env('SECRET_KEY') + + # cookie settings + self.REMEMBER_COOKIE_HTTPONLY = get_bool_env('COOKIE_HTTPONLY') + self.SESSION_COOKIE_HTTPONLY = get_bool_env('COOKIE_HTTPONLY') + self.REMEMBER_COOKIE_SAMESITE = get_env('COOKIE_SAMESITE') + self.SESSION_COOKIE_SAMESITE = get_env('COOKIE_SAMESITE') + self.REMEMBER_COOKIE_SECURE = get_bool_env('COOKIE_SECURE') + self.SESSION_COOKIE_SECURE = get_bool_env('COOKIE_SECURE') + self.PERMANENT_SESSION_LIFETIME = timedelta(days=7) + + # session settings, only support sqlalchemy, redis + self.SESSION_TYPE = get_env('SESSION_TYPE') + self.SESSION_PERMANENT = get_bool_env('SESSION_PERMANENT') + self.SESSION_USE_SIGNER = get_bool_env('SESSION_USE_SIGNER') + + # redis settings + self.REDIS_HOST = get_env('REDIS_HOST') + self.REDIS_PORT = get_env('REDIS_PORT') + self.REDIS_PASSWORD = get_env('REDIS_PASSWORD') + self.REDIS_DB = get_env('REDIS_DB') + + # session redis settings + self.SESSION_REDIS_HOST = get_env('SESSION_REDIS_HOST') + self.SESSION_REDIS_PORT = get_env('SESSION_REDIS_PORT') + self.SESSION_REDIS_PASSWORD = get_env('SESSION_REDIS_PASSWORD') + self.SESSION_REDIS_DB = get_env('SESSION_REDIS_DB') + + # storage settings + self.STORAGE_TYPE = get_env('STORAGE_TYPE') + self.STORAGE_LOCAL_PATH = get_env('STORAGE_LOCAL_PATH') + self.S3_ENDPOINT = get_env('S3_ENDPOINT') + self.S3_BUCKET_NAME = get_env('S3_BUCKET_NAME') + self.S3_ACCESS_KEY = get_env('S3_ACCESS_KEY') + self.S3_SECRET_KEY = get_env('S3_SECRET_KEY') + self.S3_REGION = get_env('S3_REGION') + + # vector store settings, only support weaviate, qdrant + self.VECTOR_STORE = get_env('VECTOR_STORE') + + # weaviate settings + self.WEAVIATE_ENDPOINT = get_env('WEAVIATE_ENDPOINT') + self.WEAVIATE_API_KEY = get_env('WEAVIATE_API_KEY') + self.WEAVIATE_GRPC_ENABLED = get_bool_env('WEAVIATE_GRPC_ENABLED') + + # qdrant settings + self.QDRANT_URL = get_env('QDRANT_URL') + self.QDRANT_API_KEY = get_env('QDRANT_API_KEY') + + # cors settings + self.CONSOLE_CORS_ALLOW_ORIGINS = get_cors_allow_origins( + 'CONSOLE_CORS_ALLOW_ORIGINS', self.CONSOLE_URL) + self.WEB_API_CORS_ALLOW_ORIGINS = get_cors_allow_origins( + 'WEB_API_CORS_ALLOW_ORIGINS', '*') + + # sentry settings + self.SENTRY_DSN = get_env('SENTRY_DSN') + self.SENTRY_TRACES_SAMPLE_RATE = float(get_env('SENTRY_TRACES_SAMPLE_RATE')) + self.SENTRY_PROFILES_SAMPLE_RATE = float(get_env('SENTRY_PROFILES_SAMPLE_RATE')) + + # check update url + self.CHECK_UPDATE_URL = get_env('CHECK_UPDATE_URL') + + # database settings + db_credentials = { + key: get_env(key) for key in + ['DB_USERNAME', 'DB_PASSWORD', 'DB_HOST', 'DB_PORT', 'DB_DATABASE'] + } + + self.SQLALCHEMY_DATABASE_URI = f"postgresql://{db_credentials['DB_USERNAME']}:{db_credentials['DB_PASSWORD']}@{db_credentials['DB_HOST']}:{db_credentials['DB_PORT']}/{db_credentials['DB_DATABASE']}" + self.SQLALCHEMY_ENGINE_OPTIONS = {'pool_size': int(get_env('SQLALCHEMY_POOL_SIZE'))} + + self.SQLALCHEMY_ECHO = get_bool_env('SQLALCHEMY_ECHO') + + # celery settings + self.CELERY_BROKER_URL = get_env('CELERY_BROKER_URL') + self.CELERY_BACKEND = get_env('CELERY_BACKEND') + self.CELERY_RESULT_BACKEND = 'db+{}'.format(self.SQLALCHEMY_DATABASE_URI) \ + if self.CELERY_BACKEND == 'database' else self.CELERY_BROKER_URL + + # hosted provider credentials + self.OPENAI_API_KEY = get_env('OPENAI_API_KEY') + + +class CloudEditionConfig(Config): + + def __init__(self): + super().__init__() + + self.EDITION = "CLOUD" + + self.GITHUB_CLIENT_ID = get_env('GITHUB_CLIENT_ID') + self.GITHUB_CLIENT_SECRET = get_env('GITHUB_CLIENT_SECRET') + self.GOOGLE_CLIENT_ID = get_env('GOOGLE_CLIENT_ID') + self.GOOGLE_CLIENT_SECRET = get_env('GOOGLE_CLIENT_SECRET') + self.OAUTH_REDIRECT_PATH = get_env('OAUTH_REDIRECT_PATH') + + +class TestConfig(Config): + + def __init__(self): + super().__init__() + + self.EDITION = "SELF_HOSTED" + self.TESTING = True + + db_credentials = { + key: get_env(key) for key in ['DB_USERNAME', 'DB_PASSWORD', 'DB_HOST', 'DB_PORT'] + } + + # use a different database for testing: dify_test + self.SQLALCHEMY_DATABASE_URI = f"postgresql://{db_credentials['DB_USERNAME']}:{db_credentials['DB_PASSWORD']}@{db_credentials['DB_HOST']}:{db_credentials['DB_PORT']}/dify_test" diff --git a/api/constants/__init__.py b/api/constants/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/constants/model_template.py b/api/constants/model_template.py new file mode 100644 index 0000000000..f8d7e0b74a --- /dev/null +++ b/api/constants/model_template.py @@ -0,0 +1,322 @@ +import json + +from models.model import AppModelConfig, App + +model_templates = { + # completion default mode + 'completion_default': { + 'app': { + 'mode': 'completion', + 'enable_site': True, + 'enable_api': True, + 'is_demo': False, + 'api_rpm': 0, + 'api_rph': 0, + 'status': 'normal' + }, + 'model_config': { + 'provider': 'openai', + 'model_id': 'text-davinci-003', + 'configs': { + 'prompt_template': '', + 'prompt_variables': [], + 'completion_params': { + 'max_token': 512, + 'temperature': 1, + 'top_p': 1, + 'presence_penalty': 0, + 'frequency_penalty': 0, + } + }, + 'model': json.dumps({ + "provider": "openai", + "name": "text-davinci-003", + "completion_params": { + "max_tokens": 512, + "temperature": 1, + "top_p": 1, + "presence_penalty": 0, + "frequency_penalty": 0 + } + }) + } + }, + + # chat default mode + 'chat_default': { + 'app': { + 'mode': 'chat', + 'enable_site': True, + 'enable_api': True, + 'is_demo': False, + 'api_rpm': 0, + 'api_rph': 0, + 'status': 'normal' + }, + 'model_config': { + 'provider': 'openai', + 'model_id': 'gpt-3.5-turbo', + 'configs': { + 'prompt_template': '', + 'prompt_variables': [], + 'completion_params': { + 'max_token': 512, + 'temperature': 1, + 'top_p': 1, + 'presence_penalty': 0, + 'frequency_penalty': 0, + } + }, + 'model': json.dumps({ + "provider": "openai", + "name": "gpt-3.5-turbo", + "completion_params": { + "max_tokens": 512, + "temperature": 1, + "top_p": 1, + "presence_penalty": 0, + "frequency_penalty": 0 + } + }) + } + }, +} + + +demo_model_templates = { + 'en-US': [ + { + 'name': 'Translation Assistant', + 'icon': '', + 'icon_background': '', + 'description': 'A multilingual translator that provides translation capabilities in multiple languages, translating user input into the language they need.', + 'mode': 'completion', + 'model_config': AppModelConfig( + provider='openai', + model_id='text-davinci-003', + configs={ + 'prompt_template': "Please translate the following text into {{target_language}}:\n", + 'prompt_variables': [ + { + "key": "target_language", + "name": "Target Language", + "description": "The language you want to translate into.", + "type": "select", + "default": "Chinese", + 'options': [ + 'Chinese', + 'English', + 'Japanese', + 'French', + 'Russian', + 'German', + 'Spanish', + 'Korean', + 'Italian', + ] + } + ], + 'completion_params': { + 'max_token': 1000, + 'temperature': 0, + 'top_p': 0, + 'presence_penalty': 0.1, + 'frequency_penalty': 0.1, + } + }, + opening_statement='', + suggested_questions=None, + pre_prompt="Please translate the following text into {{target_language}}:\n", + model=json.dumps({ + "provider": "openai", + "name": "text-davinci-003", + "completion_params": { + "max_tokens": 1000, + "temperature": 0, + "top_p": 0, + "presence_penalty": 0.1, + "frequency_penalty": 0.1 + } + }), + user_input_form=json.dumps([ + { + "select": { + "label": "Target Language", + "variable": "target_language", + "description": "The language you want to translate into.", + "default": "Chinese", + "required": True, + 'options': [ + 'Chinese', + 'English', + 'Japanese', + 'French', + 'Russian', + 'German', + 'Spanish', + 'Korean', + 'Italian', + ] + } + } + ]) + ) + }, + { + 'name': 'AI Front-end Interviewer', + 'icon': '', + 'icon_background': '', + 'description': 'A simulated front-end interviewer that tests the skill level of front-end development through questioning.', + 'mode': 'chat', + 'model_config': AppModelConfig( + provider='openai', + model_id='gpt-3.5-turbo', + configs={ + 'introduction': 'Hi, welcome to our interview. I am the interviewer for this technology company, and I will test your web front-end development skills. Next, I will ask you some technical questions. Please answer them as thoroughly as possible. ', + 'prompt_template': "You will play the role of an interviewer for a technology company, examining the user's web front-end development skills and posing 5-10 sharp technical questions.\n\nPlease note:\n- Only ask one question at a time.\n- After the user answers a question, ask the next question directly, without trying to correct any mistakes made by the candidate.\n- If you think the user has not answered correctly for several consecutive questions, ask fewer questions.\n- After asking the last question, you can ask this question: Why did you leave your last job? After the user answers this question, please express your understanding and support.\n", + 'prompt_variables': [], + 'completion_params': { + 'max_token': 300, + 'temperature': 0.8, + 'top_p': 0.9, + 'presence_penalty': 0.1, + 'frequency_penalty': 0.1, + } + }, + opening_statement='Hi, welcome to our interview. I am the interviewer for this technology company, and I will test your web front-end development skills. Next, I will ask you some technical questions. Please answer them as thoroughly as possible. ', + suggested_questions=None, + pre_prompt="You will play the role of an interviewer for a technology company, examining the user's web front-end development skills and posing 5-10 sharp technical questions.\n\nPlease note:\n- Only ask one question at a time.\n- After the user answers a question, ask the next question directly, without trying to correct any mistakes made by the candidate.\n- If you think the user has not answered correctly for several consecutive questions, ask fewer questions.\n- After asking the last question, you can ask this question: Why did you leave your last job? After the user answers this question, please express your understanding and support.\n", + model=json.dumps({ + "provider": "openai", + "name": "gpt-3.5-turbo", + "completion_params": { + "max_tokens": 300, + "temperature": 0.8, + "top_p": 0.9, + "presence_penalty": 0.1, + "frequency_penalty": 0.1 + } + }), + user_input_form=None + ) + } + ], + + 'zh-Hans': [ + { + 'name': '翻译助手', + 'icon': '', + 'icon_background': '', + 'description': '一个多语言翻译器,提供多种语言翻译能力,将用户输入的文本翻译成他们需要的语言。', + 'mode': 'completion', + 'model_config': AppModelConfig( + provider='openai', + model_id='text-davinci-003', + configs={ + 'prompt_template': "请将以下文本翻译为{{target_language}}:\n", + 'prompt_variables': [ + { + "key": "target_language", + "name": "目标语言", + "description": "翻译的目标语言", + "type": "select", + "default": "中文", + "options": [ + "中文", + "英文", + "日语", + "法语", + "俄语", + "德语", + "西班牙语", + "韩语", + "意大利语", + ] + } + ], + 'completion_params': { + 'max_token': 1000, + 'temperature': 0, + 'top_p': 0, + 'presence_penalty': 0.1, + 'frequency_penalty': 0.1, + } + }, + opening_statement='', + suggested_questions=None, + pre_prompt="请将以下文本翻译为{{target_language}}:\n", + model=json.dumps({ + "provider": "openai", + "name": "text-davinci-003", + "completion_params": { + "max_tokens": 1000, + "temperature": 0, + "top_p": 0, + "presence_penalty": 0.1, + "frequency_penalty": 0.1 + } + }), + user_input_form=json.dumps([ + { + "select": { + "label": "目标语言", + "variable": "target_language", + "description": "翻译的目标语言", + "default": "中文", + "required": True, + 'options': [ + "中文", + "英文", + "日语", + "法语", + "俄语", + "德语", + "西班牙语", + "韩语", + "意大利语", + ] + } + } + ]) + ) + }, + { + 'name': 'AI 前端面试官', + 'icon': '', + 'icon_background': '', + 'description': '一个模拟的前端面试官,通过提问的方式对前端开发的技能水平进行检验。', + 'mode': 'chat', + 'model_config': AppModelConfig( + provider='openai', + model_id='gpt-3.5-turbo', + configs={ + 'introduction': '你好,欢迎来参加我们的面试,我是这家科技公司的面试官,我将考察你的 Web 前端开发技能。接下来我会向您提出一些技术问题,请您尽可能详尽地回答。', + 'prompt_template': "你将扮演一个科技公司的面试官,考察用户作为候选人的 Web 前端开发水平,提出 5-10 个犀利的技术问题。\n\n请注意:\n- 每次只问一个问题\n- 用户回答问题后请直接问下一个问题,而不要试图纠正候选人的错误;\n- 如果你认为用户连续几次回答的都不对,就少问一点;\n- 问完最后一个问题后,你可以问这样一个问题:上一份工作为什么离职?用户回答该问题后,请表示理解与支持。\n", + 'prompt_variables': [], + 'completion_params': { + 'max_token': 300, + 'temperature': 0.8, + 'top_p': 0.9, + 'presence_penalty': 0.1, + 'frequency_penalty': 0.1, + } + }, + opening_statement='你好,欢迎来参加我们的面试,我是这家科技公司的面试官,我将考察你的 Web 前端开发技能。接下来我会向您提出一些技术问题,请您尽可能详尽地回答。', + suggested_questions=None, + pre_prompt="你将扮演一个科技公司的面试官,考察用户作为候选人的 Web 前端开发水平,提出 5-10 个犀利的技术问题。\n\n请注意:\n- 每次只问一个问题\n- 用户回答问题后请直接问下一个问题,而不要试图纠正候选人的错误;\n- 如果你认为用户连续几次回答的都不对,就少问一点;\n- 问完最后一个问题后,你可以问这样一个问题:上一份工作为什么离职?用户回答该问题后,请表示理解与支持。\n", + model=json.dumps({ + "provider": "openai", + "name": "gpt-3.5-turbo", + "completion_params": { + "max_tokens": 300, + "temperature": 0.8, + "top_p": 0.9, + "presence_penalty": 0.1, + "frequency_penalty": 0.1 + } + }), + user_input_form=None + ) + } + ], +} diff --git a/api/controllers/__init__.py b/api/controllers/__init__.py new file mode 100644 index 0000000000..2c0485b18d --- /dev/null +++ b/api/controllers/__init__.py @@ -0,0 +1,4 @@ +# -*- coding:utf-8 -*- + + + diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py new file mode 100644 index 0000000000..971e489971 --- /dev/null +++ b/api/controllers/console/__init__.py @@ -0,0 +1,20 @@ +from flask import Blueprint + +from libs.external_api import ExternalApi + +bp = Blueprint('console', __name__, url_prefix='/console/api') +api = ExternalApi(bp) + +# Import app controllers +from .app import app, site, explore, completion, model_config, statistic, conversation, message + +# Import auth controllers +from .auth import login, oauth + +# Import datasets controllers +from .datasets import datasets, datasets_document, datasets_segments, file, hit_testing + +# Import other controllers +from . import setup, version, apikey + +from .workspace import workspace, members, providers, account diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py new file mode 100644 index 0000000000..e576a4d848 --- /dev/null +++ b/api/controllers/console/apikey.py @@ -0,0 +1,175 @@ +from flask_login import login_required, current_user +import flask_restful +from flask_restful import Resource, fields, marshal_with +from werkzeug.exceptions import Forbidden + +from extensions.ext_database import db +from models.model import App, ApiToken +from models.dataset import Dataset + +from . import api +from .setup import setup_required +from .wraps import account_initialization_required +from libs.helper import TimestampField + +api_key_fields = { + 'id': fields.String, + 'type': fields.String, + 'token': fields.String, + 'last_used_at': TimestampField, + 'created_at': TimestampField +} + +api_key_list = { + 'data': fields.List(fields.Nested(api_key_fields), attribute="items") +} + + +def _get_resource(resource_id, tenant_id, resource_model): + resource = resource_model.query.filter_by( + id=resource_id, tenant_id=tenant_id + ).first() + + if resource is None: + flask_restful.abort( + 404, message=f"{resource_model.__name__} not found.") + + return resource + + +class BaseApiKeyListResource(Resource): + method_decorators = [account_initialization_required, login_required, setup_required] + + resource_type = None + resource_model = None + resource_id_field = None + token_prefix = None + max_keys = 10 + + @marshal_with(api_key_list) + def get(self, resource_id): + resource_id = str(resource_id) + _get_resource(resource_id, current_user.current_tenant_id, + self.resource_model) + keys = db.session.query(ApiToken). \ + filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id). \ + all() + return {"items": keys} + + @marshal_with(api_key_fields) + def post(self, resource_id): + resource_id = str(resource_id) + _get_resource(resource_id, current_user.current_tenant_id, + self.resource_model) + + # The role of the current user in the ta table must be admin or owner + if current_user.current_tenant.current_role not in ['admin', 'owner']: + raise Forbidden() + + current_key_count = db.session.query(ApiToken). \ + filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id). \ + count() + + if current_key_count >= self.max_keys: + flask_restful.abort( + 400, + message=f"Cannot create more than {self.max_keys} API keys for this resource type.", + code='max_keys_exceeded' + ) + + key = ApiToken.generate_api_key(self.token_prefix, 24) + api_token = ApiToken() + setattr(api_token, self.resource_id_field, resource_id) + api_token.token = key + api_token.type = self.resource_type + db.session.add(api_token) + db.session.commit() + return api_token, 201 + + +class BaseApiKeyResource(Resource): + method_decorators = [account_initialization_required, login_required, setup_required] + + resource_type = None + resource_model = None + resource_id_field = None + + def delete(self, resource_id, api_key_id): + resource_id = str(resource_id) + api_key_id = str(api_key_id) + _get_resource(resource_id, current_user.current_tenant_id, + self.resource_model) + + # The role of the current user in the ta table must be admin or owner + if current_user.current_tenant.current_role not in ['admin', 'owner']: + raise Forbidden() + + key = db.session.query(ApiToken). \ + filter(getattr(ApiToken, self.resource_id_field) == resource_id, ApiToken.type == self.resource_type, ApiToken.id == api_key_id). \ + first() + + if key is None: + flask_restful.abort(404, message='API key not found') + + db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete() + db.session.commit() + + return {'result': 'success'}, 204 + + +class AppApiKeyListResource(BaseApiKeyListResource): + + def after_request(self, resp): + resp.headers['Access-Control-Allow-Origin'] = '*' + resp.headers['Access-Control-Allow-Credentials'] = 'true' + return resp + + resource_type = 'app' + resource_model = App + resource_id_field = 'app_id' + token_prefix = 'app-' + + +class AppApiKeyResource(BaseApiKeyResource): + + def after_request(self, resp): + resp.headers['Access-Control-Allow-Origin'] = '*' + resp.headers['Access-Control-Allow-Credentials'] = 'true' + return resp + + resource_type = 'app' + resource_model = App + resource_id_field = 'app_id' + + +class DatasetApiKeyListResource(BaseApiKeyListResource): + + def after_request(self, resp): + resp.headers['Access-Control-Allow-Origin'] = '*' + resp.headers['Access-Control-Allow-Credentials'] = 'true' + return resp + + resource_type = 'dataset' + resource_model = Dataset + resource_id_field = 'dataset_id' + token_prefix = 'ds-' + + +class DatasetApiKeyResource(BaseApiKeyResource): + + def after_request(self, resp): + resp.headers['Access-Control-Allow-Origin'] = '*' + resp.headers['Access-Control-Allow-Credentials'] = 'true' + return resp + resource_type = 'dataset' + resource_model = Dataset + resource_id_field = 'dataset_id' + + +api.add_resource(AppApiKeyListResource, '/apps//api-keys') +api.add_resource(AppApiKeyResource, + '/apps//api-keys/') +api.add_resource(DatasetApiKeyListResource, + '/datasets//api-keys') +api.add_resource(DatasetApiKeyResource, + '/datasets//api-keys/') diff --git a/api/controllers/console/app/__init__.py b/api/controllers/console/app/__init__.py new file mode 100644 index 0000000000..1f22ab30c6 --- /dev/null +++ b/api/controllers/console/app/__init__.py @@ -0,0 +1,22 @@ +from flask_login import current_user +from werkzeug.exceptions import NotFound + +from controllers.console.app.error import AppUnavailableError +from extensions.ext_database import db +from models.model import App + + +def _get_app(app_id, mode=None): + app = db.session.query(App).filter( + App.id == app_id, + App.tenant_id == current_user.current_tenant_id, + App.status == 'normal' + ).first() + + if not app: + raise NotFound("App not found") + + if mode and app.mode != mode: + raise AppUnavailableError() + + return app diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py new file mode 100644 index 0000000000..fbb28fb4ae --- /dev/null +++ b/api/controllers/console/app/app.py @@ -0,0 +1,518 @@ +# -*- coding:utf-8 -*- +import json +from datetime import datetime + +import flask +from flask_login import login_required, current_user +from flask_restful import Resource, reqparse, fields, marshal_with, abort, inputs +from werkzeug.exceptions import Unauthorized, Forbidden + +from constants.model_template import model_templates, demo_model_templates +from controllers.console import api +from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError, ProviderQuotaExceededError, \ + CompletionRequestError, ProviderModelCurrentlyNotSupportError +from controllers.console.setup import setup_required +from controllers.console.wraps import account_initialization_required +from core.generator.llm_generator import LLMGenerator +from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, LLMBadRequestError, LLMAPIConnectionError, \ + LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError +from events.app_event import app_was_created, app_was_deleted +from libs.helper import TimestampField +from extensions.ext_database import db +from models.model import App, AppModelConfig, Site, InstalledApp +from services.account_service import TenantService +from services.app_model_config_service import AppModelConfigService + +model_config_fields = { + 'opening_statement': fields.String, + 'suggested_questions': fields.Raw(attribute='suggested_questions_list'), + 'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'), + 'more_like_this': fields.Raw(attribute='more_like_this_dict'), + 'model': fields.Raw(attribute='model_dict'), + 'user_input_form': fields.Raw(attribute='user_input_form_list'), + 'pre_prompt': fields.String, + 'agent_mode': fields.Raw(attribute='agent_mode_dict'), +} + +app_detail_fields = { + 'id': fields.String, + 'name': fields.String, + 'mode': fields.String, + 'icon': fields.String, + 'icon_background': fields.String, + 'enable_site': fields.Boolean, + 'enable_api': fields.Boolean, + 'api_rpm': fields.Integer, + 'api_rph': fields.Integer, + 'is_demo': fields.Boolean, + 'model_config': fields.Nested(model_config_fields, attribute='app_model_config'), + 'created_at': TimestampField +} + + +def _get_app(app_id, tenant_id): + app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id).first() + if not app: + raise AppNotFoundError + return app + + +class AppListApi(Resource): + prompt_config_fields = { + 'prompt_template': fields.String, + } + + model_config_partial_fields = { + 'model': fields.Raw(attribute='model_dict'), + 'pre_prompt': fields.String, + } + + app_partial_fields = { + 'id': fields.String, + 'name': fields.String, + 'mode': fields.String, + 'icon': fields.String, + 'icon_background': fields.String, + 'enable_site': fields.Boolean, + 'enable_api': fields.Boolean, + 'is_demo': fields.Boolean, + 'model_config': fields.Nested(model_config_partial_fields, attribute='app_model_config'), + 'created_at': TimestampField + } + + app_pagination_fields = { + 'page': fields.Integer, + 'limit': fields.Integer(attribute='per_page'), + 'total': fields.Integer, + 'has_more': fields.Boolean(attribute='has_next'), + 'data': fields.List(fields.Nested(app_partial_fields), attribute='items') + } + + @setup_required + @login_required + @account_initialization_required + @marshal_with(app_pagination_fields) + def get(self): + """Get app list""" + parser = reqparse.RequestParser() + parser.add_argument('page', type=inputs.int_range(1, 99999), required=False, default=1, location='args') + parser.add_argument('limit', type=inputs.int_range(1, 100), required=False, default=20, location='args') + args = parser.parse_args() + + app_models = db.paginate( + db.select(App).where(App.tenant_id == current_user.current_tenant_id).order_by(App.created_at.desc()), + page=args['page'], + per_page=args['limit'], + error_out=False) + + return app_models + + @setup_required + @login_required + @account_initialization_required + @marshal_with(app_detail_fields) + def post(self): + """Create app""" + parser = reqparse.RequestParser() + parser.add_argument('name', type=str, required=True, location='json') + parser.add_argument('mode', type=str, choices=['completion', 'chat'], location='json') + parser.add_argument('icon', type=str, location='json') + parser.add_argument('icon_background', type=str, location='json') + parser.add_argument('model_config', type=dict, location='json') + args = parser.parse_args() + + # The role of the current user in the ta table must be admin or owner + if current_user.current_tenant.current_role not in ['admin', 'owner']: + raise Forbidden() + + if args['model_config'] is not None: + # validate config + model_configuration = AppModelConfigService.validate_configuration( + account=current_user, + config=args['model_config'], + mode=args['mode'] + ) + + app = App( + enable_site=True, + enable_api=True, + is_demo=False, + api_rpm=0, + api_rph=0, + status='normal' + ) + + app_model_config = AppModelConfig( + provider="", + model_id="", + configs={}, + opening_statement=model_configuration['opening_statement'], + suggested_questions=json.dumps(model_configuration['suggested_questions']), + suggested_questions_after_answer=json.dumps(model_configuration['suggested_questions_after_answer']), + more_like_this=json.dumps(model_configuration['more_like_this']), + model=json.dumps(model_configuration['model']), + user_input_form=json.dumps(model_configuration['user_input_form']), + pre_prompt=model_configuration['pre_prompt'], + agent_mode=json.dumps(model_configuration['agent_mode']), + ) + else: + if 'mode' not in args or args['mode'] is None: + abort(400, message="mode is required") + + model_config_template = model_templates[args['mode'] + '_default'] + + app = App(**model_config_template['app']) + app_model_config = AppModelConfig(**model_config_template['model_config']) + + app.name = args['name'] + app.mode = args['mode'] + app.icon = args['icon'] + app.icon_background = args['icon_background'] + app.tenant_id = current_user.current_tenant_id + + db.session.add(app) + db.session.flush() + + app_model_config.app_id = app.id + db.session.add(app_model_config) + db.session.flush() + + app.app_model_config_id = app_model_config.id + + account = current_user + + site = Site( + app_id=app.id, + title=app.name, + default_language=account.interface_language, + customize_token_strategy='not_allow', + code=Site.generate_code(16) + ) + + db.session.add(site) + db.session.commit() + + app_was_created.send(app) + + return app, 201 + + +class AppTemplateApi(Resource): + template_fields = { + 'name': fields.String, + 'icon': fields.String, + 'icon_background': fields.String, + 'description': fields.String, + 'mode': fields.String, + 'model_config': fields.Nested(model_config_fields), + } + + template_list_fields = { + 'data': fields.List(fields.Nested(template_fields)), + } + + @setup_required + @login_required + @account_initialization_required + @marshal_with(template_list_fields) + def get(self): + """Get app demo templates""" + account = current_user + interface_language = account.interface_language + + return {'data': demo_model_templates.get(interface_language)} + + +class AppApi(Resource): + site_fields = { + 'access_token': fields.String(attribute='code'), + 'code': fields.String, + 'title': fields.String, + 'icon': fields.String, + 'icon_background': fields.String, + 'description': fields.String, + 'default_language': fields.String, + 'customize_domain': fields.String, + 'copyright': fields.String, + 'privacy_policy': fields.String, + 'customize_token_strategy': fields.String, + 'prompt_public': fields.Boolean, + 'app_base_url': fields.String, + } + + app_detail_fields_with_site = { + 'id': fields.String, + 'name': fields.String, + 'mode': fields.String, + 'icon': fields.String, + 'icon_background': fields.String, + 'enable_site': fields.Boolean, + 'enable_api': fields.Boolean, + 'api_rpm': fields.Integer, + 'api_rph': fields.Integer, + 'is_demo': fields.Boolean, + 'model_config': fields.Nested(model_config_fields, attribute='app_model_config'), + 'site': fields.Nested(site_fields), + 'api_base_url': fields.String, + 'created_at': TimestampField + } + + @setup_required + @login_required + @account_initialization_required + @marshal_with(app_detail_fields_with_site) + def get(self, app_id): + """Get app detail""" + app_id = str(app_id) + app = _get_app(app_id, current_user.current_tenant_id) + + return app + + @setup_required + @login_required + @account_initialization_required + def delete(self, app_id): + """Delete app""" + app_id = str(app_id) + app = _get_app(app_id, current_user.current_tenant_id) + + db.session.delete(app) + db.session.commit() + + # todo delete related data?? + # model_config, site, api_token, conversation, message, message_feedback, message_annotation + + app_was_deleted.send(app) + + return {'result': 'success'}, 204 + + +class AppNameApi(Resource): + @setup_required + @login_required + @account_initialization_required + @marshal_with(app_detail_fields) + def post(self, app_id): + + # The role of the current user in the ta table must be admin or owner + if current_user.current_tenant.current_role not in ['admin', 'owner']: + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument('name', type=str, required=True, location='json') + args = parser.parse_args() + + app = db.get_or_404(App, str(app_id)) + if app.tenant_id != flask.session.get('tenant_id'): + raise Unauthorized() + + app.name = args.get('name') + app.updated_at = datetime.utcnow() + db.session.commit() + return app + + +class AppIconApi(Resource): + @setup_required + @login_required + @account_initialization_required + @marshal_with(app_detail_fields) + def post(self, app_id): + + # The role of the current user in the ta table must be admin or owner + if current_user.current_tenant.current_role not in ['admin', 'owner']: + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument('icon', type=str, location='json') + parser.add_argument('icon_background', type=str, location='json') + args = parser.parse_args() + + app = db.get_or_404(App, str(app_id)) + if app.tenant_id != flask.session.get('tenant_id'): + raise Unauthorized() + + app.icon = args.get('icon') + app.icon_background = args.get('icon_background') + app.updated_at = datetime.utcnow() + db.session.commit() + + return app + + +class AppSiteStatus(Resource): + @setup_required + @login_required + @account_initialization_required + @marshal_with(app_detail_fields) + def post(self, app_id): + parser = reqparse.RequestParser() + parser.add_argument('enable_site', type=bool, required=True, location='json') + args = parser.parse_args() + app_id = str(app_id) + app = db.session.query(App).filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id).first() + if not app: + raise AppNotFoundError + + if args.get('enable_site') == app.enable_site: + return app + + app.enable_site = args.get('enable_site') + app.updated_at = datetime.utcnow() + db.session.commit() + return app + + +class AppApiStatus(Resource): + @setup_required + @login_required + @account_initialization_required + @marshal_with(app_detail_fields) + def post(self, app_id): + parser = reqparse.RequestParser() + parser.add_argument('enable_api', type=bool, required=True, location='json') + args = parser.parse_args() + + app_id = str(app_id) + app = _get_app(app_id, current_user.current_tenant_id) + + if args.get('enable_api') == app.enable_api: + return app + + app.enable_api = args.get('enable_api') + app.updated_at = datetime.utcnow() + db.session.commit() + return app + + +class AppRateLimit(Resource): + @setup_required + @login_required + @account_initialization_required + @marshal_with(app_detail_fields) + def post(self, app_id): + parser = reqparse.RequestParser() + parser.add_argument('api_rpm', type=inputs.natural, required=False, location='json') + parser.add_argument('api_rph', type=inputs.natural, required=False, location='json') + args = parser.parse_args() + + app_id = str(app_id) + app = _get_app(app_id, current_user.current_tenant_id) + + if args.get('api_rpm'): + app.api_rpm = args.get('api_rpm') + if args.get('api_rph'): + app.api_rph = args.get('api_rph') + app.updated_at = datetime.utcnow() + db.session.commit() + return app + + +class AppCopy(Resource): + @staticmethod + def create_app_copy(app): + copy_app = App( + name=app.name + ' copy', + icon=app.icon, + icon_background=app.icon_background, + tenant_id=app.tenant_id, + mode=app.mode, + app_model_config_id=app.app_model_config_id, + enable_site=app.enable_site, + enable_api=app.enable_api, + api_rpm=app.api_rpm, + api_rph=app.api_rph + ) + return copy_app + + @staticmethod + def create_app_model_config_copy(app_config, copy_app_id): + copy_app_model_config = AppModelConfig( + app_id=copy_app_id, + provider=app_config.provider, + model_id=app_config.model_id, + configs=app_config.configs, + opening_statement=app_config.opening_statement, + suggested_questions=app_config.suggested_questions, + suggested_questions_after_answer=app_config.suggested_questions_after_answer, + more_like_this=app_config.more_like_this, + model=app_config.model, + user_input_form=app_config.user_input_form, + pre_prompt=app_config.pre_prompt, + agent_mode=app_config.agent_mode + ) + return copy_app_model_config + + @setup_required + @login_required + @account_initialization_required + @marshal_with(app_detail_fields) + def post(self, app_id): + app_id = str(app_id) + app = _get_app(app_id, current_user.current_tenant_id) + + copy_app = self.create_app_copy(app) + db.session.add(copy_app) + + app_config = db.session.query(AppModelConfig). \ + filter(AppModelConfig.app_id == app_id). \ + one_or_none() + + if app_config: + copy_app_model_config = self.create_app_model_config_copy(app_config, copy_app.id) + db.session.add(copy_app_model_config) + db.session.commit() + copy_app.app_model_config_id = copy_app_model_config.id + db.session.commit() + + return copy_app, 201 + + +class AppExport(Resource): + + @setup_required + @login_required + @account_initialization_required + def post(self, app_id): + # todo + pass + + +class IntroductionGenerateApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self): + parser = reqparse.RequestParser() + parser.add_argument('prompt_template', type=str, required=True, location='json') + args = parser.parse_args() + + account = current_user + + try: + answer = LLMGenerator.generate_introduction( + account.current_tenant_id, + args['prompt_template'] + ) + except ProviderTokenNotInitError: + raise ProviderNotInitializeError() + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, + LLMRateLimitError, LLMAuthorizationError) as e: + raise CompletionRequestError(str(e)) + + return {'introduction': answer} + + +api.add_resource(AppListApi, '/apps') +api.add_resource(AppTemplateApi, '/app-templates') +api.add_resource(AppApi, '/apps/') +api.add_resource(AppCopy, '/apps//copy') +api.add_resource(AppNameApi, '/apps//name') +api.add_resource(AppSiteStatus, '/apps//site-enable') +api.add_resource(AppApiStatus, '/apps//api-enable') +api.add_resource(AppRateLimit, '/apps//rate-limit') +api.add_resource(IntroductionGenerateApi, '/introduction-generate') diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py new file mode 100644 index 0000000000..552271a9ec --- /dev/null +++ b/api/controllers/console/app/completion.py @@ -0,0 +1,206 @@ +# -*- coding:utf-8 -*- +import json +import logging +from typing import Generator, Union + +import flask_login +from flask import Response, stream_with_context +from flask_login import login_required +from werkzeug.exceptions import InternalServerError, NotFound + +import services +from controllers.console import api +from controllers.console.app import _get_app +from controllers.console.app.error import ConversationCompletedError, AppUnavailableError, \ + ProviderNotInitializeError, CompletionRequestError, ProviderQuotaExceededError, \ + ProviderModelCurrentlyNotSupportError +from controllers.console.setup import setup_required +from controllers.console.wraps import account_initialization_required +from core.conversation_message_task import PubHandler +from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ + LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError +from libs.helper import uuid_value +from flask_restful import Resource, reqparse + +from services.completion_service import CompletionService + + +# define completion message api for user +class CompletionMessageApi(Resource): + + @setup_required + @login_required + @account_initialization_required + def post(self, app_id): + app_id = str(app_id) + + # get app info + app_model = _get_app(app_id, 'completion') + + parser = reqparse.RequestParser() + parser.add_argument('inputs', type=dict, required=True, location='json') + parser.add_argument('query', type=str, location='json') + parser.add_argument('model_config', type=dict, required=True, location='json') + args = parser.parse_args() + + account = flask_login.current_user + + try: + response = CompletionService.completion( + app_model=app_model, + user=account, + args=args, + from_source='console', + streaming=True, + is_model_config_override=True + ) + + return compact_response(response) + except services.errors.conversation.ConversationNotExistsError: + raise NotFound("Conversation Not Exists.") + except services.errors.conversation.ConversationCompletedError: + raise ConversationCompletedError() + except services.errors.app_model_config.AppModelConfigBrokenError: + logging.exception("App model config broken.") + raise AppUnavailableError() + except ProviderTokenNotInitError: + raise ProviderNotInitializeError() + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, + LLMRateLimitError, LLMAuthorizationError) as e: + raise CompletionRequestError(str(e)) + except ValueError as e: + raise e + except Exception as e: + logging.exception("internal server error.") + raise InternalServerError() + + +class CompletionMessageStopApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self, app_id, task_id): + app_id = str(app_id) + + # get app info + _get_app(app_id, 'completion') + + account = flask_login.current_user + + PubHandler.stop(account, task_id) + + return {'result': 'success'}, 200 + + +class ChatMessageApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self, app_id): + app_id = str(app_id) + + # get app info + app_model = _get_app(app_id, 'chat') + + parser = reqparse.RequestParser() + parser.add_argument('inputs', type=dict, required=True, location='json') + parser.add_argument('query', type=str, required=True, location='json') + parser.add_argument('model_config', type=dict, required=True, location='json') + parser.add_argument('conversation_id', type=uuid_value, location='json') + args = parser.parse_args() + + account = flask_login.current_user + + try: + response = CompletionService.completion( + app_model=app_model, + user=account, + args=args, + from_source='console', + streaming=True, + is_model_config_override=True + ) + + return compact_response(response) + except services.errors.conversation.ConversationNotExistsError: + raise NotFound("Conversation Not Exists.") + except services.errors.conversation.ConversationCompletedError: + raise ConversationCompletedError() + except services.errors.app_model_config.AppModelConfigBrokenError: + logging.exception("App model config broken.") + raise AppUnavailableError() + except ProviderTokenNotInitError: + raise ProviderNotInitializeError() + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, + LLMRateLimitError, LLMAuthorizationError) as e: + raise CompletionRequestError(str(e)) + except ValueError as e: + raise e + except Exception as e: + logging.exception("internal server error.") + raise InternalServerError() + + +def compact_response(response: Union[dict | Generator]) -> Response: + if isinstance(response, dict): + return Response(response=json.dumps(response), status=200, mimetype='application/json') + else: + def generate() -> Generator: + try: + for chunk in response: + yield chunk + except services.errors.conversation.ConversationNotExistsError: + yield "data: " + json.dumps(api.handle_error(NotFound("Conversation Not Exists.")).get_json()) + "\n\n" + except services.errors.conversation.ConversationCompletedError: + yield "data: " + json.dumps(api.handle_error(ConversationCompletedError()).get_json()) + "\n\n" + except services.errors.app_model_config.AppModelConfigBrokenError: + logging.exception("App model config broken.") + yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n" + except ProviderTokenNotInitError: + yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n" + except QuotaExceededError: + yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n" + except ModelCurrentlyNotSupportError: + yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n" + except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, + LLMRateLimitError, LLMAuthorizationError) as e: + yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n" + except ValueError as e: + yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n" + except Exception: + logging.exception("internal server error.") + yield "data: " + json.dumps(api.handle_error(InternalServerError()).get_json()) + "\n\n" + + return Response(stream_with_context(generate()), status=200, + mimetype='text/event-stream') + + +class ChatMessageStopApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self, app_id, task_id): + app_id = str(app_id) + + # get app info + _get_app(app_id, 'chat') + + account = flask_login.current_user + + PubHandler.stop(account, task_id) + + return {'result': 'success'}, 200 + + +api.add_resource(CompletionMessageApi, '/apps//completion-messages') +api.add_resource(CompletionMessageStopApi, '/apps//completion-messages//stop') +api.add_resource(ChatMessageApi, '/apps//chat-messages') +api.add_resource(ChatMessageStopApi, '/apps//chat-messages//stop') diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py new file mode 100644 index 0000000000..62752deaac --- /dev/null +++ b/api/controllers/console/app/conversation.py @@ -0,0 +1,384 @@ +from datetime import datetime + +import pytz +from flask_login import login_required, current_user +from flask_restful import Resource, reqparse, fields, marshal_with +from flask_restful.inputs import int_range +from sqlalchemy import or_, func +from sqlalchemy.orm import joinedload +from werkzeug.exceptions import NotFound + +from controllers.console import api +from controllers.console.app import _get_app +from controllers.console.setup import setup_required +from controllers.console.wraps import account_initialization_required +from libs.helper import TimestampField, datetime_string, uuid_value +from extensions.ext_database import db +from models.model import Message, MessageAnnotation, Conversation + +account_fields = { + 'id': fields.String, + 'name': fields.String, + 'email': fields.String +} + +feedback_fields = { + 'rating': fields.String, + 'content': fields.String, + 'from_source': fields.String, + 'from_end_user_id': fields.String, + 'from_account': fields.Nested(account_fields, allow_null=True), +} + +annotation_fields = { + 'content': fields.String, + 'account': fields.Nested(account_fields, allow_null=True), + 'created_at': TimestampField +} + +message_detail_fields = { + 'id': fields.String, + 'conversation_id': fields.String, + 'inputs': fields.Raw, + 'query': fields.String, + 'message': fields.Raw, + 'message_tokens': fields.Integer, + 'answer': fields.String, + 'answer_tokens': fields.Integer, + 'provider_response_latency': fields.Integer, + 'from_source': fields.String, + 'from_end_user_id': fields.String, + 'from_account_id': fields.String, + 'feedbacks': fields.List(fields.Nested(feedback_fields)), + 'annotation': fields.Nested(annotation_fields, allow_null=True), + 'created_at': TimestampField +} + +feedback_stat_fields = { + 'like': fields.Integer, + 'dislike': fields.Integer +} + +model_config_fields = { + 'opening_statement': fields.String, + 'suggested_questions': fields.Raw, + 'model': fields.Raw, + 'user_input_form': fields.Raw, + 'pre_prompt': fields.String, + 'agent_mode': fields.Raw, +} + + +class CompletionConversationApi(Resource): + class MessageTextField(fields.Raw): + def format(self, value): + return value[0]['text'] if value else '' + + simple_configs_fields = { + 'prompt_template': fields.String, + } + + simple_model_config_fields = { + 'model': fields.Raw(attribute='model_dict'), + 'pre_prompt': fields.String, + } + + simple_message_detail_fields = { + 'inputs': fields.Raw, + 'query': fields.String, + 'message': MessageTextField, + 'answer': fields.String, + } + + conversation_fields = { + 'id': fields.String, + 'status': fields.String, + 'from_source': fields.String, + 'from_end_user_id': fields.String, + 'from_account_id': fields.String, + 'read_at': TimestampField, + 'created_at': TimestampField, + 'annotation': fields.Nested(annotation_fields, allow_null=True), + 'model_config': fields.Nested(simple_model_config_fields), + 'user_feedback_stats': fields.Nested(feedback_stat_fields), + 'admin_feedback_stats': fields.Nested(feedback_stat_fields), + 'message': fields.Nested(simple_message_detail_fields, attribute='first_message') + } + + conversation_pagination_fields = { + 'page': fields.Integer, + 'limit': fields.Integer(attribute='per_page'), + 'total': fields.Integer, + 'has_more': fields.Boolean(attribute='has_next'), + 'data': fields.List(fields.Nested(conversation_fields), attribute='items') + } + + @setup_required + @login_required + @account_initialization_required + @marshal_with(conversation_pagination_fields) + def get(self, app_id): + app_id = str(app_id) + + parser = reqparse.RequestParser() + parser.add_argument('keyword', type=str, location='args') + parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument('annotation_status', type=str, + choices=['annotated', 'not_annotated', 'all'], default='all', location='args') + parser.add_argument('page', type=int_range(1, 99999), default=1, location='args') + parser.add_argument('limit', type=int_range(1, 100), default=20, location='args') + args = parser.parse_args() + + # get app info + app = _get_app(app_id, 'completion') + + query = db.select(Conversation).where(Conversation.app_id == app.id, Conversation.mode == 'completion') + + if args['keyword']: + query = query.join( + Message, Message.conversation_id == Conversation.id + ).filter( + or_( + Message.query.ilike('%{}%'.format(args['keyword'])), + Message.answer.ilike('%{}%'.format(args['keyword'])) + ) + ) + + account = current_user + timezone = pytz.timezone(account.timezone) + utc_timezone = pytz.utc + + if args['start']: + start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + start_datetime = start_datetime.replace(second=0) + + start_datetime_timezone = timezone.localize(start_datetime) + start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) + + query = query.where(Conversation.created_at >= start_datetime_utc) + + if args['end']: + end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + end_datetime = end_datetime.replace(second=0) + + end_datetime_timezone = timezone.localize(end_datetime) + end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) + + query = query.where(Conversation.created_at < end_datetime_utc) + + if args['annotation_status'] == "annotated": + query = query.options(joinedload(Conversation.message_annotations)).join( + MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id + ) + elif args['annotation_status'] == "not_annotated": + query = query.outerjoin( + MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id + ).group_by(Conversation.id).having(func.count(MessageAnnotation.id) == 0) + + query = query.order_by(Conversation.created_at.desc()) + + conversations = db.paginate( + query, + page=args['page'], + per_page=args['limit'], + error_out=False + ) + + return conversations + + +class CompletionConversationDetailApi(Resource): + conversation_detail_fields = { + 'id': fields.String, + 'status': fields.String, + 'from_source': fields.String, + 'from_end_user_id': fields.String, + 'from_account_id': fields.String, + 'created_at': TimestampField, + 'model_config': fields.Nested(model_config_fields), + 'message': fields.Nested(message_detail_fields, attribute='first_message'), + } + + @setup_required + @login_required + @account_initialization_required + @marshal_with(conversation_detail_fields) + def get(self, app_id, conversation_id): + app_id = str(app_id) + conversation_id = str(conversation_id) + + return _get_conversation(app_id, conversation_id, 'completion') + + +class ChatConversationApi(Resource): + simple_configs_fields = { + 'prompt_template': fields.String, + } + + simple_model_config_fields = { + 'model': fields.Raw(attribute='model_dict'), + 'pre_prompt': fields.String, + } + + conversation_fields = { + 'id': fields.String, + 'status': fields.String, + 'from_source': fields.String, + 'from_end_user_id': fields.String, + 'from_account_id': fields.String, + 'summary': fields.String(attribute='summary_or_query'), + 'read_at': TimestampField, + 'created_at': TimestampField, + 'annotated': fields.Boolean, + 'model_config': fields.Nested(simple_model_config_fields), + 'message_count': fields.Integer, + 'user_feedback_stats': fields.Nested(feedback_stat_fields), + 'admin_feedback_stats': fields.Nested(feedback_stat_fields) + } + + conversation_pagination_fields = { + 'page': fields.Integer, + 'limit': fields.Integer(attribute='per_page'), + 'total': fields.Integer, + 'has_more': fields.Boolean(attribute='has_next'), + 'data': fields.List(fields.Nested(conversation_fields), attribute='items') + } + + @setup_required + @login_required + @account_initialization_required + @marshal_with(conversation_pagination_fields) + def get(self, app_id): + app_id = str(app_id) + + parser = reqparse.RequestParser() + parser.add_argument('keyword', type=str, location='args') + parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument('annotation_status', type=str, + choices=['annotated', 'not_annotated', 'all'], default='all', location='args') + parser.add_argument('message_count_gte', type=int_range(1, 99999), required=False, location='args') + parser.add_argument('page', type=int_range(1, 99999), required=False, default=1, location='args') + parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + args = parser.parse_args() + + # get app info + app = _get_app(app_id, 'chat') + + query = db.select(Conversation).where(Conversation.app_id == app.id, Conversation.mode == 'chat') + + if args['keyword']: + query = query.join( + Message, Message.conversation_id == Conversation.id + ).filter( + or_( + Message.query.ilike('%{}%'.format(args['keyword'])), + Message.answer.ilike('%{}%'.format(args['keyword'])), + Conversation.name.ilike('%{}%'.format(args['keyword'])), + Conversation.introduction.ilike('%{}%'.format(args['keyword'])), + ), + + ) + + account = current_user + timezone = pytz.timezone(account.timezone) + utc_timezone = pytz.utc + + if args['start']: + start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + start_datetime = start_datetime.replace(second=0) + + start_datetime_timezone = timezone.localize(start_datetime) + start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) + + query = query.where(Conversation.created_at >= start_datetime_utc) + + if args['end']: + end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + end_datetime = end_datetime.replace(second=0) + + end_datetime_timezone = timezone.localize(end_datetime) + end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) + + query = query.where(Conversation.created_at < end_datetime_utc) + + if args['annotation_status'] == "annotated": + query = query.options(joinedload(Conversation.message_annotations)).join( + MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id + ) + elif args['annotation_status'] == "not_annotated": + query = query.outerjoin( + MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id + ).group_by(Conversation.id).having(func.count(MessageAnnotation.id) == 0) + + if args['message_count_gte'] and args['message_count_gte'] >= 1: + query = ( + query.options(joinedload(Conversation.messages)) + .join(Message, Message.conversation_id == Conversation.id) + .group_by(Conversation.id) + .having(func.count(Message.id) >= args['message_count_gte']) + ) + + query = query.order_by(Conversation.created_at.desc()) + + conversations = db.paginate( + query, + page=args['page'], + per_page=args['limit'], + error_out=False + ) + + return conversations + + +class ChatConversationDetailApi(Resource): + conversation_detail_fields = { + 'id': fields.String, + 'status': fields.String, + 'from_source': fields.String, + 'from_end_user_id': fields.String, + 'from_account_id': fields.String, + 'created_at': TimestampField, + 'annotated': fields.Boolean, + 'model_config': fields.Nested(model_config_fields), + 'message_count': fields.Integer, + 'user_feedback_stats': fields.Nested(feedback_stat_fields), + 'admin_feedback_stats': fields.Nested(feedback_stat_fields) + } + + @setup_required + @login_required + @account_initialization_required + @marshal_with(conversation_detail_fields) + def get(self, app_id, conversation_id): + app_id = str(app_id) + conversation_id = str(conversation_id) + + return _get_conversation(app_id, conversation_id, 'chat') + + + + +api.add_resource(CompletionConversationApi, '/apps//completion-conversations') +api.add_resource(CompletionConversationDetailApi, '/apps//completion-conversations/') +api.add_resource(ChatConversationApi, '/apps//chat-conversations') +api.add_resource(ChatConversationDetailApi, '/apps//chat-conversations/') + + +def _get_conversation(app_id, conversation_id, mode): + # get app info + app = _get_app(app_id, mode) + + conversation = db.session.query(Conversation) \ + .filter(Conversation.id == conversation_id, Conversation.app_id == app.id).first() + + if not conversation: + raise NotFound("Conversation Not Exists.") + + if not conversation.read_at: + conversation.read_at = datetime.utcnow() + conversation.read_account_id = current_user.id + db.session.commit() + + return conversation diff --git a/api/controllers/console/app/error.py b/api/controllers/console/app/error.py new file mode 100644 index 0000000000..c19f054be4 --- /dev/null +++ b/api/controllers/console/app/error.py @@ -0,0 +1,49 @@ +from libs.exception import BaseHTTPException + + +class AppNotFoundError(BaseHTTPException): + error_code = 'app_not_found' + description = "App not found." + code = 404 + + +class ProviderNotInitializeError(BaseHTTPException): + error_code = 'provider_not_initialize' + description = "Provider Token not initialize." + code = 400 + + +class ProviderQuotaExceededError(BaseHTTPException): + error_code = 'provider_quota_exceeded' + description = "Provider quota exceeded." + code = 400 + + +class ProviderModelCurrentlyNotSupportError(BaseHTTPException): + error_code = 'model_currently_not_support' + description = "GPT-4 currently not support." + code = 400 + + +class ConversationCompletedError(BaseHTTPException): + error_code = 'conversation_completed' + description = "Conversation was completed." + code = 400 + + +class AppUnavailableError(BaseHTTPException): + error_code = 'app_unavailable' + description = "App unavailable." + code = 400 + + +class CompletionRequestError(BaseHTTPException): + error_code = 'completion_request_error' + description = "Completion request failed." + code = 400 + + +class AppMoreLikeThisDisabledError(BaseHTTPException): + error_code = 'app_more_like_this_disabled' + description = "More like this disabled." + code = 403 diff --git a/api/controllers/console/app/explore.py b/api/controllers/console/app/explore.py new file mode 100644 index 0000000000..eeec2ddc24 --- /dev/null +++ b/api/controllers/console/app/explore.py @@ -0,0 +1,209 @@ +# -*- coding:utf-8 -*- +from datetime import datetime + +from flask_login import login_required, current_user +from flask_restful import Resource, reqparse, fields, marshal_with, abort, inputs +from sqlalchemy import and_ + +from controllers.console import api +from extensions.ext_database import db +from models.model import Tenant, App, InstalledApp, RecommendedApp +from services.account_service import TenantService + +app_fields = { + 'id': fields.String, + 'name': fields.String, + 'mode': fields.String, + 'icon': fields.String, + 'icon_background': fields.String +} + +installed_app_fields = { + 'id': fields.String, + 'app': fields.Nested(app_fields, attribute='app'), + 'app_owner_tenant_id': fields.String, + 'is_pinned': fields.Boolean, + 'last_used_at': fields.DateTime, + 'editable': fields.Boolean +} + +installed_app_list_fields = { + 'installed_apps': fields.List(fields.Nested(installed_app_fields)) +} + +recommended_app_fields = { + 'app': fields.Nested(app_fields, attribute='app'), + 'app_id': fields.String, + 'description': fields.String(attribute='description'), + 'copyright': fields.String, + 'privacy_policy': fields.String, + 'category': fields.String, + 'position': fields.Integer, + 'is_listed': fields.Boolean, + 'install_count': fields.Integer, + 'installed': fields.Boolean, + 'editable': fields.Boolean +} + +recommended_app_list_fields = { + 'recommended_apps': fields.List(fields.Nested(recommended_app_fields)), + 'categories': fields.List(fields.String) +} + + +class InstalledAppsListResource(Resource): + @login_required + @marshal_with(installed_app_list_fields) + def get(self): + current_tenant_id = Tenant.query.first().id + installed_apps = db.session.query(InstalledApp).filter( + InstalledApp.tenant_id == current_tenant_id + ).all() + + current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant) + installed_apps = [ + { + **installed_app, + "editable": current_user.role in ["owner", "admin"], + } + for installed_app in installed_apps + ] + installed_apps.sort(key=lambda app: (-app.is_pinned, app.last_used_at)) + + return {'installed_apps': installed_apps} + + @login_required + def post(self): + parser = reqparse.RequestParser() + parser.add_argument('app_id', type=str, required=True, help='Invalid app_id') + args = parser.parse_args() + + current_tenant_id = Tenant.query.first().id + app = App.query.get(args['app_id']) + if app is None: + abort(404, message='App not found') + recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args['app_id']).first() + if recommended_app is None: + abort(404, message='App not found') + if not app.is_public: + abort(403, message="You can't install a non-public app") + + installed_app = InstalledApp.query.filter(and_( + InstalledApp.app_id == args['app_id'], + InstalledApp.tenant_id == current_tenant_id + )).first() + + if installed_app is None: + # todo: position + recommended_app.install_count += 1 + + new_installed_app = InstalledApp( + app_id=args['app_id'], + tenant_id=current_tenant_id, + is_pinned=False, + last_used_at=datetime.utcnow() + ) + db.session.add(new_installed_app) + db.session.commit() + + return {'message': 'App installed successfully'} + + +class InstalledAppResource(Resource): + + @login_required + def delete(self, installed_app_id): + + installed_app = InstalledApp.query.filter(and_( + InstalledApp.id == str(installed_app_id), + InstalledApp.tenant_id == current_user.current_tenant_id + )).first() + + if installed_app is None: + abort(404, message='App not found') + + if installed_app.app_owner_tenant_id == current_user.current_tenant_id: + abort(400, message="You can't uninstall an app owned by the current tenant") + + db.session.delete(installed_app) + db.session.commit() + + return {'result': 'success', 'message': 'App uninstalled successfully'} + + @login_required + def patch(self, installed_app_id): + parser = reqparse.RequestParser() + parser.add_argument('is_pinned', type=inputs.boolean) + args = parser.parse_args() + + current_tenant_id = Tenant.query.first().id + installed_app = InstalledApp.query.filter(and_( + InstalledApp.id == str(installed_app_id), + InstalledApp.tenant_id == current_tenant_id + )).first() + + if installed_app is None: + abort(404, message='Installed app not found') + + commit_args = False + if 'is_pinned' in args: + installed_app.is_pinned = args['is_pinned'] + commit_args = True + + if commit_args: + db.session.commit() + + return {'result': 'success', 'message': 'App info updated successfully'} + + +class RecommendedAppsResource(Resource): + @login_required + @marshal_with(recommended_app_list_fields) + def get(self): + recommended_apps = db.session.query(RecommendedApp).filter( + RecommendedApp.is_listed == True + ).all() + + categories = set() + current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant) + recommended_apps_result = [] + for recommended_app in recommended_apps: + installed = db.session.query(InstalledApp).filter( + and_( + InstalledApp.app_id == recommended_app.app_id, + InstalledApp.tenant_id == current_user.current_tenant_id + ) + ).first() is not None + + language_prefix = current_user.interface_language.split('-')[0] + desc = None + if recommended_app.description: + if language_prefix in recommended_app.description: + desc = recommended_app.description[language_prefix] + elif 'en' in recommended_app.description: + desc = recommended_app.description['en'] + + recommended_app_result = { + 'id': recommended_app.id, + 'app': recommended_app.app, + 'app_id': recommended_app.app_id, + 'description': desc, + 'copyright': recommended_app.copyright, + 'privacy_policy': recommended_app.privacy_policy, + 'category': recommended_app.category, + 'position': recommended_app.position, + 'is_listed': recommended_app.is_listed, + 'install_count': recommended_app.install_count, + 'installed': installed, + 'editable': current_user.role in ['owner', 'admin'], + } + recommended_apps_result.append(recommended_app_result) + + categories.add(recommended_app.category) # add category to categories + + return {'recommended_apps': recommended_apps_result, 'categories': list(categories)} + + +api.add_resource(InstalledAppsListResource, '/installed-apps') +api.add_resource(InstalledAppResource, '/installed-apps/') +api.add_resource(RecommendedAppsResource, '/explore/apps') diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py new file mode 100644 index 0000000000..27698d965c --- /dev/null +++ b/api/controllers/console/app/message.py @@ -0,0 +1,361 @@ +import json +import logging +from typing import Union, Generator + +from flask import Response, stream_with_context +from flask_login import current_user, login_required +from flask_restful import Resource, reqparse, marshal_with, fields +from flask_restful.inputs import int_range +from werkzeug.exceptions import InternalServerError, NotFound + +from controllers.console import api +from controllers.console.app import _get_app +from controllers.console.app.error import CompletionRequestError, ProviderNotInitializeError, \ + AppMoreLikeThisDisabledError, ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError +from controllers.console.setup import setup_required +from controllers.console.wraps import account_initialization_required +from core.llm.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \ + ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError +from libs.helper import uuid_value, TimestampField +from libs.infinite_scroll_pagination import InfiniteScrollPagination +from extensions.ext_database import db +from models.model import MessageAnnotation, Conversation, Message, MessageFeedback +from services.completion_service import CompletionService +from services.errors.app import MoreLikeThisDisabledError +from services.errors.conversation import ConversationNotExistsError +from services.errors.message import MessageNotExistsError +from services.message_service import MessageService + + +class ChatMessageApi(Resource): + account_fields = { + 'id': fields.String, + 'name': fields.String, + 'email': fields.String + } + + feedback_fields = { + 'rating': fields.String, + 'content': fields.String, + 'from_source': fields.String, + 'from_end_user_id': fields.String, + 'from_account': fields.Nested(account_fields, allow_null=True), + } + + annotation_fields = { + 'content': fields.String, + 'account': fields.Nested(account_fields, allow_null=True), + 'created_at': TimestampField + } + + message_detail_fields = { + 'id': fields.String, + 'conversation_id': fields.String, + 'inputs': fields.Raw, + 'query': fields.String, + 'message': fields.Raw, + 'message_tokens': fields.Integer, + 'answer': fields.String, + 'answer_tokens': fields.Integer, + 'provider_response_latency': fields.Integer, + 'from_source': fields.String, + 'from_end_user_id': fields.String, + 'from_account_id': fields.String, + 'feedbacks': fields.List(fields.Nested(feedback_fields)), + 'annotation': fields.Nested(annotation_fields, allow_null=True), + 'created_at': TimestampField + } + + message_infinite_scroll_pagination_fields = { + 'limit': fields.Integer, + 'has_more': fields.Boolean, + 'data': fields.List(fields.Nested(message_detail_fields)) + } + + @setup_required + @login_required + @account_initialization_required + @marshal_with(message_infinite_scroll_pagination_fields) + def get(self, app_id): + app_id = str(app_id) + + # get app info + app = _get_app(app_id, 'chat') + + parser = reqparse.RequestParser() + parser.add_argument('conversation_id', required=True, type=uuid_value, location='args') + parser.add_argument('first_id', type=uuid_value, location='args') + parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + args = parser.parse_args() + + conversation = db.session.query(Conversation).filter( + Conversation.id == args['conversation_id'], + Conversation.app_id == app.id + ).first() + + if not conversation: + raise NotFound("Conversation Not Exists.") + + if args['first_id']: + first_message = db.session.query(Message) \ + .filter(Message.conversation_id == conversation.id, Message.id == args['first_id']).first() + + if not first_message: + raise NotFound("First message not found") + + history_messages = db.session.query(Message).filter( + Message.conversation_id == conversation.id, + Message.created_at < first_message.created_at, + Message.id != first_message.id + ) \ + .order_by(Message.created_at.desc()).limit(args['limit']).all() + else: + history_messages = db.session.query(Message).filter(Message.conversation_id == conversation.id) \ + .order_by(Message.created_at.desc()).limit(args['limit']).all() + + has_more = False + if len(history_messages) == args['limit']: + current_page_first_message = history_messages[-1] + rest_count = db.session.query(Message).filter( + Message.conversation_id == conversation.id, + Message.created_at < current_page_first_message.created_at, + Message.id != current_page_first_message.id + ).count() + + if rest_count > 0: + has_more = True + + history_messages = list(reversed(history_messages)) + + return InfiniteScrollPagination( + data=history_messages, + limit=args['limit'], + has_more=has_more + ) + + +class MessageFeedbackApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self, app_id): + app_id = str(app_id) + + # get app info + app = _get_app(app_id) + + parser = reqparse.RequestParser() + parser.add_argument('message_id', required=True, type=uuid_value, location='json') + parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json') + args = parser.parse_args() + + message_id = str(args['message_id']) + + message = db.session.query(Message).filter( + Message.id == message_id, + Message.app_id == app.id + ).first() + + if not message: + raise NotFound("Message Not Exists.") + + feedback = message.admin_feedback + + if not args['rating'] and feedback: + db.session.delete(feedback) + elif args['rating'] and feedback: + feedback.rating = args['rating'] + elif not args['rating'] and not feedback: + raise ValueError('rating cannot be None when feedback not exists') + else: + feedback = MessageFeedback( + app_id=app.id, + conversation_id=message.conversation_id, + message_id=message.id, + rating=args['rating'], + from_source='admin', + from_account_id=current_user.id + ) + db.session.add(feedback) + + db.session.commit() + + return {'result': 'success'} + + +class MessageAnnotationApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self, app_id): + app_id = str(app_id) + + # get app info + app = _get_app(app_id) + + parser = reqparse.RequestParser() + parser.add_argument('message_id', required=True, type=uuid_value, location='json') + parser.add_argument('content', type=str, location='json') + args = parser.parse_args() + + message_id = str(args['message_id']) + + message = db.session.query(Message).filter( + Message.id == message_id, + Message.app_id == app.id + ).first() + + if not message: + raise NotFound("Message Not Exists.") + + annotation = message.annotation + + if annotation: + annotation.content = args['content'] + else: + annotation = MessageAnnotation( + app_id=app.id, + conversation_id=message.conversation_id, + message_id=message.id, + content=args['content'], + account_id=current_user.id + ) + db.session.add(annotation) + + db.session.commit() + + return {'result': 'success'} + + +class MessageAnnotationCountApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, app_id): + app_id = str(app_id) + + # get app info + app = _get_app(app_id) + + count = db.session.query(MessageAnnotation).filter( + MessageAnnotation.app_id == app.id + ).count() + + return {'count': count} + + +class MessageMoreLikeThisApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, app_id, message_id): + app_id = str(app_id) + message_id = str(message_id) + + parser = reqparse.RequestParser() + parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'], location='args') + args = parser.parse_args() + + streaming = args['response_mode'] == 'streaming' + + # get app info + app_model = _get_app(app_id, 'completion') + + try: + response = CompletionService.generate_more_like_this(app_model, current_user, message_id, streaming) + return compact_response(response) + except MessageNotExistsError: + raise NotFound("Message Not Exists.") + except MoreLikeThisDisabledError: + raise AppMoreLikeThisDisabledError() + except ProviderTokenNotInitError: + raise ProviderNotInitializeError() + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, + LLMRateLimitError, LLMAuthorizationError) as e: + raise CompletionRequestError(str(e)) + except ValueError as e: + raise e + except Exception as e: + logging.exception("internal server error.") + raise InternalServerError() + + +def compact_response(response: Union[dict | Generator]) -> Response: + if isinstance(response, dict): + return Response(response=json.dumps(response), status=200, mimetype='application/json') + else: + def generate() -> Generator: + try: + for chunk in response: + yield chunk + except MessageNotExistsError: + yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n" + except MoreLikeThisDisabledError: + yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n" + except ProviderTokenNotInitError: + yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n" + except QuotaExceededError: + yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n" + except ModelCurrentlyNotSupportError: + yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n" + except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, + LLMRateLimitError, LLMAuthorizationError) as e: + yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n" + except ValueError as e: + yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n" + except Exception: + logging.exception("internal server error.") + yield "data: " + json.dumps(api.handle_error(InternalServerError()).get_json()) + "\n\n" + + return Response(stream_with_context(generate()), status=200, + mimetype='text/event-stream') + + +class MessageSuggestedQuestionApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, app_id, message_id): + app_id = str(app_id) + message_id = str(message_id) + + # get app info + app_model = _get_app(app_id, 'chat') + + try: + questions = MessageService.get_suggested_questions_after_answer( + app_model=app_model, + user=current_user, + message_id=message_id, + check_enabled=False + ) + except MessageNotExistsError: + raise NotFound("Message not found") + except ConversationNotExistsError: + raise NotFound("Conversation not found") + except ProviderTokenNotInitError: + raise ProviderNotInitializeError() + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, + LLMRateLimitError, LLMAuthorizationError) as e: + raise CompletionRequestError(str(e)) + except Exception: + logging.exception("internal server error.") + raise InternalServerError() + + return {'data': questions} + + +api.add_resource(MessageMoreLikeThisApi, '/apps//completion-messages//more-like-this') +api.add_resource(MessageSuggestedQuestionApi, '/apps//chat-messages//suggested-questions') +api.add_resource(ChatMessageApi, '/apps//chat-messages', endpoint='chat_messages') +api.add_resource(MessageFeedbackApi, '/apps//feedbacks') +api.add_resource(MessageAnnotationApi, '/apps//annotations') +api.add_resource(MessageAnnotationCountApi, '/apps//annotations/count') diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py new file mode 100644 index 0000000000..adc10d6609 --- /dev/null +++ b/api/controllers/console/app/model_config.py @@ -0,0 +1,65 @@ +# -*- coding:utf-8 -*- +import json + +from flask import request +from flask_restful import Resource +from flask_login import login_required, current_user + +from controllers.console import api +from controllers.console.app import _get_app +from controllers.console.setup import setup_required +from controllers.console.wraps import account_initialization_required +from events.app_event import app_model_config_was_updated +from extensions.ext_database import db +from models.model import AppModelConfig +from services.app_model_config_service import AppModelConfigService + + +class ModelConfigResource(Resource): + + @setup_required + @login_required + @account_initialization_required + def post(self, app_id): + """Modify app model config""" + app_id = str(app_id) + + app_model = _get_app(app_id) + + # validate config + model_configuration = AppModelConfigService.validate_configuration( + account=current_user, + config=request.json, + mode=app_model.mode + ) + + new_app_model_config = AppModelConfig( + app_id=app_model.id, + provider="", + model_id="", + configs={}, + opening_statement=model_configuration['opening_statement'], + suggested_questions=json.dumps(model_configuration['suggested_questions']), + suggested_questions_after_answer=json.dumps(model_configuration['suggested_questions_after_answer']), + more_like_this=json.dumps(model_configuration['more_like_this']), + model=json.dumps(model_configuration['model']), + user_input_form=json.dumps(model_configuration['user_input_form']), + pre_prompt=model_configuration['pre_prompt'], + agent_mode=json.dumps(model_configuration['agent_mode']), + ) + + db.session.add(new_app_model_config) + db.session.flush() + + app_model.app_model_config_id = new_app_model_config.id + db.session.commit() + + app_model_config_was_updated.send( + app_model, + app_model_config=new_app_model_config + ) + + return {'result': 'success'} + + +api.add_resource(ModelConfigResource, '/apps//model-config') diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py new file mode 100644 index 0000000000..2e0e00a881 --- /dev/null +++ b/api/controllers/console/app/site.py @@ -0,0 +1,114 @@ +# -*- coding:utf-8 -*- +from flask_login import login_required, current_user +from flask_restful import Resource, reqparse, fields, marshal_with +from werkzeug.exceptions import NotFound, Forbidden + +from controllers.console import api +from controllers.console.app import _get_app +from controllers.console.setup import setup_required +from controllers.console.wraps import account_initialization_required +from libs.helper import supported_language +from extensions.ext_database import db +from models.model import Site + +app_site_fields = { + 'app_id': fields.String, + 'access_token': fields.String(attribute='code'), + 'code': fields.String, + 'title': fields.String, + 'icon': fields.String, + 'icon_background': fields.String, + 'description': fields.String, + 'default_language': fields.String, + 'customize_domain': fields.String, + 'copyright': fields.String, + 'privacy_policy': fields.String, + 'customize_token_strategy': fields.String, + 'prompt_public': fields.Boolean +} + + +def parse_app_site_args(): + parser = reqparse.RequestParser() + parser.add_argument('title', type=str, required=False, location='json') + parser.add_argument('icon', type=str, required=False, location='json') + parser.add_argument('icon_background', type=str, required=False, location='json') + parser.add_argument('description', type=str, required=False, location='json') + parser.add_argument('default_language', type=supported_language, required=False, location='json') + parser.add_argument('customize_domain', type=str, required=False, location='json') + parser.add_argument('copyright', type=str, required=False, location='json') + parser.add_argument('privacy_policy', type=str, required=False, location='json') + parser.add_argument('customize_token_strategy', type=str, choices=['must', 'allow', 'not_allow'], + required=False, + location='json') + parser.add_argument('prompt_public', type=bool, required=False, location='json') + return parser.parse_args() + + +class AppSite(Resource): + @setup_required + @login_required + @account_initialization_required + @marshal_with(app_site_fields) + def post(self, app_id): + args = parse_app_site_args() + + app_id = str(app_id) + app_model = _get_app(app_id) + + # The role of the current user in the ta table must be admin or owner + if current_user.current_tenant.current_role not in ['admin', 'owner']: + raise Forbidden() + + site = db.session.query(Site). \ + filter(Site.app_id == app_model.id). \ + one_or_404() + + for attr_name in [ + 'title', + 'icon', + 'icon_background', + 'description', + 'default_language', + 'customize_domain', + 'copyright', + 'privacy_policy', + 'customize_token_strategy', + 'prompt_public' + ]: + value = args.get(attr_name) + if value is not None: + setattr(site, attr_name, value) + + db.session.commit() + + return site + + +class AppSiteAccessTokenReset(Resource): + + @setup_required + @login_required + @account_initialization_required + @marshal_with(app_site_fields) + def post(self, app_id): + app_id = str(app_id) + app_model = _get_app(app_id) + + # The role of the current user in the ta table must be admin or owner + if current_user.current_tenant.current_role not in ['admin', 'owner']: + raise Forbidden() + + site = db.session.query(Site).filter(Site.app_id == app_model.id).first() + + if not site: + raise NotFound + + site.code = Site.generate_code(16) + db.session.commit() + + return site + + +api.add_resource(AppSite, '/apps//site') +api.add_resource(AppSiteAccessTokenReset, '/apps//site/access-token-reset') diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py new file mode 100644 index 0000000000..37c3500448 --- /dev/null +++ b/api/controllers/console/app/statistic.py @@ -0,0 +1,202 @@ +# -*- coding:utf-8 -*- +from datetime import datetime + +import pytz +from flask import jsonify +from flask_login import login_required, current_user +from flask_restful import Resource, reqparse + +from controllers.console import api +from controllers.console.app import _get_app +from controllers.console.setup import setup_required +from controllers.console.wraps import account_initialization_required +from libs.helper import datetime_string +from extensions.ext_database import db + + +class DailyConversationStatistic(Resource): + + @setup_required + @login_required + @account_initialization_required + def get(self, app_id): + account = current_user + app_id = str(app_id) + app_model = _get_app(app_id) + + parser = reqparse.RequestParser() + parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + args = parser.parse_args() + + sql_query = ''' + SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.conversation_id) AS conversation_count + FROM messages where app_id = :app_id + ''' + arg_dict = {'tz': account.timezone, 'app_id': app_model.id} + + timezone = pytz.timezone(account.timezone) + utc_timezone = pytz.utc + + if args['start']: + start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + start_datetime = start_datetime.replace(second=0) + + start_datetime_timezone = timezone.localize(start_datetime) + start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) + + sql_query += ' and created_at >= :start' + arg_dict['start'] = start_datetime_utc + + if args['end']: + end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + end_datetime = end_datetime.replace(second=0) + + end_datetime_timezone = timezone.localize(end_datetime) + end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) + + sql_query += ' and created_at < :end' + arg_dict['end'] = end_datetime_utc + + sql_query += ' GROUP BY date order by date' + rs = db.session.execute(sql_query, arg_dict) + + response_date = [] + + for i in rs: + response_date.append({ + 'date': str(i.date), + 'conversation_count': i.conversation_count + }) + + return jsonify({ + 'data': response_date + }) + + +class DailyTerminalsStatistic(Resource): + + @setup_required + @login_required + @account_initialization_required + def get(self, app_id): + account = current_user + app_id = str(app_id) + app_model = _get_app(app_id) + + parser = reqparse.RequestParser() + parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + args = parser.parse_args() + + sql_query = ''' + SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.from_end_user_id) AS terminal_count + FROM messages where app_id = :app_id + ''' + arg_dict = {'tz': account.timezone, 'app_id': app_model.id} + + timezone = pytz.timezone(account.timezone) + utc_timezone = pytz.utc + + if args['start']: + start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + start_datetime = start_datetime.replace(second=0) + + start_datetime_timezone = timezone.localize(start_datetime) + start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) + + sql_query += ' and created_at >= :start' + arg_dict['start'] = start_datetime_utc + + if args['end']: + end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + end_datetime = end_datetime.replace(second=0) + + end_datetime_timezone = timezone.localize(end_datetime) + end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) + + sql_query += ' and created_at < :end' + arg_dict['end'] = end_datetime_utc + + sql_query += ' GROUP BY date order by date' + rs = db.session.execute(sql_query, arg_dict) + + response_date = [] + + for i in rs: + response_date.append({ + 'date': str(i.date), + 'terminal_count': i.terminal_count + }) + + return jsonify({ + 'data': response_date + }) + + +class DailyTokenCostStatistic(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, app_id): + account = current_user + app_id = str(app_id) + app_model = _get_app(app_id) + + parser = reqparse.RequestParser() + parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + args = parser.parse_args() + + sql_query = ''' + SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, + (sum(messages.message_tokens) + sum(messages.answer_tokens)) as token_count, + sum(total_price) as total_price + FROM messages where app_id = :app_id + ''' + arg_dict = {'tz': account.timezone, 'app_id': app_model.id} + + timezone = pytz.timezone(account.timezone) + utc_timezone = pytz.utc + + if args['start']: + start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + start_datetime = start_datetime.replace(second=0) + + start_datetime_timezone = timezone.localize(start_datetime) + start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) + + sql_query += ' and created_at >= :start' + arg_dict['start'] = start_datetime_utc + + if args['end']: + end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + end_datetime = end_datetime.replace(second=0) + + end_datetime_timezone = timezone.localize(end_datetime) + end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) + + sql_query += ' and created_at < :end' + arg_dict['end'] = end_datetime_utc + + sql_query += ' GROUP BY date order by date' + rs = db.session.execute(sql_query, arg_dict) + + response_date = [] + + for i in rs: + response_date.append({ + 'date': str(i.date), + 'token_count': i.token_count, + 'total_price': i.total_price, + 'currency': 'USD' + }) + + return jsonify({ + 'data': response_date + }) + + +api.add_resource(DailyConversationStatistic, '/apps//statistics/daily-conversations') +api.add_resource(DailyTerminalsStatistic, '/apps//statistics/daily-end-users') +api.add_resource(DailyTokenCostStatistic, '/apps//statistics/token-costs') diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py new file mode 100644 index 0000000000..89f5d91789 --- /dev/null +++ b/api/controllers/console/auth/login.py @@ -0,0 +1,109 @@ +# -*- coding:utf-8 -*- +import flask +import flask_login +from flask import request, current_app +from flask_restful import Resource, reqparse + +import services +from controllers.console import api +from controllers.console.error import AccountNotLinkTenantError +from controllers.console.setup import setup_required +from libs.helper import email +from libs.password import valid_password +from services.account_service import AccountService, TenantService + + +class LoginApi(Resource): + """Resource for user login.""" + + @setup_required + def post(self): + """Authenticate user and login.""" + parser = reqparse.RequestParser() + parser.add_argument('email', type=email, required=True, location='json') + parser.add_argument('password', type=valid_password, required=True, location='json') + parser.add_argument('remember_me', type=bool, required=False, default=False, location='json') + args = parser.parse_args() + + # todo: Verify the recaptcha + + try: + account = AccountService.authenticate(args['email'], args['password']) + except services.errors.account.AccountLoginError: + return {'code': 'unauthorized', 'message': 'Invalid email or password'}, 401 + + try: + TenantService.switch_tenant(account) + except Exception: + raise AccountNotLinkTenantError("Account not link tenant") + + flask_login.login_user(account, remember=args['remember_me']) + AccountService.update_last_login(account, request) + + # todo: return the user info + + return {'result': 'success'} + + +class LogoutApi(Resource): + + @setup_required + def get(self): + flask.session.pop('workspace_id', None) + flask_login.logout_user() + return {'result': 'success'} + + +class ResetPasswordApi(Resource): + @setup_required + def get(self): + parser = reqparse.RequestParser() + parser.add_argument('email', type=email, required=True, location='json') + args = parser.parse_args() + + # import mailchimp_transactional as MailchimpTransactional + # from mailchimp_transactional.api_client import ApiClientError + + account = {'email': args['email']} + # account = AccountService.get_by_email(args['email']) + # if account is None: + # raise ValueError('Email not found') + # new_password = AccountService.generate_password() + # AccountService.update_password(account, new_password) + + # todo: Send email + MAILCHIMP_API_KEY = current_app.config['MAILCHIMP_TRANSACTIONAL_API_KEY'] + # mailchimp = MailchimpTransactional(MAILCHIMP_API_KEY) + + message = { + 'from_email': 'noreply@example.com', + 'to': [{'email': account.email}], + 'subject': 'Reset your Dify password', + 'html': """ +

Dear User,

+

The Dify team has generated a new password for you, details as follows:

+

{new_password}

+

Please change your password to log in as soon as possible.

+

Regards,

+

The Dify Team

+ """ + } + + # response = mailchimp.messages.send({ + # 'message': message, + # # required for transactional email + # ' settings': { + # 'sandbox_mode': current_app.config['MAILCHIMP_SANDBOX_MODE'], + # }, + # }) + + # Check if MSG was sent + # if response.status_code != 200: + # # handle error + # pass + + return {'result': 'success'} + + +api.add_resource(LoginApi, '/login') +api.add_resource(LogoutApi, '/logout') diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py new file mode 100644 index 0000000000..ababc30de9 --- /dev/null +++ b/api/controllers/console/auth/oauth.py @@ -0,0 +1,126 @@ +import logging +from datetime import datetime +from typing import Optional + +import flask_login +import requests +from flask import request, redirect, current_app, session +from flask_restful import Resource + +from libs.oauth import OAuthUserInfo, GitHubOAuth, GoogleOAuth +from extensions.ext_database import db +from models.account import Account, AccountStatus +from services.account_service import AccountService, RegisterService +from .. import api + + +def get_oauth_providers(): + with current_app.app_context(): + github_oauth = GitHubOAuth(client_id=current_app.config.get('GITHUB_CLIENT_ID'), + client_secret=current_app.config.get( + 'GITHUB_CLIENT_SECRET'), + redirect_uri=current_app.config.get( + 'CONSOLE_URL') + '/console/api/oauth/authorize/github') + + google_oauth = GoogleOAuth(client_id=current_app.config.get('GOOGLE_CLIENT_ID'), + client_secret=current_app.config.get( + 'GOOGLE_CLIENT_SECRET'), + redirect_uri=current_app.config.get( + 'CONSOLE_URL') + '/console/api/oauth/authorize/google') + + OAUTH_PROVIDERS = { + 'github': github_oauth, + 'google': google_oauth + } + return OAUTH_PROVIDERS + + +class OAuthLogin(Resource): + def get(self, provider: str): + OAUTH_PROVIDERS = get_oauth_providers() + with current_app.app_context(): + oauth_provider = OAUTH_PROVIDERS.get(provider) + print(vars(oauth_provider)) + if not oauth_provider: + return {'error': 'Invalid provider'}, 400 + + auth_url = oauth_provider.get_authorization_url() + return redirect(auth_url) + + +class OAuthCallback(Resource): + def get(self, provider: str): + OAUTH_PROVIDERS = get_oauth_providers() + with current_app.app_context(): + oauth_provider = OAUTH_PROVIDERS.get(provider) + if not oauth_provider: + return {'error': 'Invalid provider'}, 400 + + code = request.args.get('code') + try: + token = oauth_provider.get_access_token(code) + user_info = oauth_provider.get_user_info(token) + except requests.exceptions.HTTPError as e: + logging.exception( + f"An error occurred during the OAuth process with {provider}: {e.response.text}") + return {'error': 'OAuth process failed'}, 400 + + account = _generate_account(provider, user_info) + # Check account status + if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value: + return {'error': 'Account is banned or closed.'}, 403 + + if account.status == AccountStatus.PENDING.value: + account.status = AccountStatus.ACTIVE.value + account.initialized_at = datetime.utcnow() + db.session.commit() + + # login user + session.clear() + flask_login.login_user(account, remember=True) + AccountService.update_last_login(account, request) + + return redirect(f'{current_app.config.get("CONSOLE_URL")}?oauth_login=success') + + +def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]: + account = Account.get_by_openid(provider, user_info.id) + + if not account: + account = Account.query.filter_by(email=user_info.email).first() + + return account + + +def _generate_account(provider: str, user_info: OAuthUserInfo): + # Get account by openid or email. + account = _get_account_by_openid_or_email(provider, user_info) + + if not account: + # Create account + account_name = user_info.name if user_info.name else 'Dify' + account = RegisterService.register( + email=user_info.email, + name=account_name, + password=None, + open_id=user_info.id, + provider=provider + ) + + # Set interface language + preferred_lang = request.accept_languages.best_match(['zh', 'en']) + if preferred_lang == 'zh': + interface_language = 'zh-Hans' + else: + interface_language = 'en-US' + account.interface_language = interface_language + db.session.commit() + + # Link account + AccountService.link_account_integrate(provider, user_info.id, account) + + return account + + +api.add_resource(OAuthLogin, '/oauth/login/') +api.add_resource(OAuthCallback, '/oauth/authorize/') diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py new file mode 100644 index 0000000000..04a4a44840 --- /dev/null +++ b/api/controllers/console/datasets/datasets.py @@ -0,0 +1,281 @@ +# -*- coding:utf-8 -*- +from flask import request +from flask_login import login_required, current_user +from flask_restful import Resource, reqparse, fields, marshal, marshal_with +from werkzeug.exceptions import NotFound, Forbidden + +import services +from controllers.console import api +from controllers.console.datasets.error import DatasetNameDuplicateError +from controllers.console.setup import setup_required +from controllers.console.wraps import account_initialization_required +from core.indexing_runner import IndexingRunner +from libs.helper import TimestampField +from extensions.ext_database import db +from models.model import UploadFile +from services.dataset_service import DatasetService + +dataset_detail_fields = { + 'id': fields.String, + 'name': fields.String, + 'description': fields.String, + 'provider': fields.String, + 'permission': fields.String, + 'data_source_type': fields.String, + 'indexing_technique': fields.String, + 'app_count': fields.Integer, + 'document_count': fields.Integer, + 'word_count': fields.Integer, + 'created_by': fields.String, + 'created_at': TimestampField, + 'updated_by': fields.String, + 'updated_at': TimestampField, +} + +dataset_query_detail_fields = { + "id": fields.String, + "content": fields.String, + "source": fields.String, + "source_app_id": fields.String, + "created_by_role": fields.String, + "created_by": fields.String, + "created_at": TimestampField +} + + +def _validate_name(name): + if not name or len(name) < 1 or len(name) > 40: + raise ValueError('Name must be between 1 to 40 characters.') + return name + + +def _validate_description_length(description): + if len(description) > 200: + raise ValueError('Description cannot exceed 200 characters.') + return description + + +class DatasetListApi(Resource): + + @setup_required + @login_required + @account_initialization_required + def get(self): + page = request.args.get('page', default=1, type=int) + limit = request.args.get('limit', default=20, type=int) + ids = request.args.getlist('ids') + provider = request.args.get('provider', default="vendor") + if ids: + datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id) + else: + datasets, total = DatasetService.get_datasets(page, limit, provider, + current_user.current_tenant_id, current_user) + + response = { + 'data': marshal(datasets, dataset_detail_fields), + 'has_more': len(datasets) == limit, + 'limit': limit, + 'total': total, + 'page': page + } + return response, 200 + + @setup_required + @login_required + @account_initialization_required + def post(self): + parser = reqparse.RequestParser() + parser.add_argument('name', nullable=False, required=True, + help='type is required. Name must be between 1 to 40 characters.', + type=_validate_name) + parser.add_argument('indexing_technique', type=str, location='json', + choices=('high_quality', 'economy'), + help='Invalid indexing technique.') + args = parser.parse_args() + + # The role of the current user in the ta table must be admin or owner + if current_user.current_tenant.current_role not in ['admin', 'owner']: + raise Forbidden() + + try: + dataset = DatasetService.create_empty_dataset( + tenant_id=current_user.current_tenant_id, + name=args['name'], + indexing_technique=args['indexing_technique'], + account=current_user + ) + except services.errors.dataset.DatasetNameDuplicateError: + raise DatasetNameDuplicateError() + + return marshal(dataset, dataset_detail_fields), 201 + + +class DatasetApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, dataset_id): + dataset_id_str = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id_str) + if dataset is None: + raise NotFound("Dataset not found.") + + try: + DatasetService.check_dataset_permission( + dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + + return marshal(dataset, dataset_detail_fields), 200 + + @setup_required + @login_required + @account_initialization_required + def patch(self, dataset_id): + dataset_id_str = str(dataset_id) + + parser = reqparse.RequestParser() + parser.add_argument('name', nullable=False, + help='type is required. Name must be between 1 to 40 characters.', + type=_validate_name) + parser.add_argument('description', + location='json', store_missing=False, + type=_validate_description_length) + parser.add_argument('indexing_technique', type=str, location='json', + choices=('high_quality', 'economy'), + help='Invalid indexing technique.') + parser.add_argument('permission', type=str, location='json', choices=( + 'only_me', 'all_team_members'), help='Invalid permission.') + args = parser.parse_args() + + # The role of the current user in the ta table must be admin or owner + if current_user.current_tenant.current_role not in ['admin', 'owner']: + raise Forbidden() + + dataset = DatasetService.update_dataset( + dataset_id_str, args, current_user) + + if dataset is None: + raise NotFound("Dataset not found.") + + return marshal(dataset, dataset_detail_fields), 200 + + @setup_required + @login_required + @account_initialization_required + def delete(self, dataset_id): + dataset_id_str = str(dataset_id) + + # The role of the current user in the ta table must be admin or owner + if current_user.current_tenant.current_role not in ['admin', 'owner']: + raise Forbidden() + + if DatasetService.delete_dataset(dataset_id_str, current_user): + return {'result': 'success'}, 204 + else: + raise NotFound("Dataset not found.") + + +class DatasetQueryApi(Resource): + + @setup_required + @login_required + @account_initialization_required + def get(self, dataset_id): + dataset_id_str = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id_str) + if dataset is None: + raise NotFound("Dataset not found.") + + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + + page = request.args.get('page', default=1, type=int) + limit = request.args.get('limit', default=20, type=int) + + dataset_queries, total = DatasetService.get_dataset_queries( + dataset_id=dataset.id, + page=page, + per_page=limit + ) + + response = { + 'data': marshal(dataset_queries, dataset_query_detail_fields), + 'has_more': len(dataset_queries) == limit, + 'limit': limit, + 'total': total, + 'page': page + } + return response, 200 + + +class DatasetIndexingEstimateApi(Resource): + + @setup_required + @login_required + @account_initialization_required + def post(self): + segment_rule = request.get_json() + file_detail = db.session.query(UploadFile).filter( + UploadFile.tenant_id == current_user.current_tenant_id, + UploadFile.id == segment_rule["file_id"] + ).first() + + if file_detail is None: + raise NotFound("File not found.") + + indexing_runner = IndexingRunner() + response = indexing_runner.indexing_estimate(file_detail, segment_rule['process_rule']) + return response, 200 + + +class DatasetRelatedAppListApi(Resource): + app_detail_kernel_fields = { + 'id': fields.String, + 'name': fields.String, + 'mode': fields.String, + 'icon': fields.String, + 'icon_background': fields.String, + } + + related_app_list = { + 'data': fields.List(fields.Nested(app_detail_kernel_fields)), + 'total': fields.Integer, + } + + @setup_required + @login_required + @account_initialization_required + @marshal_with(related_app_list) + def get(self, dataset_id): + dataset_id_str = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id_str) + if dataset is None: + raise NotFound("Dataset not found.") + + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + + app_dataset_joins = DatasetService.get_related_apps(dataset.id) + + related_apps = [] + for app_dataset_join in app_dataset_joins: + app_model = app_dataset_join.app + if app_model: + related_apps.append(app_model) + + return { + 'data': related_apps, + 'total': len(related_apps) + }, 200 + + +api.add_resource(DatasetListApi, '/datasets') +api.add_resource(DatasetApi, '/datasets/') +api.add_resource(DatasetQueryApi, '/datasets//queries') +api.add_resource(DatasetIndexingEstimateApi, '/datasets/file-indexing-estimate') +api.add_resource(DatasetRelatedAppListApi, '/datasets//related-apps') diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py new file mode 100644 index 0000000000..79e52d565b --- /dev/null +++ b/api/controllers/console/datasets/datasets_document.py @@ -0,0 +1,682 @@ +# -*- coding:utf-8 -*- +import random +from datetime import datetime + +from flask import request +from flask_login import login_required, current_user +from flask_restful import Resource, fields, marshal, marshal_with, reqparse +from sqlalchemy import desc, asc +from werkzeug.exceptions import NotFound, Forbidden + +import services +from controllers.console import api +from controllers.console.app.error import ProviderNotInitializeError +from controllers.console.datasets.error import DocumentAlreadyFinishedError, InvalidActionError, DocumentIndexingError, \ + InvalidMetadataError, ArchivedDocumentImmutableError +from controllers.console.setup import setup_required +from controllers.console.wraps import account_initialization_required +from core.indexing_runner import IndexingRunner +from core.llm.error import ProviderTokenNotInitError +from extensions.ext_redis import redis_client +from libs.helper import TimestampField +from extensions.ext_database import db +from models.dataset import DatasetProcessRule, Dataset +from models.dataset import Document, DocumentSegment +from models.model import UploadFile +from services.dataset_service import DocumentService, DatasetService +from tasks.add_document_to_index_task import add_document_to_index_task +from tasks.remove_document_from_index_task import remove_document_from_index_task + +dataset_fields = { + 'id': fields.String, + 'name': fields.String, + 'description': fields.String, + 'permission': fields.String, + 'data_source_type': fields.String, + 'indexing_technique': fields.String, + 'created_by': fields.String, + 'created_at': TimestampField, +} + +document_fields = { + 'id': fields.String, + 'position': fields.Integer, + 'data_source_type': fields.String, + 'data_source_info': fields.Raw(attribute='data_source_info_dict'), + 'dataset_process_rule_id': fields.String, + 'name': fields.String, + 'created_from': fields.String, + 'created_by': fields.String, + 'created_at': TimestampField, + 'tokens': fields.Integer, + 'indexing_status': fields.String, + 'error': fields.String, + 'enabled': fields.Boolean, + 'disabled_at': TimestampField, + 'disabled_by': fields.String, + 'archived': fields.Boolean, + 'display_status': fields.String, + 'word_count': fields.Integer, + 'hit_count': fields.Integer, +} + + +class DocumentResource(Resource): + def get_document(self, dataset_id: str, document_id: str) -> Document: + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise NotFound('Dataset not found.') + + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + + document = DocumentService.get_document(dataset_id, document_id) + + if not document: + raise NotFound('Document not found.') + + if document.tenant_id != current_user.current_tenant_id: + raise Forbidden('No permission.') + + return document + + +class GetProcessRuleApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self): + req_data = request.args + + document_id = req_data.get('document_id') + if document_id: + # get the latest process rule + document = Document.query.get_or_404(document_id) + + dataset = DatasetService.get_dataset(document.dataset_id) + + if not dataset: + raise NotFound('Dataset not found.') + + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + + # get the latest process rule + dataset_process_rule = db.session.query(DatasetProcessRule). \ + filter(DatasetProcessRule.dataset_id == document.dataset_id). \ + order_by(DatasetProcessRule.created_at.desc()). \ + limit(1). \ + one_or_none() + mode = dataset_process_rule.mode + rules = dataset_process_rule.rules_dict + else: + mode = DocumentService.DEFAULT_RULES['mode'] + rules = DocumentService.DEFAULT_RULES['rules'] + + return { + 'mode': mode, + 'rules': rules + } + + +class DatasetDocumentListApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, dataset_id): + dataset_id = str(dataset_id) + page = request.args.get('page', default=1, type=int) + limit = request.args.get('limit', default=20, type=int) + search = request.args.get('search', default=None, type=str) + sort = request.args.get('sort', default='-created_at', type=str) + + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise NotFound('Dataset not found.') + + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + + query = Document.query.filter_by( + dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id) + + if search: + search = f'%{search}%' + query = query.filter(Document.name.like(search)) + + if sort.startswith('-'): + sort_logic = desc + sort = sort[1:] + else: + sort_logic = asc + + if sort == 'hit_count': + sub_query = db.select(DocumentSegment.document_id, + db.func.sum(DocumentSegment.hit_count).label("total_hit_count")) \ + .group_by(DocumentSegment.document_id) \ + .subquery() + + query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id) \ + .order_by(sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0))) + elif sort == 'created_at': + query = query.order_by(sort_logic(Document.created_at)) + else: + query = query.order_by(desc(Document.created_at)) + + paginated_documents = query.paginate( + page=page, per_page=limit, max_per_page=100, error_out=False) + documents = paginated_documents.items + + response = { + 'data': marshal(documents, document_fields), + 'has_more': len(documents) == limit, + 'limit': limit, + 'total': paginated_documents.total, + 'page': page + } + + return response + + @setup_required + @login_required + @account_initialization_required + @marshal_with(document_fields) + def post(self, dataset_id): + dataset_id = str(dataset_id) + + dataset = DatasetService.get_dataset(dataset_id) + + if not dataset: + raise NotFound('Dataset not found.') + + # The role of the current user in the ta table must be admin or owner + if current_user.current_tenant.current_role not in ['admin', 'owner']: + raise Forbidden() + + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + + parser = reqparse.RequestParser() + parser.add_argument('indexing_technique', type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, + location='json') + parser.add_argument('data_source', type=dict, required=True, nullable=True, location='json') + parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') + parser.add_argument('duplicate', type=bool, nullable=False, location='json') + args = parser.parse_args() + + if not dataset.indexing_technique and not args['indexing_technique']: + raise ValueError('indexing_technique is required.') + + # validate args + DocumentService.document_create_args_validate(args) + + try: + document = DocumentService.save_document_with_dataset_id(dataset, args, current_user) + except ProviderTokenNotInitError: + raise ProviderNotInitializeError() + + return document + + +class DatasetInitApi(Resource): + dataset_and_document_fields = { + 'dataset': fields.Nested(dataset_fields), + 'document': fields.Nested(document_fields) + } + + @setup_required + @login_required + @account_initialization_required + @marshal_with(dataset_and_document_fields) + def post(self): + # The role of the current user in the ta table must be admin or owner + if current_user.current_tenant.current_role not in ['admin', 'owner']: + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument('indexing_technique', type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, required=True, + nullable=False, location='json') + parser.add_argument('data_source', type=dict, required=True, nullable=True, location='json') + parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') + args = parser.parse_args() + + # validate args + DocumentService.document_create_args_validate(args) + + try: + dataset, document = DocumentService.save_document_without_dataset_id( + tenant_id=current_user.current_tenant_id, + document_data=args, + account=current_user + ) + except ProviderTokenNotInitError: + raise ProviderNotInitializeError() + + response = { + 'dataset': dataset, + 'document': document + } + + return response + + +class DocumentIndexingEstimateApi(DocumentResource): + + @setup_required + @login_required + @account_initialization_required + def get(self, dataset_id, document_id): + dataset_id = str(dataset_id) + document_id = str(document_id) + document = self.get_document(dataset_id, document_id) + + if document.indexing_status in ['completed', 'error']: + raise DocumentAlreadyFinishedError() + + data_process_rule = document.dataset_process_rule + data_process_rule_dict = data_process_rule.to_dict() + + response = { + "tokens": 0, + "total_price": 0, + "currency": "USD", + "total_segments": 0, + "preview": [] + } + + if document.data_source_type == 'upload_file': + data_source_info = document.data_source_info_dict + if data_source_info and 'upload_file_id' in data_source_info: + file_id = data_source_info['upload_file_id'] + + file = db.session.query(UploadFile).filter( + UploadFile.tenant_id == document.tenant_id, + UploadFile.id == file_id + ).first() + + # raise error if file not found + if not file: + raise NotFound('File not found.') + + indexing_runner = IndexingRunner() + response = indexing_runner.indexing_estimate(file, data_process_rule_dict) + + return response + + +class DocumentIndexingStatusApi(DocumentResource): + document_status_fields = { + 'id': fields.String, + 'indexing_status': fields.String, + 'processing_started_at': TimestampField, + 'parsing_completed_at': TimestampField, + 'cleaning_completed_at': TimestampField, + 'splitting_completed_at': TimestampField, + 'completed_at': TimestampField, + 'paused_at': TimestampField, + 'error': fields.String, + 'stopped_at': TimestampField, + 'completed_segments': fields.Integer, + 'total_segments': fields.Integer, + } + + @setup_required + @login_required + @account_initialization_required + def get(self, dataset_id, document_id): + dataset_id = str(dataset_id) + document_id = str(document_id) + document = self.get_document(dataset_id, document_id) + + completed_segments = DocumentSegment.query \ + .filter(DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document_id)) \ + .count() + total_segments = DocumentSegment.query \ + .filter_by(document_id=str(document_id)) \ + .count() + + document.completed_segments = completed_segments + document.total_segments = total_segments + + return marshal(document, self.document_status_fields) + + +class DocumentDetailApi(DocumentResource): + METADATA_CHOICES = {'all', 'only', 'without'} + + @setup_required + @login_required + @account_initialization_required + def get(self, dataset_id, document_id): + dataset_id = str(dataset_id) + document_id = str(document_id) + document = self.get_document(dataset_id, document_id) + + metadata = request.args.get('metadata', 'all') + if metadata not in self.METADATA_CHOICES: + raise InvalidMetadataError(f'Invalid metadata value: {metadata}') + + if metadata == 'only': + response = { + 'id': document.id, + 'doc_type': document.doc_type, + 'doc_metadata': document.doc_metadata + } + elif metadata == 'without': + process_rules = DatasetService.get_process_rules(dataset_id) + data_source_info = document.data_source_detail_dict + response = { + 'id': document.id, + 'position': document.position, + 'data_source_type': document.data_source_type, + 'data_source_info': data_source_info, + 'dataset_process_rule_id': document.dataset_process_rule_id, + 'dataset_process_rule': process_rules, + 'name': document.name, + 'created_from': document.created_from, + 'created_by': document.created_by, + 'created_at': document.created_at.timestamp(), + 'tokens': document.tokens, + 'indexing_status': document.indexing_status, + 'completed_at': int(document.completed_at.timestamp()) if document.completed_at else None, + 'updated_at': int(document.updated_at.timestamp()) if document.updated_at else None, + 'indexing_latency': document.indexing_latency, + 'error': document.error, + 'enabled': document.enabled, + 'disabled_at': int(document.disabled_at.timestamp()) if document.disabled_at else None, + 'disabled_by': document.disabled_by, + 'archived': document.archived, + 'segment_count': document.segment_count, + 'average_segment_length': document.average_segment_length, + 'hit_count': document.hit_count, + 'display_status': document.display_status + } + else: + process_rules = DatasetService.get_process_rules(dataset_id) + data_source_info = document.data_source_detail_dict_() + response = { + 'id': document.id, + 'position': document.position, + 'data_source_type': document.data_source_type, + 'data_source_info': data_source_info, + 'dataset_process_rule_id': document.dataset_process_rule_id, + 'dataset_process_rule': process_rules, + 'name': document.name, + 'created_from': document.created_from, + 'created_by': document.created_by, + 'created_at': document.created_at.timestamp(), + 'tokens': document.tokens, + 'indexing_status': document.indexing_status, + 'completed_at': int(document.completed_at.timestamp())if document.completed_at else None, + 'updated_at': int(document.updated_at.timestamp()) if document.updated_at else None, + 'indexing_latency': document.indexing_latency, + 'error': document.error, + 'enabled': document.enabled, + 'disabled_at': int(document.disabled_at.timestamp()) if document.disabled_at else None, + 'disabled_by': document.disabled_by, + 'archived': document.archived, + 'doc_type': document.doc_type, + 'doc_metadata': document.doc_metadata, + 'segment_count': document.segment_count, + 'average_segment_length': document.average_segment_length, + 'hit_count': document.hit_count, + 'display_status': document.display_status + } + + return response, 200 + + +class DocumentProcessingApi(DocumentResource): + @setup_required + @login_required + @account_initialization_required + def patch(self, dataset_id, document_id, action): + dataset_id = str(dataset_id) + document_id = str(document_id) + document = self.get_document(dataset_id, document_id) + + # The role of the current user in the ta table must be admin or owner + if current_user.current_tenant.current_role not in ['admin', 'owner']: + raise Forbidden() + + if action == "pause": + if document.indexing_status != "indexing": + raise InvalidActionError('Document not in indexing state.') + + document.paused_by = current_user.id + document.paused_at = datetime.utcnow() + document.is_paused = True + db.session.commit() + + elif action == "resume": + if document.indexing_status not in ["paused", "error"]: + raise InvalidActionError('Document not in paused or error state.') + + document.paused_by = None + document.paused_at = None + document.is_paused = False + db.session.commit() + else: + raise InvalidActionError() + + return {'result': 'success'}, 200 + + +class DocumentDeleteApi(DocumentResource): + @setup_required + @login_required + @account_initialization_required + def delete(self, dataset_id, document_id): + dataset_id = str(dataset_id) + document_id = str(document_id) + document = self.get_document(dataset_id, document_id) + + try: + DocumentService.delete_document(document) + except services.errors.document.DocumentIndexingError: + raise DocumentIndexingError('Cannot delete document during indexing.') + + return {'result': 'success'}, 204 + + +class DocumentMetadataApi(DocumentResource): + @setup_required + @login_required + @account_initialization_required + def put(self, dataset_id, document_id): + dataset_id = str(dataset_id) + document_id = str(document_id) + document = self.get_document(dataset_id, document_id) + + req_data = request.get_json() + + doc_type = req_data.get('doc_type') + doc_metadata = req_data.get('doc_metadata') + + # The role of the current user in the ta table must be admin or owner + if current_user.current_tenant.current_role not in ['admin', 'owner']: + raise Forbidden() + + if doc_type is None or doc_metadata is None: + raise ValueError('Both doc_type and doc_metadata must be provided.') + + if doc_type not in DocumentService.DOCUMENT_METADATA_SCHEMA: + raise ValueError('Invalid doc_type.') + + if not isinstance(doc_metadata, dict): + raise ValueError('doc_metadata must be a dictionary.') + + metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type] + + document.doc_metadata = {} + + for key, value_type in metadata_schema.items(): + value = doc_metadata.get(key) + if value is not None and isinstance(value, value_type): + document.doc_metadata[key] = value + + document.doc_type = doc_type + document.updated_at = datetime.utcnow() + db.session.commit() + + return {'result': 'success', 'message': 'Document metadata updated.'}, 200 + + +class DocumentStatusApi(DocumentResource): + @setup_required + @login_required + @account_initialization_required + def patch(self, dataset_id, document_id, action): + dataset_id = str(dataset_id) + document_id = str(document_id) + document = self.get_document(dataset_id, document_id) + + # The role of the current user in the ta table must be admin or owner + if current_user.current_tenant.current_role not in ['admin', 'owner']: + raise Forbidden() + + indexing_cache_key = 'document_{}_indexing'.format(document.id) + cache_result = redis_client.get(indexing_cache_key) + if cache_result is not None: + raise InvalidActionError("Document is being indexed, please try again later") + + if action == "enable": + if document.enabled: + raise InvalidActionError('Document already enabled.') + + document.enabled = True + document.disabled_at = None + document.disabled_by = None + document.updated_at = datetime.utcnow() + db.session.commit() + + # Set cache to prevent indexing the same document multiple times + redis_client.setex(indexing_cache_key, 600, 1) + + add_document_to_index_task.delay(document_id) + + return {'result': 'success'}, 200 + + elif action == "disable": + if not document.enabled: + raise InvalidActionError('Document already disabled.') + + document.enabled = False + document.disabled_at = datetime.utcnow() + document.disabled_by = current_user.id + document.updated_at = datetime.utcnow() + db.session.commit() + + # Set cache to prevent indexing the same document multiple times + redis_client.setex(indexing_cache_key, 600, 1) + + remove_document_from_index_task.delay(document_id) + + return {'result': 'success'}, 200 + + elif action == "archive": + if document.archived: + raise InvalidActionError('Document already archived.') + + document.archived = True + document.archived_at = datetime.utcnow() + document.archived_by = current_user.id + document.updated_at = datetime.utcnow() + db.session.commit() + + if document.enabled: + # Set cache to prevent indexing the same document multiple times + redis_client.setex(indexing_cache_key, 600, 1) + + remove_document_from_index_task.delay(document_id) + + return {'result': 'success'}, 200 + else: + raise InvalidActionError() + + +class DocumentPauseApi(DocumentResource): + def patch(self, dataset_id, document_id): + """pause document.""" + dataset_id = str(dataset_id) + document_id = str(document_id) + + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise NotFound('Dataset not found.') + + document = DocumentService.get_document(dataset.id, document_id) + + # 404 if document not found + if document is None: + raise NotFound("Document Not Exists.") + + # 403 if document is archived + if DocumentService.check_archived(document): + raise ArchivedDocumentImmutableError() + + try: + # pause document + DocumentService.pause_document(document) + except services.errors.document.DocumentIndexingError: + raise DocumentIndexingError('Cannot pause completed document.') + + return {'result': 'success'}, 204 + + +class DocumentRecoverApi(DocumentResource): + def patch(self, dataset_id, document_id): + """recover document.""" + dataset_id = str(dataset_id) + document_id = str(document_id) + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise NotFound('Dataset not found.') + document = DocumentService.get_document(dataset.id, document_id) + + # 404 if document not found + if document is None: + raise NotFound("Document Not Exists.") + + # 403 if document is archived + if DocumentService.check_archived(document): + raise ArchivedDocumentImmutableError() + try: + # pause document + DocumentService.recover_document(document) + except services.errors.document.DocumentIndexingError: + raise DocumentIndexingError('Document is not in paused status.') + + return {'result': 'success'}, 204 + + +api.add_resource(GetProcessRuleApi, '/datasets/process-rule') +api.add_resource(DatasetDocumentListApi, + '/datasets//documents') +api.add_resource(DatasetInitApi, + '/datasets/init') +api.add_resource(DocumentIndexingEstimateApi, + '/datasets//documents//indexing-estimate') +api.add_resource(DocumentIndexingStatusApi, + '/datasets//documents//indexing-status') +api.add_resource(DocumentDetailApi, + '/datasets//documents/') +api.add_resource(DocumentProcessingApi, + '/datasets//documents//processing/') +api.add_resource(DocumentDeleteApi, + '/datasets//documents/') +api.add_resource(DocumentMetadataApi, + '/datasets//documents//metadata') +api.add_resource(DocumentStatusApi, + '/datasets//documents//status/') +api.add_resource(DocumentPauseApi, '/datasets//documents//processing/pause') +api.add_resource(DocumentRecoverApi, '/datasets//documents//processing/resume') diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py new file mode 100644 index 0000000000..2c77a44c97 --- /dev/null +++ b/api/controllers/console/datasets/datasets_segments.py @@ -0,0 +1,203 @@ +# -*- coding:utf-8 -*- +from datetime import datetime + +from flask_login import login_required, current_user +from flask_restful import Resource, reqparse, fields, marshal +from werkzeug.exceptions import NotFound, Forbidden + +import services +from controllers.console import api +from controllers.console.datasets.error import InvalidActionError +from controllers.console.setup import setup_required +from controllers.console.wraps import account_initialization_required +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.dataset import DocumentSegment + +from libs.helper import TimestampField +from services.dataset_service import DatasetService, DocumentService +from tasks.add_segment_to_index_task import add_segment_to_index_task +from tasks.remove_segment_from_index_task import remove_segment_from_index_task + +segment_fields = { + 'id': fields.String, + 'position': fields.Integer, + 'document_id': fields.String, + 'content': fields.String, + 'word_count': fields.Integer, + 'tokens': fields.Integer, + 'keywords': fields.List(fields.String), + 'index_node_id': fields.String, + 'index_node_hash': fields.String, + 'hit_count': fields.Integer, + 'enabled': fields.Boolean, + 'disabled_at': TimestampField, + 'disabled_by': fields.String, + 'status': fields.String, + 'created_by': fields.String, + 'created_at': TimestampField, + 'indexing_at': TimestampField, + 'completed_at': TimestampField, + 'error': fields.String, + 'stopped_at': TimestampField +} + +segment_list_response = { + 'data': fields.List(fields.Nested(segment_fields)), + 'has_more': fields.Boolean, + 'limit': fields.Integer +} + + +class DatasetDocumentSegmentListApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, dataset_id, document_id): + dataset_id = str(dataset_id) + document_id = str(document_id) + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise NotFound('Dataset not found.') + + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + + document = DocumentService.get_document(dataset_id, document_id) + + if not document: + raise NotFound('Document not found.') + + parser = reqparse.RequestParser() + parser.add_argument('last_id', type=str, default=None, location='args') + parser.add_argument('limit', type=int, default=20, location='args') + parser.add_argument('status', type=str, + action='append', default=[], location='args') + parser.add_argument('hit_count_gte', type=int, + default=None, location='args') + parser.add_argument('enabled', type=str, default='all', location='args') + args = parser.parse_args() + + last_id = args['last_id'] + limit = min(args['limit'], 100) + status_list = args['status'] + hit_count_gte = args['hit_count_gte'] + + query = DocumentSegment.query.filter( + DocumentSegment.document_id == str(document_id), + DocumentSegment.tenant_id == current_user.current_tenant_id + ) + + if last_id is not None: + last_segment = DocumentSegment.query.get(str(last_id)) + if last_segment: + query = query.filter( + DocumentSegment.position > last_segment.position) + else: + return {'data': [], 'has_more': False, 'limit': limit}, 200 + + if status_list: + query = query.filter(DocumentSegment.status.in_(status_list)) + + if hit_count_gte is not None: + query = query.filter(DocumentSegment.hit_count >= hit_count_gte) + + if args['enabled'].lower() != 'all': + if args['enabled'].lower() == 'true': + query = query.filter(DocumentSegment.enabled == True) + elif args['enabled'].lower() == 'false': + query = query.filter(DocumentSegment.enabled == False) + + total = query.count() + segments = query.order_by(DocumentSegment.position).limit(limit + 1).all() + + has_more = False + if len(segments) > limit: + has_more = True + segments = segments[:-1] + + return { + 'data': marshal(segments, segment_fields), + 'has_more': has_more, + 'limit': limit, + 'total': total + }, 200 + + +class DatasetDocumentSegmentApi(Resource): + @setup_required + @login_required + @account_initialization_required + def patch(self, dataset_id, segment_id, action): + dataset_id = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise NotFound('Dataset not found.') + + # The role of the current user in the ta table must be admin or owner + if current_user.current_tenant.current_role not in ['admin', 'owner']: + raise Forbidden() + + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + + segment = DocumentSegment.query.filter( + DocumentSegment.id == str(segment_id), + DocumentSegment.tenant_id == current_user.current_tenant_id + ).first() + + if not segment: + raise NotFound('Segment not found.') + + document_indexing_cache_key = 'document_{}_indexing'.format(segment.document_id) + cache_result = redis_client.get(document_indexing_cache_key) + if cache_result is not None: + raise InvalidActionError("Document is being indexed, please try again later") + + indexing_cache_key = 'segment_{}_indexing'.format(segment.id) + cache_result = redis_client.get(indexing_cache_key) + if cache_result is not None: + raise InvalidActionError("Segment is being indexed, please try again later") + + if action == "enable": + if segment.enabled: + raise InvalidActionError("Segment is already enabled.") + + segment.enabled = True + segment.disabled_at = None + segment.disabled_by = None + db.session.commit() + + # Set cache to prevent indexing the same segment multiple times + redis_client.setex(indexing_cache_key, 600, 1) + + add_segment_to_index_task.delay(segment.id) + + return {'result': 'success'}, 200 + elif action == "disable": + if not segment.enabled: + raise InvalidActionError("Segment is already disabled.") + + segment.enabled = False + segment.disabled_at = datetime.utcnow() + segment.disabled_by = current_user.id + db.session.commit() + + # Set cache to prevent indexing the same segment multiple times + redis_client.setex(indexing_cache_key, 600, 1) + + remove_segment_from_index_task.delay(segment.id) + + return {'result': 'success'}, 200 + else: + raise InvalidActionError() + + +api.add_resource(DatasetDocumentSegmentListApi, + '/datasets//documents//segments') +api.add_resource(DatasetDocumentSegmentApi, + '/datasets//segments//') diff --git a/api/controllers/console/datasets/error.py b/api/controllers/console/datasets/error.py new file mode 100644 index 0000000000..014822d565 --- /dev/null +++ b/api/controllers/console/datasets/error.py @@ -0,0 +1,73 @@ +from libs.exception import BaseHTTPException + + +class NoFileUploadedError(BaseHTTPException): + error_code = 'no_file_uploaded' + description = "No file uploaded." + code = 400 + + +class TooManyFilesError(BaseHTTPException): + error_code = 'too_many_files' + description = "Only one file is allowed." + code = 400 + + +class FileTooLargeError(BaseHTTPException): + error_code = 'file_too_large' + description = "File size exceeded. {message}" + code = 413 + + +class UnsupportedFileTypeError(BaseHTTPException): + error_code = 'unsupported_file_type' + description = "File type not allowed." + code = 415 + + +class HighQualityDatasetOnlyError(BaseHTTPException): + error_code = 'high_quality_dataset_only' + description = "High quality dataset only." + code = 400 + + +class DatasetNotInitializedError(BaseHTTPException): + error_code = 'dataset_not_initialized' + description = "Dataset not initialized." + code = 400 + + +class ArchivedDocumentImmutableError(BaseHTTPException): + error_code = 'archived_document_immutable' + description = "Cannot process an archived document." + code = 403 + + +class DatasetNameDuplicateError(BaseHTTPException): + error_code = 'dataset_name_duplicate' + description = "Dataset name already exists." + code = 409 + + +class InvalidActionError(BaseHTTPException): + error_code = 'invalid_action' + description = "Invalid action." + code = 400 + + +class DocumentAlreadyFinishedError(BaseHTTPException): + error_code = 'document_already_finished' + description = "Document already finished." + code = 400 + + +class DocumentIndexingError(BaseHTTPException): + error_code = 'document_indexing' + description = "Document indexing." + code = 400 + + +class InvalidMetadataError(BaseHTTPException): + error_code = 'invalid_metadata' + description = "Invalid metadata." + code = 400 diff --git a/api/controllers/console/datasets/file.py b/api/controllers/console/datasets/file.py new file mode 100644 index 0000000000..75aba5e9eb --- /dev/null +++ b/api/controllers/console/datasets/file.py @@ -0,0 +1,147 @@ +import datetime +import hashlib +import tempfile +import time +import uuid +from pathlib import Path + +from cachetools import TTLCache +from flask import request, current_app +from flask_login import login_required, current_user +from flask_restful import Resource, marshal_with, fields +from werkzeug.exceptions import NotFound + +from controllers.console import api +from controllers.console.datasets.error import NoFileUploadedError, TooManyFilesError, FileTooLargeError, \ + UnsupportedFileTypeError +from controllers.console.setup import setup_required +from controllers.console.wraps import account_initialization_required +from core.index.readers.html_parser import HTMLParser +from core.index.readers.pdf_parser import PDFParser +from extensions.ext_storage import storage +from libs.helper import TimestampField +from extensions.ext_database import db +from models.model import UploadFile + +cache = TTLCache(maxsize=None, ttl=30) + +FILE_SIZE_LIMIT = 15 * 1024 * 1024 # 15MB +ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm'] +PREVIEW_WORDS_LIMIT = 3000 + + +class FileApi(Resource): + file_fields = { + 'id': fields.String, + 'name': fields.String, + 'size': fields.Integer, + 'extension': fields.String, + 'mime_type': fields.String, + 'created_by': fields.String, + 'created_at': TimestampField, + } + + @setup_required + @login_required + @account_initialization_required + @marshal_with(file_fields) + def post(self): + + # get file from request + file = request.files['file'] + + # check file + if 'file' not in request.files: + raise NoFileUploadedError() + + if len(request.files) > 1: + raise TooManyFilesError() + + file_content = file.read() + file_size = len(file_content) + + if file_size > FILE_SIZE_LIMIT: + message = "({file_size} > {FILE_SIZE_LIMIT})" + raise FileTooLargeError(message) + + extension = file.filename.split('.')[-1] + if extension not in ALLOWED_EXTENSIONS: + raise UnsupportedFileTypeError() + + # user uuid as file name + file_uuid = str(uuid.uuid4()) + file_key = 'upload_files/' + current_user.current_tenant_id + '/' + file_uuid + '.' + extension + + # save file to storage + storage.save(file_key, file_content) + + # save file to db + config = current_app.config + upload_file = UploadFile( + tenant_id=current_user.current_tenant_id, + storage_type=config['STORAGE_TYPE'], + key=file_key, + name=file.filename, + size=file_size, + extension=extension, + mime_type=file.mimetype, + created_by=current_user.id, + created_at=datetime.datetime.utcnow(), + used=False, + hash=hashlib.sha3_256(file_content).hexdigest() + ) + + db.session.add(upload_file) + db.session.commit() + + return upload_file, 201 + + +class FilePreviewApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, file_id): + file_id = str(file_id) + + key = file_id + request.path + cached_response = cache.get(key) + if cached_response and time.time() - cached_response['timestamp'] < cache.ttl: + return cached_response['response'] + + upload_file = db.session.query(UploadFile) \ + .filter(UploadFile.id == file_id) \ + .first() + + if not upload_file: + raise NotFound("File not found") + + # extract text from file + extension = upload_file.extension + if extension not in ALLOWED_EXTENSIONS: + raise UnsupportedFileTypeError() + + with tempfile.TemporaryDirectory() as temp_dir: + suffix = Path(upload_file.key).suffix + filepath = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" + storage.download(upload_file.key, filepath) + + if extension == 'pdf': + parser = PDFParser({'upload_file': upload_file}) + text = parser.parse_file(Path(filepath)) + elif extension in ['html', 'htm']: + # Use BeautifulSoup to extract text + parser = HTMLParser() + text = parser.parse_file(Path(filepath)) + else: + # ['txt', 'markdown', 'md'] + with open(filepath, "rb") as fp: + data = fp.read() + text = data.decode(encoding='utf-8').strip() if data else '' + + text = text[0:PREVIEW_WORDS_LIMIT] if text else '' + return {'content': text} + + +api.add_resource(FileApi, '/files/upload') +api.add_resource(FilePreviewApi, '/files//preview') diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py new file mode 100644 index 0000000000..16bb571df3 --- /dev/null +++ b/api/controllers/console/datasets/hit_testing.py @@ -0,0 +1,100 @@ +import logging + +from flask_login import login_required, current_user +from flask_restful import Resource, reqparse, marshal, fields +from werkzeug.exceptions import InternalServerError, NotFound, Forbidden + +import services +from controllers.console import api +from controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedError +from controllers.console.setup import setup_required +from controllers.console.wraps import account_initialization_required +from libs.helper import TimestampField +from services.dataset_service import DatasetService +from services.hit_testing_service import HitTestingService + +document_fields = { + 'id': fields.String, + 'data_source_type': fields.String, + 'name': fields.String, + 'doc_type': fields.String, +} + +segment_fields = { + 'id': fields.String, + 'position': fields.Integer, + 'document_id': fields.String, + 'content': fields.String, + 'word_count': fields.Integer, + 'tokens': fields.Integer, + 'keywords': fields.List(fields.String), + 'index_node_id': fields.String, + 'index_node_hash': fields.String, + 'hit_count': fields.Integer, + 'enabled': fields.Boolean, + 'disabled_at': TimestampField, + 'disabled_by': fields.String, + 'status': fields.String, + 'created_by': fields.String, + 'created_at': TimestampField, + 'indexing_at': TimestampField, + 'completed_at': TimestampField, + 'error': fields.String, + 'stopped_at': TimestampField, + 'document': fields.Nested(document_fields), +} + +hit_testing_record_fields = { + 'segment': fields.Nested(segment_fields), + 'score': fields.Float, + 'tsne_position': fields.Raw +} + + +class HitTestingApi(Resource): + + @setup_required + @login_required + @account_initialization_required + def post(self, dataset_id): + dataset_id_str = str(dataset_id) + + dataset = DatasetService.get_dataset(dataset_id_str) + if dataset is None: + raise NotFound("Dataset not found.") + + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + + # only high quality dataset can be used for hit testing + if dataset.indexing_technique != 'high_quality': + raise HighQualityDatasetOnlyError() + + parser = reqparse.RequestParser() + parser.add_argument('query', type=str, location='json') + args = parser.parse_args() + + query = args['query'] + + if not query or len(query) > 250: + raise ValueError('Query is required and cannot exceed 250 characters') + + try: + response = HitTestingService.retrieve( + dataset=dataset, + query=query, + account=current_user, + limit=10, + ) + + return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)} + except services.errors.index.IndexNotInitializedError: + raise DatasetNotInitializedError() + except Exception as e: + logging.exception("Hit testing failed.") + raise InternalServerError(str(e)) + + +api.add_resource(HitTestingApi, '/datasets//hit-testing') diff --git a/api/controllers/console/error.py b/api/controllers/console/error.py new file mode 100644 index 0000000000..3040423d71 --- /dev/null +++ b/api/controllers/console/error.py @@ -0,0 +1,19 @@ +from libs.exception import BaseHTTPException + + +class AlreadySetupError(BaseHTTPException): + error_code = 'already_setup' + description = "Application already setup." + code = 403 + + +class NotSetupError(BaseHTTPException): + error_code = 'not_setup' + description = "Application not setup." + code = 401 + + +class AccountNotLinkTenantError(BaseHTTPException): + error_code = 'account_not_link_tenant' + description = "Account not link tenant." + code = 403 diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py new file mode 100644 index 0000000000..4677a2075b --- /dev/null +++ b/api/controllers/console/setup.py @@ -0,0 +1,93 @@ +# -*- coding:utf-8 -*- +from functools import wraps + +import flask_login +from flask import request, current_app +from flask_restful import Resource, reqparse + +from extensions.ext_database import db +from models.model import DifySetup +from services.account_service import AccountService, TenantService, RegisterService + +from libs.helper import email, str_len +from libs.password import valid_password + +from . import api +from .error import AlreadySetupError, NotSetupError +from .wraps import only_edition_self_hosted + + +class SetupApi(Resource): + + @only_edition_self_hosted + def get(self): + setup_status = get_setup_status() + if setup_status: + return { + 'step': 'finished', + 'setup_at': setup_status.setup_at.isoformat() + } + return {'step': 'not_start'} + + @only_edition_self_hosted + def post(self): + # is set up + if get_setup_status(): + raise AlreadySetupError() + + # is tenant created + tenant_count = TenantService.get_tenant_count() + if tenant_count > 0: + raise AlreadySetupError() + + parser = reqparse.RequestParser() + parser.add_argument('email', type=email, + required=True, location='json') + parser.add_argument('name', type=str_len( + 30), required=True, location='json') + parser.add_argument('password', type=valid_password, + required=True, location='json') + args = parser.parse_args() + + # Register + account = RegisterService.register( + email=args['email'], + name=args['name'], + password=args['password'] + ) + + setup() + + # Login + flask_login.login_user(account) + AccountService.update_last_login(account, request) + + return {'result': 'success'}, 201 + + +def setup(): + dify_setup = DifySetup( + version=current_app.config['CURRENT_VERSION'] + ) + db.session.add(dify_setup) + + +def setup_required(view): + @wraps(view) + def decorated(*args, **kwargs): + # check setup + if not get_setup_status(): + raise NotSetupError() + + return view(*args, **kwargs) + + return decorated + + +def get_setup_status(): + if current_app.config['EDITION'] == 'SELF_HOSTED': + return DifySetup.query.first() + else: + return True + +api.add_resource(SetupApi, '/setup') diff --git a/api/controllers/console/version.py b/api/controllers/console/version.py new file mode 100644 index 0000000000..0e6e75c361 --- /dev/null +++ b/api/controllers/console/version.py @@ -0,0 +1,39 @@ +# -*- coding:utf-8 -*- + +import json +import logging + +import requests +from flask import current_app +from flask_restful import reqparse, Resource +from werkzeug.exceptions import InternalServerError + +from . import api + + +class VersionApi(Resource): + + def get(self): + parser = reqparse.RequestParser() + parser.add_argument('current_version', type=str, required=True, location='args') + args = parser.parse_args() + check_update_url = current_app.config['CHECK_UPDATE_URL'] + + try: + response = requests.get(check_update_url, { + 'current_version': args.get('current_version') + }) + except Exception as error: + logging.exception("Check update error.") + raise InternalServerError() + + content = json.loads(response.content) + return { + 'version': content['version'], + 'release_date': content['releaseDate'], + 'release_notes': content['releaseNotes'], + 'can_auto_update': content['canAutoUpdate'] + } + + +api.add_resource(VersionApi, '/version') diff --git a/api/controllers/console/workspace/__init__.py b/api/controllers/console/workspace/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py new file mode 100644 index 0000000000..0890cd0468 --- /dev/null +++ b/api/controllers/console/workspace/account.py @@ -0,0 +1,263 @@ +# -*- coding:utf-8 -*- +from datetime import datetime + +import pytz +from flask import current_app, request +from flask_login import login_required, current_user +from flask_restful import Resource, reqparse, fields, marshal_with + +from controllers.console import api +from controllers.console.setup import setup_required +from controllers.console.workspace.error import AccountAlreadyInitedError, InvalidInvitationCodeError, \ + RepeatPasswordNotMatchError +from controllers.console.wraps import account_initialization_required +from libs.helper import TimestampField, supported_language, timezone +from extensions.ext_database import db +from models.account import InvitationCode, AccountIntegrate +from services.account_service import AccountService + + +account_fields = { + 'id': fields.String, + 'name': fields.String, + 'avatar': fields.String, + 'email': fields.String, + 'interface_language': fields.String, + 'interface_theme': fields.String, + 'timezone': fields.String, + 'last_login_at': TimestampField, + 'last_login_ip': fields.String, + 'created_at': TimestampField +} + + +class AccountInitApi(Resource): + + @setup_required + @login_required + def post(self): + account = current_user + + if account.status == 'active': + raise AccountAlreadyInitedError() + + parser = reqparse.RequestParser() + + if current_app.config['EDITION'] == 'CLOUD': + parser.add_argument('invitation_code', type=str, location='json') + + parser.add_argument( + 'interface_language', type=supported_language, required=True, location='json') + parser.add_argument('timezone', type=timezone, + required=True, location='json') + args = parser.parse_args() + + if current_app.config['EDITION'] == 'CLOUD': + if not args['invitation_code']: + raise ValueError('invitation_code is required') + + # check invitation code + invitation_code = db.session.query(InvitationCode).filter( + InvitationCode.code == args['invitation_code'], + InvitationCode.status == 'unused', + ).first() + + if not invitation_code: + raise InvalidInvitationCodeError() + + invitation_code.status = 'used' + invitation_code.used_at = datetime.utcnow() + invitation_code.used_by_tenant_id = account.current_tenant_id + invitation_code.used_by_account_id = account.id + + account.interface_language = args['interface_language'] + account.timezone = args['timezone'] + account.interface_theme = 'light' + account.status = 'active' + account.initialized_at = datetime.utcnow() + db.session.commit() + + return {'result': 'success'} + + +class AccountProfileApi(Resource): + @setup_required + @login_required + @account_initialization_required + @marshal_with(account_fields) + def get(self): + return current_user + + +class AccountNameApi(Resource): + @setup_required + @login_required + @account_initialization_required + @marshal_with(account_fields) + def post(self): + parser = reqparse.RequestParser() + parser.add_argument('name', type=str, required=True, location='json') + args = parser.parse_args() + + # Validate account name length + if len(args['name']) < 3 or len(args['name']) > 30: + raise ValueError( + "Account name must be between 3 and 30 characters.") + + updated_account = AccountService.update_account(current_user, name=args['name']) + + return updated_account + + +class AccountAvatarApi(Resource): + @setup_required + @login_required + @account_initialization_required + @marshal_with(account_fields) + def post(self): + parser = reqparse.RequestParser() + parser.add_argument('avatar', type=str, required=True, location='json') + args = parser.parse_args() + + updated_account = AccountService.update_account(current_user, avatar=args['avatar']) + + return updated_account + + +class AccountInterfaceLanguageApi(Resource): + @setup_required + @login_required + @account_initialization_required + @marshal_with(account_fields) + def post(self): + parser = reqparse.RequestParser() + parser.add_argument( + 'interface_language', type=supported_language, required=True, location='json') + args = parser.parse_args() + + updated_account = AccountService.update_account(current_user, interface_language=args['interface_language']) + + return updated_account + + +class AccountInterfaceThemeApi(Resource): + @setup_required + @login_required + @account_initialization_required + @marshal_with(account_fields) + def post(self): + parser = reqparse.RequestParser() + parser.add_argument('interface_theme', type=str, choices=[ + 'light', 'dark'], required=True, location='json') + args = parser.parse_args() + + updated_account = AccountService.update_account(current_user, interface_theme=args['interface_theme']) + + return updated_account + + +class AccountTimezoneApi(Resource): + @setup_required + @login_required + @account_initialization_required + @marshal_with(account_fields) + def post(self): + parser = reqparse.RequestParser() + parser.add_argument('timezone', type=str, + required=True, location='json') + args = parser.parse_args() + + # Validate timezone string, e.g. America/New_York, Asia/Shanghai + if args['timezone'] not in pytz.all_timezones: + raise ValueError("Invalid timezone string.") + + updated_account = AccountService.update_account(current_user, timezone=args['timezone']) + + return updated_account + + +class AccountPasswordApi(Resource): + @setup_required + @login_required + @account_initialization_required + @marshal_with(account_fields) + def post(self): + parser = reqparse.RequestParser() + parser.add_argument('password', type=str, + required=False, location='json') + parser.add_argument('new_password', type=str, + required=True, location='json') + parser.add_argument('repeat_new_password', type=str, + required=True, location='json') + args = parser.parse_args() + + if args['new_password'] != args['repeat_new_password']: + raise RepeatPasswordNotMatchError() + + AccountService.update_account_password( + current_user, args['password'], args['new_password']) + + return {"result": "success"} + + +class AccountIntegrateApi(Resource): + integrate_fields = { + 'provider': fields.String, + 'created_at': TimestampField, + 'is_bound': fields.Boolean, + 'link': fields.String + } + + integrate_list_fields = { + 'data': fields.List(fields.Nested(integrate_fields)), + } + + @setup_required + @login_required + @account_initialization_required + @marshal_with(integrate_list_fields) + def get(self): + account = current_user + + account_integrates = db.session.query(AccountIntegrate).filter( + AccountIntegrate.account_id == account.id).all() + + base_url = request.url_root.rstrip('/') + oauth_base_path = "/console/api/oauth/login" + providers = ["github", "google"] + + integrate_data = [] + for provider in providers: + existing_integrate = next((ai for ai in account_integrates if ai.provider == provider), None) + if existing_integrate: + integrate_data.append({ + 'id': existing_integrate.id, + 'provider': provider, + 'created_at': existing_integrate.created_at, + 'is_bound': True, + 'link': None + }) + else: + integrate_data.append({ + 'id': None, + 'provider': provider, + 'created_at': None, + 'is_bound': False, + 'link': f'{base_url}{oauth_base_path}/{provider}' + }) + + return {'data': integrate_data} + + +# Register API resources +api.add_resource(AccountInitApi, '/account/init') +api.add_resource(AccountProfileApi, '/account/profile') +api.add_resource(AccountNameApi, '/account/name') +api.add_resource(AccountAvatarApi, '/account/avatar') +api.add_resource(AccountInterfaceLanguageApi, '/account/interface-language') +api.add_resource(AccountInterfaceThemeApi, '/account/interface-theme') +api.add_resource(AccountTimezoneApi, '/account/timezone') +api.add_resource(AccountPasswordApi, '/account/password') +api.add_resource(AccountIntegrateApi, '/account/integrates') +# api.add_resource(AccountEmailApi, '/account/email') +# api.add_resource(AccountEmailVerifyApi, '/account/email-verify') diff --git a/api/controllers/console/workspace/error.py b/api/controllers/console/workspace/error.py new file mode 100644 index 0000000000..c5e3a3fb6a --- /dev/null +++ b/api/controllers/console/workspace/error.py @@ -0,0 +1,31 @@ +from libs.exception import BaseHTTPException + + +class RepeatPasswordNotMatchError(BaseHTTPException): + error_code = 'repeat_password_not_match' + description = "New password and repeat password does not match." + code = 400 + + +class ProviderRequestFailedError(BaseHTTPException): + error_code = 'provider_request_failed' + description = None + code = 400 + + +class InvalidInvitationCodeError(BaseHTTPException): + error_code = 'invalid_invitation_code' + description = "Invalid invitation code." + code = 400 + + +class AccountAlreadyInitedError(BaseHTTPException): + error_code = 'account_already_inited' + description = "Account already inited." + code = 400 + + +class AccountNotInitializedError(BaseHTTPException): + error_code = 'account_not_initialized' + description = "Account not initialized." + code = 400 diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py new file mode 100644 index 0000000000..e0fc2bc19f --- /dev/null +++ b/api/controllers/console/workspace/members.py @@ -0,0 +1,141 @@ +# -*- coding:utf-8 -*- + +from flask_login import login_required, current_user +from flask_restful import Resource, reqparse, marshal_with, abort, fields, marshal + +import services +from controllers.console import api +from controllers.console.setup import setup_required +from controllers.console.wraps import account_initialization_required +from libs.helper import TimestampField +from extensions.ext_database import db +from models.account import Account, TenantAccountJoin +from services.account_service import TenantService, RegisterService + +account_fields = { + 'id': fields.String, + 'name': fields.String, + 'avatar': fields.String, + 'email': fields.String, + 'last_login_at': TimestampField, + 'created_at': TimestampField, + 'role': fields.String, + 'status': fields.String, +} + +account_list_fields = { + 'accounts': fields.List(fields.Nested(account_fields)) +} + + +class MemberListApi(Resource): + """List all members of current tenant.""" + + @setup_required + @login_required + @account_initialization_required + @marshal_with(account_list_fields) + def get(self): + members = TenantService.get_tenant_members(current_user.current_tenant) + return {'result': 'success', 'accounts': members}, 200 + + +class MemberInviteEmailApi(Resource): + """Invite a new member by email.""" + + @setup_required + @login_required + @account_initialization_required + def post(self): + parser = reqparse.RequestParser() + parser.add_argument('email', type=str, required=True, location='json') + parser.add_argument('role', type=str, required=True, default='admin', location='json') + args = parser.parse_args() + + invitee_email = args['email'] + invitee_role = args['role'] + if invitee_role not in ['admin', 'normal']: + return {'code': 'invalid-role', 'message': 'Invalid role'}, 400 + + inviter = current_user + + try: + RegisterService.invite_new_member(inviter.current_tenant, invitee_email, role=invitee_role, inviter=inviter) + account = db.session.query(Account, TenantAccountJoin.role).join( + TenantAccountJoin, Account.id == TenantAccountJoin.account_id + ).filter(Account.email == args['email']).first() + account, role = account + account = marshal(account, account_fields) + account['role'] = role + except services.errors.account.CannotOperateSelfError as e: + return {'code': 'cannot-operate-self', 'message': str(e)}, 400 + except services.errors.account.NoPermissionError as e: + return {'code': 'forbidden', 'message': str(e)}, 403 + except services.errors.account.AccountAlreadyInTenantError as e: + return {'code': 'email-taken', 'message': str(e)}, 409 + except Exception as e: + return {'code': 'unexpected-error', 'message': str(e)}, 500 + + # todo:413 + + return {'result': 'success', 'account': account}, 201 + + +class MemberCancelInviteApi(Resource): + """Cancel an invitation by member id.""" + + @setup_required + @login_required + @account_initialization_required + def delete(self, member_id): + member = Account.query.get(str(member_id)) + if not member: + abort(404) + + try: + TenantService.remove_member_from_tenant(current_user.current_tenant, member, current_user) + except services.errors.account.CannotOperateSelfError as e: + return {'code': 'cannot-operate-self', 'message': str(e)}, 400 + except services.errors.account.NoPermissionError as e: + return {'code': 'forbidden', 'message': str(e)}, 403 + except services.errors.account.MemberNotInTenantError as e: + return {'code': 'member-not-found', 'message': str(e)}, 404 + except Exception as e: + raise ValueError(str(e)) + + return {'result': 'success'}, 204 + + +class MemberUpdateRoleApi(Resource): + """Update member role.""" + + @setup_required + @login_required + @account_initialization_required + def put(self, member_id): + parser = reqparse.RequestParser() + parser.add_argument('role', type=str, required=True, location='json') + args = parser.parse_args() + new_role = args['role'] + + if new_role not in ['admin', 'normal', 'owner']: + return {'code': 'invalid-role', 'message': 'Invalid role'}, 400 + + member = Account.query.get(str(member_id)) + if not member: + abort(404) + + try: + TenantService.update_member_role(current_user.current_tenant, member, new_role, current_user) + except Exception as e: + raise ValueError(str(e)) + + # todo: 403 + + return {'result': 'success'} + + +api.add_resource(MemberListApi, '/workspaces/current/members') +api.add_resource(MemberInviteEmailApi, '/workspaces/current/members/invite-email') +api.add_resource(MemberCancelInviteApi, '/workspaces/current/members/') +api.add_resource(MemberUpdateRoleApi, '/workspaces/current/members//update-role') diff --git a/api/controllers/console/workspace/providers.py b/api/controllers/console/workspace/providers.py new file mode 100644 index 0000000000..bc6b8320af --- /dev/null +++ b/api/controllers/console/workspace/providers.py @@ -0,0 +1,246 @@ +# -*- coding:utf-8 -*- +import base64 +import json +import logging + +from flask_login import login_required, current_user +from flask_restful import Resource, reqparse, abort +from werkzeug.exceptions import Forbidden + +from controllers.console import api +from controllers.console.setup import setup_required +from controllers.console.wraps import account_initialization_required +from core.llm.provider.errors import ValidateFailedError +from extensions.ext_database import db +from libs import rsa +from models.provider import Provider, ProviderType, ProviderName +from services.provider_service import ProviderService + + +class ProviderListApi(Resource): + + @setup_required + @login_required + @account_initialization_required + def get(self): + tenant_id = current_user.current_tenant_id + + """ + If the type is AZURE_OPENAI, decode and return the four fields of azure_api_type, azure_api_version:, + azure_api_base, azure_api_key as an object, where azure_api_key displays the first 6 bits in plaintext, and the + rest is replaced by * and the last two bits are displayed in plaintext + + If the type is other, decode and return the Token field directly, the field displays the first 6 bits in + plaintext, the rest is replaced by * and the last two bits are displayed in plaintext + """ + + ProviderService.init_supported_provider(current_user.current_tenant, "cloud") + providers = Provider.query.filter_by(tenant_id=tenant_id).all() + + provider_list = [ + { + 'provider_name': p.provider_name, + 'provider_type': p.provider_type, + 'is_valid': p.is_valid, + 'last_used': p.last_used, + 'is_enabled': p.is_enabled, + **({ + 'quota_type': p.quota_type, + 'quota_limit': p.quota_limit, + 'quota_used': p.quota_used + } if p.provider_type == ProviderType.SYSTEM.value else {}), + 'token': ProviderService.get_obfuscated_api_key(current_user.current_tenant, + ProviderName(p.provider_name)) + } + for p in providers + ] + + return provider_list + + +class ProviderTokenApi(Resource): + + @setup_required + @login_required + @account_initialization_required + def post(self, provider): + if provider not in [p.value for p in ProviderName]: + abort(404) + + # The role of the current user in the ta table must be admin or owner + if current_user.current_tenant.current_role not in ['admin', 'owner']: + logging.log(logging.ERROR, + f'User {current_user.id} is not authorized to update provider token, current_role is {current_user.current_tenant.current_role}') + raise Forbidden() + + parser = reqparse.RequestParser() + + parser.add_argument('token', type=ProviderService.get_token_type( + tenant=current_user.current_tenant, + provider_name=ProviderName(provider) + ), required=True, nullable=False, location='json') + + args = parser.parse_args() + + if not args['token']: + raise ValueError('Token is empty') + + try: + ProviderService.validate_provider_configs( + tenant=current_user.current_tenant, + provider_name=ProviderName(provider), + configs=args['token'] + ) + token_is_valid = True + except ValidateFailedError: + token_is_valid = False + + tenant = current_user.current_tenant + + base64_encrypted_token = ProviderService.get_encrypted_token( + tenant=current_user.current_tenant, + provider_name=ProviderName(provider), + configs=args['token'] + ) + + provider_model = Provider.query.filter_by(tenant_id=tenant.id, provider_name=provider, + provider_type=ProviderType.CUSTOM.value).first() + + # Only allow updating token for CUSTOM provider type + if provider_model: + provider_model.encrypted_config = base64_encrypted_token + provider_model.is_valid = token_is_valid + else: + provider_model = Provider(tenant_id=tenant.id, provider_name=provider, + provider_type=ProviderType.CUSTOM.value, + encrypted_config=base64_encrypted_token, + is_valid=token_is_valid) + db.session.add(provider_model) + + db.session.commit() + + if provider in [ProviderName.ANTHROPIC.value, ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value, + ProviderName.HUGGINGFACEHUB.value]: + return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'}, 201 + + return {'result': 'success'}, 201 + + +class ProviderTokenValidateApi(Resource): + + @setup_required + @login_required + @account_initialization_required + def post(self, provider): + if provider not in [p.value for p in ProviderName]: + abort(404) + + parser = reqparse.RequestParser() + parser.add_argument('token', type=ProviderService.get_token_type( + tenant=current_user.current_tenant, + provider_name=ProviderName(provider) + ), required=True, nullable=False, location='json') + args = parser.parse_args() + + # todo: remove this when the provider is supported + if provider in [ProviderName.ANTHROPIC.value, ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value, + ProviderName.HUGGINGFACEHUB.value]: + return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'} + + result = True + error = None + + try: + ProviderService.validate_provider_configs( + tenant=current_user.current_tenant, + provider_name=ProviderName(provider), + configs=args['token'] + ) + except ValidateFailedError as e: + result = False + error = str(e) + + response = {'result': 'success' if result else 'error'} + + if not result: + response['error'] = error + + return response + + +class ProviderSystemApi(Resource): + + @setup_required + @login_required + @account_initialization_required + def put(self, provider): + if provider not in [p.value for p in ProviderName]: + abort(404) + + parser = reqparse.RequestParser() + parser.add_argument('is_enabled', type=bool, required=True, location='json') + args = parser.parse_args() + + tenant = current_user.current_tenant_id + + provider_model = Provider.query.filter_by(tenant_id=tenant.id, provider_name=provider).first() + + if provider_model and provider_model.provider_type == ProviderType.SYSTEM.value: + provider_model.is_valid = args['is_enabled'] + db.session.commit() + elif not provider_model: + ProviderService.create_system_provider(tenant, provider, args['is_enabled']) + else: + abort(403) + + return {'result': 'success'} + + @setup_required + @login_required + @account_initialization_required + def get(self, provider): + if provider not in [p.value for p in ProviderName]: + abort(404) + + # The role of the current user in the ta table must be admin or owner + if current_user.current_tenant.current_role not in ['admin', 'owner']: + raise Forbidden() + + provider_model = db.session.query(Provider).filter(Provider.tenant_id == current_user.current_tenant_id, + Provider.provider_name == provider, + Provider.provider_type == ProviderType.SYSTEM.value).first() + + system_model = None + if provider_model: + system_model = { + 'result': 'success', + 'provider': { + 'provider_name': provider_model.provider_name, + 'provider_type': provider_model.provider_type, + 'is_valid': provider_model.is_valid, + 'last_used': provider_model.last_used, + 'is_enabled': provider_model.is_enabled, + 'quota_type': provider_model.quota_type, + 'quota_limit': provider_model.quota_limit, + 'quota_used': provider_model.quota_used + } + } + else: + abort(404) + + return system_model + + +api.add_resource(ProviderTokenApi, '/providers//token', + endpoint='current_providers_token') # Deprecated +api.add_resource(ProviderTokenValidateApi, '/providers//token-validate', + endpoint='current_providers_token_validate') # Deprecated + +api.add_resource(ProviderTokenApi, '/workspaces/current/providers//token', + endpoint='workspaces_current_providers_token') # PUT for updating provider token +api.add_resource(ProviderTokenValidateApi, '/workspaces/current/providers//token-validate', + endpoint='workspaces_current_providers_token_validate') # POST for validating provider token + +api.add_resource(ProviderListApi, '/workspaces/current/providers') # GET for getting providers list +api.add_resource(ProviderSystemApi, '/workspaces/current/providers//system', + endpoint='workspaces_current_providers_system') # GET for getting provider quota, PUT for updating provider status diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py new file mode 100644 index 0000000000..2ad457c79b --- /dev/null +++ b/api/controllers/console/workspace/workspace.py @@ -0,0 +1,97 @@ +# -*- coding:utf-8 -*- +import logging + +from flask import request +from flask_login import login_required, current_user +from flask_restful import Resource, fields, marshal_with, reqparse, marshal + +from controllers.console import api +from controllers.console.setup import setup_required +from controllers.console.error import AccountNotLinkTenantError +from controllers.console.wraps import account_initialization_required +from libs.helper import TimestampField +from extensions.ext_database import db +from models.account import Tenant +from services.account_service import TenantService +from services.workspace_service import WorkspaceService + +provider_fields = { + 'provider_name': fields.String, + 'provider_type': fields.String, + 'is_valid': fields.Boolean, + 'token_is_set': fields.Boolean, +} + +tenant_fields = { + 'id': fields.String, + 'name': fields.String, + 'plan': fields.String, + 'status': fields.String, + 'created_at': TimestampField, + 'role': fields.String, + 'providers': fields.List(fields.Nested(provider_fields)), + 'in_trail': fields.Boolean, + 'trial_end_reason': fields.String, +} + +tenants_fields = { + 'id': fields.String, + 'name': fields.String, + 'plan': fields.String, + 'status': fields.String, + 'created_at': TimestampField, + 'current': fields.Boolean +} + + +class TenantListApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self): + tenants = TenantService.get_join_tenants(current_user) + + for tenant in tenants: + if tenant.id == current_user.current_tenant_id: + tenant.current = True # Set current=True for current tenant + return {'workspaces': marshal(tenants, tenants_fields)}, 200 + + +class TenantApi(Resource): + @setup_required + @login_required + @account_initialization_required + @marshal_with(tenant_fields) + def get(self): + if request.path == '/info': + logging.warning('Deprecated URL /info was used.') + + tenant = current_user.current_tenant + + return WorkspaceService.get_tenant_info(tenant), 200 + + +class SwitchWorkspaceApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self): + parser = reqparse.RequestParser() + parser.add_argument('tenant_id', type=str, required=True, location='json') + args = parser.parse_args() + + # check if tenant_id is valid, 403 if not + try: + TenantService.switch_tenant(current_user, args['tenant_id']) + except Exception: + raise AccountNotLinkTenantError("Account not link tenant") + + new_tenant = db.session.query(Tenant).get(args['tenant_id']) # Get new tenant + + return {'result': 'success', 'new_tenant': marshal(WorkspaceService.get_tenant_info(new_tenant), tenant_fields)} + + +api.add_resource(TenantListApi, '/workspaces') # GET for getting all tenants +api.add_resource(TenantApi, '/workspaces/current', endpoint='workspaces_current') # GET for getting current tenant info +api.add_resource(TenantApi, '/info', endpoint='info') # Deprecated +api.add_resource(SwitchWorkspaceApi, '/workspaces/switch') # POST for switching tenant diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py new file mode 100644 index 0000000000..41ce4f200b --- /dev/null +++ b/api/controllers/console/wraps.py @@ -0,0 +1,43 @@ +# -*- coding:utf-8 -*- +from functools import wraps + +from flask import current_app, abort +from flask_login import current_user + +from controllers.console.workspace.error import AccountNotInitializedError + + +def account_initialization_required(view): + @wraps(view) + def decorated(*args, **kwargs): + # check account initialization + account = current_user + + if account.status == 'uninitialized': + raise AccountNotInitializedError() + + return view(*args, **kwargs) + + return decorated + + +def only_edition_cloud(view): + @wraps(view) + def decorated(*args, **kwargs): + if current_app.config['EDITION'] != 'CLOUD': + abort(404) + + return view(*args, **kwargs) + + return decorated + + +def only_edition_self_hosted(view): + @wraps(view) + def decorated(*args, **kwargs): + if current_app.config['EDITION'] != 'SELF_HOSTED': + abort(404) + + return view(*args, **kwargs) + + return decorated diff --git a/api/controllers/service_api/__init__.py b/api/controllers/service_api/__init__.py new file mode 100644 index 0000000000..05318a2076 --- /dev/null +++ b/api/controllers/service_api/__init__.py @@ -0,0 +1,12 @@ +# -*- coding:utf-8 -*- +from flask import Blueprint + +from libs.external_api import ExternalApi + +bp = Blueprint('service_api', __name__, url_prefix='/v1') +api = ExternalApi(bp) + + +from .app import completion, app, conversation, message + +from .dataset import document diff --git a/api/controllers/service_api/app/__init__.py b/api/controllers/service_api/app/__init__.py new file mode 100644 index 0000000000..d8018ee385 --- /dev/null +++ b/api/controllers/service_api/app/__init__.py @@ -0,0 +1,27 @@ +from extensions.ext_database import db +from models.model import EndUser + + +def create_or_update_end_user_for_user_id(app_model, user_id): + """ + Create or update session terminal based on user ID. + """ + end_user = db.session.query(EndUser) \ + .filter( + EndUser.tenant_id == app_model.tenant_id, + EndUser.session_id == user_id, + EndUser.type == 'service_api' + ).first() + + if end_user is None: + end_user = EndUser( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + type='service_api', + is_anonymous=True, + session_id=user_id + ) + db.session.add(end_user) + db.session.commit() + + return end_user diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py new file mode 100644 index 0000000000..08532441c8 --- /dev/null +++ b/api/controllers/service_api/app/app.py @@ -0,0 +1,43 @@ +# -*- coding:utf-8 -*- +from flask_restful import fields, marshal_with + +from controllers.service_api import api +from controllers.service_api.wraps import AppApiResource + + +class AppParameterApi(AppApiResource): + """Resource for app variables.""" + + variable_fields = { + 'key': fields.String, + 'name': fields.String, + 'description': fields.String, + 'type': fields.String, + 'default': fields.String, + 'max_length': fields.Integer, + 'options': fields.List(fields.String) + } + + parameters_fields = { + 'opening_statement': fields.String, + 'suggested_questions': fields.Raw, + 'suggested_questions_after_answer': fields.Raw, + 'more_like_this': fields.Raw, + 'user_input_form': fields.Raw, + } + + @marshal_with(parameters_fields) + def get(self, app_model, end_user): + """Retrieve app parameters.""" + app_model_config = app_model.app_model_config + + return { + 'opening_statement': app_model_config.opening_statement, + 'suggested_questions': app_model_config.suggested_questions_list, + 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, + 'more_like_this': app_model_config.more_like_this_dict, + 'user_input_form': app_model_config.user_input_form_list + } + + +api.add_resource(AppParameterApi, '/parameters') diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py new file mode 100644 index 0000000000..e5eb4153aa --- /dev/null +++ b/api/controllers/service_api/app/completion.py @@ -0,0 +1,182 @@ +import json +import logging +from typing import Union, Generator + +from flask import stream_with_context, Response +from flask_restful import reqparse +from werkzeug.exceptions import NotFound, InternalServerError + +import services +from controllers.service_api import api +from controllers.service_api.app import create_or_update_end_user_for_user_id +from controllers.service_api.app.error import AppUnavailableError, ProviderNotInitializeError, NotChatAppError, \ + ConversationCompletedError, CompletionRequestError, ProviderQuotaExceededError, \ + ProviderModelCurrentlyNotSupportError +from controllers.service_api.wraps import AppApiResource +from core.conversation_message_task import PubHandler +from core.llm.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \ + LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError +from libs.helper import uuid_value +from services.completion_service import CompletionService + + +class CompletionApi(AppApiResource): + def post(self, app_model, end_user): + if app_model.mode != 'completion': + raise AppUnavailableError() + + parser = reqparse.RequestParser() + parser.add_argument('inputs', type=dict, required=True, location='json') + parser.add_argument('query', type=str, location='json') + parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') + parser.add_argument('user', type=str, location='json') + args = parser.parse_args() + + streaming = args['response_mode'] == 'streaming' + + if end_user is None and args['user'] is not None: + end_user = create_or_update_end_user_for_user_id(app_model, args['user']) + + try: + response = CompletionService.completion( + app_model=app_model, + user=end_user, + args=args, + from_source='api', + streaming=streaming + ) + + return compact_response(response) + except services.errors.conversation.ConversationNotExistsError: + raise NotFound("Conversation Not Exists.") + except services.errors.conversation.ConversationCompletedError: + raise ConversationCompletedError() + except services.errors.app_model_config.AppModelConfigBrokenError: + logging.exception("App model config broken.") + raise AppUnavailableError() + except ProviderTokenNotInitError: + raise ProviderNotInitializeError() + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, + LLMRateLimitError, LLMAuthorizationError) as e: + raise CompletionRequestError(str(e)) + except ValueError as e: + raise e + except Exception as e: + logging.exception("internal server error.") + raise InternalServerError() + + +class CompletionStopApi(AppApiResource): + def post(self, app_model, end_user, task_id): + if app_model.mode != 'completion': + raise AppUnavailableError() + + PubHandler.stop(end_user, task_id) + + return {'result': 'success'}, 200 + + +class ChatApi(AppApiResource): + def post(self, app_model, end_user): + if app_model.mode != 'chat': + raise NotChatAppError() + + parser = reqparse.RequestParser() + parser.add_argument('inputs', type=dict, required=True, location='json') + parser.add_argument('query', type=str, required=True, location='json') + parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') + parser.add_argument('conversation_id', type=uuid_value, location='json') + parser.add_argument('user', type=str, location='json') + args = parser.parse_args() + + streaming = args['response_mode'] == 'streaming' + + if end_user is None and args['user'] is not None: + end_user = create_or_update_end_user_for_user_id(app_model, args['user']) + + try: + response = CompletionService.completion( + app_model=app_model, + user=end_user, + args=args, + from_source='api', + streaming=streaming + ) + + return compact_response(response) + except services.errors.conversation.ConversationNotExistsError: + raise NotFound("Conversation Not Exists.") + except services.errors.conversation.ConversationCompletedError: + raise ConversationCompletedError() + except services.errors.app_model_config.AppModelConfigBrokenError: + logging.exception("App model config broken.") + raise AppUnavailableError() + except ProviderTokenNotInitError: + raise ProviderNotInitializeError() + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, + LLMRateLimitError, LLMAuthorizationError) as e: + raise CompletionRequestError(str(e)) + except ValueError as e: + raise e + except Exception as e: + logging.exception("internal server error.") + raise InternalServerError() + + +class ChatStopApi(AppApiResource): + def post(self, app_model, end_user, task_id): + if app_model.mode != 'chat': + raise NotChatAppError() + + PubHandler.stop(end_user, task_id) + + return {'result': 'success'}, 200 + + +def compact_response(response: Union[dict | Generator]) -> Response: + if isinstance(response, dict): + return Response(response=json.dumps(response), status=200, mimetype='application/json') + else: + def generate() -> Generator: + try: + for chunk in response: + yield chunk + except services.errors.conversation.ConversationNotExistsError: + yield "data: " + json.dumps(api.handle_error(NotFound("Conversation Not Exists.")).get_json()) + "\n\n" + except services.errors.conversation.ConversationCompletedError: + yield "data: " + json.dumps(api.handle_error(ConversationCompletedError()).get_json()) + "\n\n" + except services.errors.app_model_config.AppModelConfigBrokenError: + logging.exception("App model config broken.") + yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n" + except ProviderTokenNotInitError: + yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n" + except QuotaExceededError: + yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n" + except ModelCurrentlyNotSupportError: + yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n" + except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, + LLMRateLimitError, LLMAuthorizationError) as e: + yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n" + except ValueError as e: + yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n" + except Exception: + logging.exception("internal server error.") + yield "data: " + json.dumps(api.handle_error(InternalServerError()).get_json()) + "\n\n" + + return Response(stream_with_context(generate()), status=200, + mimetype='text/event-stream') + + +api.add_resource(CompletionApi, '/completion-messages') +api.add_resource(CompletionStopApi, '/completion-messages//stop') +api.add_resource(ChatApi, '/chat-messages') +api.add_resource(ChatStopApi, '/chat-messages//stop') + diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py new file mode 100644 index 0000000000..602ac8d785 --- /dev/null +++ b/api/controllers/service_api/app/conversation.py @@ -0,0 +1,76 @@ +# -*- coding:utf-8 -*- +from flask_restful import fields, marshal_with, reqparse +from flask_restful.inputs import int_range +from werkzeug.exceptions import NotFound + +from controllers.service_api import api +from controllers.service_api.app import create_or_update_end_user_for_user_id +from controllers.service_api.app.error import NotChatAppError +from controllers.service_api.wraps import AppApiResource +from libs.helper import TimestampField, uuid_value +import services +from services.conversation_service import ConversationService + +conversation_fields = { + 'id': fields.String, + 'name': fields.String, + 'inputs': fields.Raw, + 'status': fields.String, + 'introduction': fields.String, + 'created_at': TimestampField +} + +conversation_infinite_scroll_pagination_fields = { + 'limit': fields.Integer, + 'has_more': fields.Boolean, + 'data': fields.List(fields.Nested(conversation_fields)) +} + + +class ConversationApi(AppApiResource): + + @marshal_with(conversation_infinite_scroll_pagination_fields) + def get(self, app_model, end_user): + if app_model.mode != 'chat': + raise NotChatAppError() + + parser = reqparse.RequestParser() + parser.add_argument('last_id', type=uuid_value, location='args') + parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + parser.add_argument('user', type=str, location='args') + args = parser.parse_args() + + if end_user is None and args['user'] is not None: + end_user = create_or_update_end_user_for_user_id(app_model, args['user']) + + try: + return ConversationService.pagination_by_last_id(app_model, end_user, args['last_id'], args['limit']) + except services.errors.conversation.LastConversationNotExistsError: + raise NotFound("Last Conversation Not Exists.") + + +class ConversationRenameApi(AppApiResource): + + @marshal_with(conversation_fields) + def post(self, app_model, end_user, c_id): + if app_model.mode != 'chat': + raise NotChatAppError() + + conversation_id = str(c_id) + + parser = reqparse.RequestParser() + parser.add_argument('name', type=str, required=True, location='json') + parser.add_argument('user', type=str, location='json') + args = parser.parse_args() + + if end_user is None and args['user'] is not None: + end_user = create_or_update_end_user_for_user_id(app_model, args['user']) + + try: + return ConversationService.rename(app_model, conversation_id, end_user, args['name']) + except services.errors.conversation.ConversationNotExistsError: + raise NotFound("Conversation Not Exists.") + + +api.add_resource(ConversationRenameApi, '/conversations//name', endpoint='conversation_name') +api.add_resource(ConversationApi, '/conversations') diff --git a/api/controllers/service_api/app/error.py b/api/controllers/service_api/app/error.py new file mode 100644 index 0000000000..c59f570efd --- /dev/null +++ b/api/controllers/service_api/app/error.py @@ -0,0 +1,51 @@ +# -*- coding:utf-8 -*- +from libs.exception import BaseHTTPException + + +class AppUnavailableError(BaseHTTPException): + error_code = 'app_unavailable' + description = "App unavailable." + code = 400 + + +class NotCompletionAppError(BaseHTTPException): + error_code = 'not_completion_app' + description = "Not Completion App" + code = 400 + + +class NotChatAppError(BaseHTTPException): + error_code = 'not_chat_app' + description = "Not Chat App" + code = 400 + + +class ConversationCompletedError(BaseHTTPException): + error_code = 'conversation_completed' + description = "Conversation Completed." + code = 400 + + +class ProviderNotInitializeError(BaseHTTPException): + error_code = 'provider_not_initialize' + description = "Provider Token not initialize." + code = 400 + + +class ProviderQuotaExceededError(BaseHTTPException): + error_code = 'provider_quota_exceeded' + description = "Provider quota exceeded." + code = 400 + + +class ProviderModelCurrentlyNotSupportError(BaseHTTPException): + error_code = 'model_currently_not_support' + description = "GPT-4 currently not support." + code = 400 + + +class CompletionRequestError(BaseHTTPException): + error_code = 'completion_request_error' + description = "Completion request failed." + code = 400 + diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py new file mode 100644 index 0000000000..ef020891ff --- /dev/null +++ b/api/controllers/service_api/app/message.py @@ -0,0 +1,81 @@ +# -*- coding:utf-8 -*- +from flask_restful import fields, marshal_with, reqparse +from flask_restful.inputs import int_range +from werkzeug.exceptions import NotFound + +import services +from controllers.service_api import api +from controllers.service_api.app import create_or_update_end_user_for_user_id +from controllers.service_api.app.error import NotChatAppError +from controllers.service_api.wraps import AppApiResource +from libs.helper import TimestampField, uuid_value +from services.message_service import MessageService + + +class MessageListApi(AppApiResource): + feedback_fields = { + 'rating': fields.String + } + + message_fields = { + 'id': fields.String, + 'conversation_id': fields.String, + 'inputs': fields.Raw, + 'query': fields.String, + 'answer': fields.String, + 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), + 'created_at': TimestampField + } + + message_infinite_scroll_pagination_fields = { + 'limit': fields.Integer, + 'has_more': fields.Boolean, + 'data': fields.List(fields.Nested(message_fields)) + } + + @marshal_with(message_infinite_scroll_pagination_fields) + def get(self, app_model, end_user): + if app_model.mode != 'chat': + raise NotChatAppError() + + parser = reqparse.RequestParser() + parser.add_argument('conversation_id', required=True, type=uuid_value, location='args') + parser.add_argument('first_id', type=uuid_value, location='args') + parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + parser.add_argument('user', type=str, location='args') + args = parser.parse_args() + + if end_user is None and args['user'] is not None: + end_user = create_or_update_end_user_for_user_id(app_model, args['user']) + + try: + return MessageService.pagination_by_first_id(app_model, end_user, + args['conversation_id'], args['first_id'], args['limit']) + except services.errors.conversation.ConversationNotExistsError: + raise NotFound("Conversation Not Exists.") + except services.errors.message.FirstMessageNotExistsError: + raise NotFound("First Message Not Exists.") + + +class MessageFeedbackApi(AppApiResource): + def post(self, app_model, end_user, message_id): + message_id = str(message_id) + + parser = reqparse.RequestParser() + parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json') + parser.add_argument('user', type=str, location='json') + args = parser.parse_args() + + if end_user is None and args['user'] is not None: + end_user = create_or_update_end_user_for_user_id(app_model, args['user']) + + try: + MessageService.create_feedback(app_model, message_id, end_user, args['rating']) + except services.errors.message.MessageNotExistsError: + raise NotFound("Message Not Exists.") + + return {'result': 'success'} + + +api.add_resource(MessageListApi, '/messages') +api.add_resource(MessageFeedbackApi, '/messages//feedbacks') diff --git a/api/controllers/service_api/dataset/__init__.py b/api/controllers/service_api/dataset/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py new file mode 100644 index 0000000000..47a90756db --- /dev/null +++ b/api/controllers/service_api/dataset/document.py @@ -0,0 +1,129 @@ +import datetime +import uuid + +from flask import current_app +from flask_restful import reqparse +from werkzeug.exceptions import NotFound + +import services.dataset_service +from controllers.service_api import api +from controllers.service_api.app.error import ProviderNotInitializeError +from controllers.service_api.dataset.error import ArchivedDocumentImmutableError, DocumentIndexingError, \ + DatasetNotInitedError +from controllers.service_api.wraps import DatasetApiResource +from core.llm.error import ProviderTokenNotInitError +from extensions.ext_database import db +from extensions.ext_storage import storage +from models.model import UploadFile +from services.dataset_service import DocumentService + + +class DocumentListApi(DatasetApiResource): + """Resource for documents.""" + + def post(self, dataset): + """Create document.""" + parser = reqparse.RequestParser() + parser.add_argument('name', type=str, required=True, nullable=False, location='json') + parser.add_argument('text', type=str, required=True, nullable=False, location='json') + parser.add_argument('doc_type', type=str, location='json') + parser.add_argument('doc_metadata', type=dict, location='json') + args = parser.parse_args() + + if not dataset.indexing_technique: + raise DatasetNotInitedError("Dataset indexing technique must be set.") + + doc_type = args.get('doc_type') + doc_metadata = args.get('doc_metadata') + + if doc_type and doc_type not in DocumentService.DOCUMENT_METADATA_SCHEMA: + raise ValueError('Invalid doc_type.') + + # user uuid as file name + file_uuid = str(uuid.uuid4()) + file_key = 'upload_files/' + dataset.tenant_id + '/' + file_uuid + '.txt' + + # save file to storage + storage.save(file_key, args.get('text')) + + # save file to db + config = current_app.config + upload_file = UploadFile( + tenant_id=dataset.tenant_id, + storage_type=config['STORAGE_TYPE'], + key=file_key, + name=args.get('name') + '.txt', + size=len(args.get('text')), + extension='txt', + mime_type='text/plain', + created_by=dataset.created_by, + created_at=datetime.datetime.utcnow(), + used=True, + used_by=dataset.created_by, + used_at=datetime.datetime.utcnow() + ) + + db.session.add(upload_file) + db.session.commit() + + document_data = { + 'data_source': { + 'type': 'upload_file', + 'info': upload_file.id + } + } + + try: + document = DocumentService.save_document_with_dataset_id( + dataset=dataset, + document_data=document_data, + account=dataset.created_by_account, + dataset_process_rule=dataset.latest_process_rule, + created_from='api' + ) + except ProviderTokenNotInitError: + raise ProviderNotInitializeError() + + if doc_type and doc_metadata: + metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type] + + document.doc_metadata = {} + + for key, value_type in metadata_schema.items(): + value = doc_metadata.get(key) + if value is not None and isinstance(value, value_type): + document.doc_metadata[key] = value + + document.doc_type = doc_type + document.updated_at = datetime.datetime.utcnow() + db.session.commit() + + return {'id': document.id} + + +class DocumentApi(DatasetApiResource): + def delete(self, dataset, document_id): + """Delete document.""" + document_id = str(document_id) + + document = DocumentService.get_document(dataset.id, document_id) + + # 404 if document not found + if document is None: + raise NotFound("Document Not Exists.") + + # 403 if document is archived + if DocumentService.check_archived(document): + raise ArchivedDocumentImmutableError() + + try: + # delete document + DocumentService.delete_document(document) + except services.errors.document.DocumentIndexingError: + raise DocumentIndexingError('Cannot delete document during indexing.') + + return {'result': 'success'}, 204 + + +api.add_resource(DocumentListApi, '/documents') +api.add_resource(DocumentApi, '/documents/') diff --git a/api/controllers/service_api/dataset/error.py b/api/controllers/service_api/dataset/error.py new file mode 100644 index 0000000000..d231e0b40a --- /dev/null +++ b/api/controllers/service_api/dataset/error.py @@ -0,0 +1,20 @@ +# -*- coding:utf-8 -*- +from libs.exception import BaseHTTPException + + +class ArchivedDocumentImmutableError(BaseHTTPException): + error_code = 'archived_document_immutable' + description = "Cannot operate when document was archived." + code = 403 + + +class DocumentIndexingError(BaseHTTPException): + error_code = 'document_indexing' + description = "Cannot operate document during indexing." + code = 403 + + +class DatasetNotInitedError(BaseHTTPException): + error_code = 'dataset_not_inited' + description = "Dataset not inited." + code = 403 diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py new file mode 100644 index 0000000000..cb64a3b158 --- /dev/null +++ b/api/controllers/service_api/wraps.py @@ -0,0 +1,95 @@ +# -*- coding:utf-8 -*- +from datetime import datetime +from functools import wraps + +from flask import request +from flask_restful import Resource +from werkzeug.exceptions import NotFound, Unauthorized + +from extensions.ext_database import db +from models.dataset import Dataset +from models.model import ApiToken, App + + +def validate_app_token(view=None): + def decorator(view): + @wraps(view) + def decorated(*args, **kwargs): + api_token = validate_and_get_api_token('app') + + app_model = db.session.query(App).get(api_token.app_id) + if not app_model: + raise NotFound() + + if app_model.status != 'normal': + raise NotFound() + + if not app_model.enable_api: + raise NotFound() + + return view(app_model, None, *args, **kwargs) + return decorated + + if view: + return decorator(view) + + # if view is None, it means that the decorator is used without parentheses + # use the decorator as a function for method_decorators + return decorator + + +def validate_dataset_token(view=None): + def decorator(view): + @wraps(view) + def decorated(*args, **kwargs): + api_token = validate_and_get_api_token('dataset') + + dataset = db.session.query(Dataset).get(api_token.dataset_id) + if not dataset: + raise NotFound() + + return view(dataset, *args, **kwargs) + return decorated + + if view: + return decorator(view) + + # if view is None, it means that the decorator is used without parentheses + # use the decorator as a function for method_decorators + return decorator + + +def validate_and_get_api_token(scope=None): + """ + Validate and get API token. + """ + auth_header = request.headers.get('Authorization') + if auth_header is None: + raise Unauthorized() + + auth_scheme, auth_token = auth_header.split(None, 1) + auth_scheme = auth_scheme.lower() + + if auth_scheme != 'bearer': + raise Unauthorized() + + api_token = db.session.query(ApiToken).filter( + ApiToken.token == auth_token, + ApiToken.type == scope, + ).first() + + if not api_token: + raise Unauthorized() + + api_token.last_used_at = datetime.utcnow() + db.session.commit() + + return api_token + + +class AppApiResource(Resource): + method_decorators = [validate_app_token] + + +class DatasetApiResource(Resource): + method_decorators = [validate_dataset_token] diff --git a/api/controllers/web/__init__.py b/api/controllers/web/__init__.py new file mode 100644 index 0000000000..b793f11014 --- /dev/null +++ b/api/controllers/web/__init__.py @@ -0,0 +1,10 @@ +# -*- coding:utf-8 -*- +from flask import Blueprint + +from libs.external_api import ExternalApi + +bp = Blueprint('web', __name__, url_prefix='/api') +api = ExternalApi(bp) + + +from . import completion, app, conversation, message, site, saved_message diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py new file mode 100644 index 0000000000..1396531111 --- /dev/null +++ b/api/controllers/web/app.py @@ -0,0 +1,42 @@ +# -*- coding:utf-8 -*- +from flask_restful import marshal_with, fields + +from controllers.web import api +from controllers.web.wraps import WebApiResource + + +class AppParameterApi(WebApiResource): + """Resource for app variables.""" + variable_fields = { + 'key': fields.String, + 'name': fields.String, + 'description': fields.String, + 'type': fields.String, + 'default': fields.String, + 'max_length': fields.Integer, + 'options': fields.List(fields.String) + } + + parameters_fields = { + 'opening_statement': fields.String, + 'suggested_questions': fields.Raw, + 'suggested_questions_after_answer': fields.Raw, + 'more_like_this': fields.Raw, + 'user_input_form': fields.Raw, + } + + @marshal_with(parameters_fields) + def get(self, app_model, end_user): + """Retrieve app parameters.""" + app_model_config = app_model.app_model_config + + return { + 'opening_statement': app_model_config.opening_statement, + 'suggested_questions': app_model_config.suggested_questions_list, + 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, + 'more_like_this': app_model_config.more_like_this_dict, + 'user_input_form': app_model_config.user_input_form_list + } + + +api.add_resource(AppParameterApi, '/parameters') diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py new file mode 100644 index 0000000000..532bcfaa8d --- /dev/null +++ b/api/controllers/web/completion.py @@ -0,0 +1,175 @@ +# -*- coding:utf-8 -*- +import json +import logging +from typing import Generator, Union + +from flask import Response, stream_with_context +from flask_restful import reqparse +from werkzeug.exceptions import InternalServerError, NotFound + +import services +from controllers.web import api +from controllers.web.error import AppUnavailableError, ConversationCompletedError, \ + ProviderNotInitializeError, NotChatAppError, NotCompletionAppError, CompletionRequestError, \ + ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError +from controllers.web.wraps import WebApiResource +from core.conversation_message_task import PubHandler +from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ + LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError +from libs.helper import uuid_value +from services.completion_service import CompletionService + + +# define completion api for user +class CompletionApi(WebApiResource): + + def post(self, app_model, end_user): + if app_model.mode != 'completion': + raise NotCompletionAppError() + + parser = reqparse.RequestParser() + parser.add_argument('inputs', type=dict, required=True, location='json') + parser.add_argument('query', type=str, location='json') + parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') + args = parser.parse_args() + + streaming = args['response_mode'] == 'streaming' + + try: + response = CompletionService.completion( + app_model=app_model, + user=end_user, + args=args, + from_source='api', + streaming=streaming + ) + + return compact_response(response) + except services.errors.conversation.ConversationNotExistsError: + raise NotFound("Conversation Not Exists.") + except services.errors.conversation.ConversationCompletedError: + raise ConversationCompletedError() + except services.errors.app_model_config.AppModelConfigBrokenError: + logging.exception("App model config broken.") + raise AppUnavailableError() + except ProviderTokenNotInitError: + raise ProviderNotInitializeError() + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, + LLMRateLimitError, LLMAuthorizationError) as e: + raise CompletionRequestError(str(e)) + except ValueError as e: + raise e + except Exception as e: + logging.exception("internal server error.") + raise InternalServerError() + + +class CompletionStopApi(WebApiResource): + def post(self, app_model, end_user, task_id): + if app_model.mode != 'completion': + raise NotCompletionAppError() + + PubHandler.stop(end_user, task_id) + + return {'result': 'success'}, 200 + + +class ChatApi(WebApiResource): + def post(self, app_model, end_user): + if app_model.mode != 'chat': + raise NotChatAppError() + + parser = reqparse.RequestParser() + parser.add_argument('inputs', type=dict, required=True, location='json') + parser.add_argument('query', type=str, required=True, location='json') + parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') + parser.add_argument('conversation_id', type=uuid_value, location='json') + args = parser.parse_args() + + streaming = args['response_mode'] == 'streaming' + + try: + response = CompletionService.completion( + app_model=app_model, + user=end_user, + args=args, + from_source='api', + streaming=streaming + ) + + return compact_response(response) + except services.errors.conversation.ConversationNotExistsError: + raise NotFound("Conversation Not Exists.") + except services.errors.conversation.ConversationCompletedError: + raise ConversationCompletedError() + except services.errors.app_model_config.AppModelConfigBrokenError: + logging.exception("App model config broken.") + raise AppUnavailableError() + except ProviderTokenNotInitError: + raise ProviderNotInitializeError() + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, + LLMRateLimitError, LLMAuthorizationError) as e: + raise CompletionRequestError(str(e)) + except ValueError as e: + raise e + except Exception as e: + logging.exception("internal server error.") + raise InternalServerError() + + +class ChatStopApi(WebApiResource): + def post(self, app_model, end_user, task_id): + if app_model.mode != 'chat': + raise NotChatAppError() + + PubHandler.stop(end_user, task_id) + + return {'result': 'success'}, 200 + + +def compact_response(response: Union[dict | Generator]) -> Response: + if isinstance(response, dict): + return Response(response=json.dumps(response), status=200, mimetype='application/json') + else: + def generate() -> Generator: + try: + for chunk in response: + yield chunk + except services.errors.conversation.ConversationNotExistsError: + yield "data: " + json.dumps(api.handle_error(NotFound("Conversation Not Exists.")).get_json()) + "\n\n" + except services.errors.conversation.ConversationCompletedError: + yield "data: " + json.dumps(api.handle_error(ConversationCompletedError()).get_json()) + "\n\n" + except services.errors.app_model_config.AppModelConfigBrokenError: + logging.exception("App model config broken.") + yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n" + except ProviderTokenNotInitError: + yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n" + except QuotaExceededError: + yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n" + except ModelCurrentlyNotSupportError: + yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n" + except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, + LLMRateLimitError, LLMAuthorizationError) as e: + yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n" + except ValueError as e: + yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n" + except Exception: + logging.exception("internal server error.") + yield "data: " + json.dumps(api.handle_error(InternalServerError()).get_json()) + "\n\n" + + return Response(stream_with_context(generate()), status=200, + mimetype='text/event-stream') + + +api.add_resource(CompletionApi, '/completion-messages') +api.add_resource(CompletionStopApi, '/completion-messages//stop') +api.add_resource(ChatApi, '/chat-messages') +api.add_resource(ChatStopApi, '/chat-messages//stop') diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py new file mode 100644 index 0000000000..53ba382051 --- /dev/null +++ b/api/controllers/web/conversation.py @@ -0,0 +1,121 @@ +# -*- coding:utf-8 -*- +from flask_restful import fields, reqparse, marshal_with +from flask_restful.inputs import int_range +from werkzeug.exceptions import NotFound + +from controllers.web import api +from controllers.web.error import NotChatAppError +from controllers.web.wraps import WebApiResource +from libs.helper import TimestampField, uuid_value +from services.conversation_service import ConversationService +from services.errors.conversation import LastConversationNotExistsError, ConversationNotExistsError +from services.web_conversation_service import WebConversationService + +conversation_fields = { + 'id': fields.String, + 'name': fields.String, + 'inputs': fields.Raw, + 'status': fields.String, + 'introduction': fields.String, + 'created_at': TimestampField +} + +conversation_infinite_scroll_pagination_fields = { + 'limit': fields.Integer, + 'has_more': fields.Boolean, + 'data': fields.List(fields.Nested(conversation_fields)) +} + + +class ConversationListApi(WebApiResource): + + @marshal_with(conversation_infinite_scroll_pagination_fields) + def get(self, app_model, end_user): + if app_model.mode != 'chat': + raise NotChatAppError() + + parser = reqparse.RequestParser() + parser.add_argument('last_id', type=uuid_value, location='args') + parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + parser.add_argument('pinned', type=str, choices=['true', 'false', None], location='args') + args = parser.parse_args() + + pinned = None + if 'pinned' in args and args['pinned'] is not None: + pinned = True if args['pinned'] == 'true' else False + + try: + return WebConversationService.pagination_by_last_id( + app_model=app_model, + end_user=end_user, + last_id=args['last_id'], + limit=args['limit'], + pinned=pinned + ) + except LastConversationNotExistsError: + raise NotFound("Last Conversation Not Exists.") + + +class ConversationApi(WebApiResource): + def delete(self, app_model, end_user, c_id): + if app_model.mode != 'chat': + raise NotChatAppError() + + conversation_id = str(c_id) + ConversationService.delete(app_model, conversation_id, end_user) + WebConversationService.unpin(app_model, conversation_id, end_user) + + return {"result": "success"}, 204 + + +class ConversationRenameApi(WebApiResource): + + @marshal_with(conversation_fields) + def post(self, app_model, end_user, c_id): + if app_model.mode != 'chat': + raise NotChatAppError() + + conversation_id = str(c_id) + + parser = reqparse.RequestParser() + parser.add_argument('name', type=str, required=True, location='json') + args = parser.parse_args() + + try: + return ConversationService.rename(app_model, conversation_id, end_user, args['name']) + except ConversationNotExistsError: + raise NotFound("Conversation Not Exists.") + + +class ConversationPinApi(WebApiResource): + + def patch(self, app_model, end_user, c_id): + if app_model.mode != 'chat': + raise NotChatAppError() + + conversation_id = str(c_id) + + try: + WebConversationService.pin(app_model, conversation_id, end_user) + except ConversationNotExistsError: + raise NotFound("Conversation Not Exists.") + + return {"result": "success"} + + +class ConversationUnPinApi(WebApiResource): + def patch(self, app_model, end_user, c_id): + if app_model.mode != 'chat': + raise NotChatAppError() + + conversation_id = str(c_id) + WebConversationService.unpin(app_model, conversation_id, end_user) + + return {"result": "success"} + + +api.add_resource(ConversationRenameApi, '/conversations//name', endpoint='web_conversation_name') +api.add_resource(ConversationListApi, '/conversations') +api.add_resource(ConversationApi, '/conversations/') +api.add_resource(ConversationPinApi, '/conversations//pin') +api.add_resource(ConversationUnPinApi, '/conversations//unpin') diff --git a/api/controllers/web/error.py b/api/controllers/web/error.py new file mode 100644 index 0000000000..ea72422a1b --- /dev/null +++ b/api/controllers/web/error.py @@ -0,0 +1,62 @@ +# -*- coding:utf-8 -*- +from libs.exception import BaseHTTPException + + +class AppUnavailableError(BaseHTTPException): + error_code = 'app_unavailable' + description = "App unavailable." + code = 400 + + +class NotCompletionAppError(BaseHTTPException): + error_code = 'not_completion_app' + description = "Not Completion App" + code = 400 + + +class NotChatAppError(BaseHTTPException): + error_code = 'not_chat_app' + description = "Not Chat App" + code = 400 + + +class ConversationCompletedError(BaseHTTPException): + error_code = 'conversation_completed' + description = "Conversation Completed." + code = 400 + + +class ProviderNotInitializeError(BaseHTTPException): + error_code = 'provider_not_initialize' + description = "Provider Token not initialize." + code = 400 + + +class ProviderQuotaExceededError(BaseHTTPException): + error_code = 'provider_quota_exceeded' + description = "Provider quota exceeded." + code = 400 + + +class ProviderModelCurrentlyNotSupportError(BaseHTTPException): + error_code = 'model_currently_not_support' + description = "GPT-4 currently not support." + code = 400 + + +class CompletionRequestError(BaseHTTPException): + error_code = 'completion_request_error' + description = "Completion request failed." + code = 400 + + +class AppMoreLikeThisDisabledError(BaseHTTPException): + error_code = 'app_more_like_this_disabled' + description = "More like this disabled." + code = 403 + + +class AppSuggestedQuestionsAfterAnswerDisabledError(BaseHTTPException): + error_code = 'app_suggested_questions_after_answer_disabled' + description = "Function Suggested questions after answer disabled." + code = 403 diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py new file mode 100644 index 0000000000..0d519eac06 --- /dev/null +++ b/api/controllers/web/message.py @@ -0,0 +1,189 @@ +# -*- coding:utf-8 -*- +import json +import logging +from typing import Generator, Union + +from flask import stream_with_context, Response +from flask_restful import reqparse, fields, marshal_with +from flask_restful.inputs import int_range +from werkzeug.exceptions import NotFound, InternalServerError + +import services +from controllers.web import api +from controllers.web.error import NotChatAppError, CompletionRequestError, ProviderNotInitializeError, \ + AppMoreLikeThisDisabledError, NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError, \ + ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError +from controllers.web.wraps import WebApiResource +from core.llm.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \ + ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError +from libs.helper import uuid_value, TimestampField +from services.completion_service import CompletionService +from services.errors.app import MoreLikeThisDisabledError +from services.errors.conversation import ConversationNotExistsError +from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError +from services.message_service import MessageService + + +class MessageListApi(WebApiResource): + feedback_fields = { + 'rating': fields.String + } + + message_fields = { + 'id': fields.String, + 'conversation_id': fields.String, + 'inputs': fields.Raw, + 'query': fields.String, + 'answer': fields.String, + 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), + 'created_at': TimestampField + } + + message_infinite_scroll_pagination_fields = { + 'limit': fields.Integer, + 'has_more': fields.Boolean, + 'data': fields.List(fields.Nested(message_fields)) + } + + @marshal_with(message_infinite_scroll_pagination_fields) + def get(self, app_model, end_user): + if app_model.mode != 'chat': + raise NotChatAppError() + + parser = reqparse.RequestParser() + parser.add_argument('conversation_id', required=True, type=uuid_value, location='args') + parser.add_argument('first_id', type=uuid_value, location='args') + parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + args = parser.parse_args() + + try: + return MessageService.pagination_by_first_id(app_model, end_user, + args['conversation_id'], args['first_id'], args['limit']) + except services.errors.conversation.ConversationNotExistsError: + raise NotFound("Conversation Not Exists.") + except services.errors.message.FirstMessageNotExistsError: + raise NotFound("First Message Not Exists.") + + +class MessageFeedbackApi(WebApiResource): + def post(self, app_model, end_user, message_id): + message_id = str(message_id) + + parser = reqparse.RequestParser() + parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json') + args = parser.parse_args() + + try: + MessageService.create_feedback(app_model, message_id, end_user, args['rating']) + except services.errors.message.MessageNotExistsError: + raise NotFound("Message Not Exists.") + + return {'result': 'success'} + + +class MessageMoreLikeThisApi(WebApiResource): + def get(self, app_model, end_user, message_id): + if app_model.mode != 'completion': + raise NotCompletionAppError() + + message_id = str(message_id) + + parser = reqparse.RequestParser() + parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'], location='args') + args = parser.parse_args() + + streaming = args['response_mode'] == 'streaming' + + try: + response = CompletionService.generate_more_like_this(app_model, end_user, message_id, streaming) + return compact_response(response) + except MessageNotExistsError: + raise NotFound("Message Not Exists.") + except MoreLikeThisDisabledError: + raise AppMoreLikeThisDisabledError() + except ProviderTokenNotInitError: + raise ProviderNotInitializeError() + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, + LLMRateLimitError, LLMAuthorizationError) as e: + raise CompletionRequestError(str(e)) + except ValueError as e: + raise e + except Exception: + logging.exception("internal server error.") + raise InternalServerError() + + +def compact_response(response: Union[dict | Generator]) -> Response: + if isinstance(response, dict): + return Response(response=json.dumps(response), status=200, mimetype='application/json') + else: + def generate() -> Generator: + try: + for chunk in response: + yield chunk + except MessageNotExistsError: + yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n" + except MoreLikeThisDisabledError: + yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n" + except ProviderTokenNotInitError: + yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n" + except QuotaExceededError: + yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n" + except ModelCurrentlyNotSupportError: + yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n" + except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, + LLMRateLimitError, LLMAuthorizationError) as e: + yield "data: " + json.dumps(api.handle_error(CompletionRequestError(str(e))).get_json()) + "\n\n" + except ValueError as e: + yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n" + except Exception: + logging.exception("internal server error.") + yield "data: " + json.dumps(api.handle_error(InternalServerError()).get_json()) + "\n\n" + + return Response(stream_with_context(generate()), status=200, + mimetype='text/event-stream') + + +class MessageSuggestedQuestionApi(WebApiResource): + def get(self, app_model, end_user, message_id): + if app_model.mode != 'chat': + raise NotCompletionAppError() + + message_id = str(message_id) + + try: + questions = MessageService.get_suggested_questions_after_answer( + app_model=app_model, + user=end_user, + message_id=message_id + ) + except MessageNotExistsError: + raise NotFound("Message not found") + except ConversationNotExistsError: + raise NotFound("Conversation not found") + except SuggestedQuestionsAfterAnswerDisabledError: + raise AppSuggestedQuestionsAfterAnswerDisabledError() + except ProviderTokenNotInitError: + raise ProviderNotInitializeError() + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, + LLMRateLimitError, LLMAuthorizationError) as e: + raise CompletionRequestError(str(e)) + except Exception: + logging.exception("internal server error.") + raise InternalServerError() + + return {'data': questions} + + +api.add_resource(MessageListApi, '/messages') +api.add_resource(MessageFeedbackApi, '/messages//feedbacks') +api.add_resource(MessageMoreLikeThisApi, '/messages//more-like-this') +api.add_resource(MessageSuggestedQuestionApi, '/messages//suggested-questions') diff --git a/api/controllers/web/saved_message.py b/api/controllers/web/saved_message.py new file mode 100644 index 0000000000..7f6f4249c9 --- /dev/null +++ b/api/controllers/web/saved_message.py @@ -0,0 +1,74 @@ +from flask_restful import reqparse, marshal_with, fields +from flask_restful.inputs import int_range +from werkzeug.exceptions import NotFound + +from controllers.web import api +from controllers.web.error import NotCompletionAppError +from controllers.web.wraps import WebApiResource +from libs.helper import uuid_value, TimestampField +from services.errors.message import MessageNotExistsError +from services.saved_message_service import SavedMessageService + +feedback_fields = { + 'rating': fields.String +} + +message_fields = { + 'id': fields.String, + 'inputs': fields.Raw, + 'query': fields.String, + 'answer': fields.String, + 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), + 'created_at': TimestampField +} + + +class SavedMessageListApi(WebApiResource): + saved_message_infinite_scroll_pagination_fields = { + 'limit': fields.Integer, + 'has_more': fields.Boolean, + 'data': fields.List(fields.Nested(message_fields)) + } + + @marshal_with(saved_message_infinite_scroll_pagination_fields) + def get(self, app_model, end_user): + if app_model.mode != 'completion': + raise NotCompletionAppError() + + parser = reqparse.RequestParser() + parser.add_argument('last_id', type=uuid_value, location='args') + parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + args = parser.parse_args() + + return SavedMessageService.pagination_by_last_id(app_model, end_user, args['last_id'], args['limit']) + + def post(self, app_model, end_user): + if app_model.mode != 'completion': + raise NotCompletionAppError() + + parser = reqparse.RequestParser() + parser.add_argument('message_id', type=uuid_value, required=True, location='json') + args = parser.parse_args() + + try: + SavedMessageService.save(app_model, end_user, args['message_id']) + except MessageNotExistsError: + raise NotFound("Message Not Exists.") + + return {'result': 'success'} + + +class SavedMessageApi(WebApiResource): + def delete(self, app_model, end_user, message_id): + message_id = str(message_id) + + if app_model.mode != 'completion': + raise NotCompletionAppError() + + SavedMessageService.delete(app_model, end_user, message_id) + + return {'result': 'success'} + + +api.add_resource(SavedMessageListApi, '/saved-messages') +api.add_resource(SavedMessageApi, '/saved-messages/') diff --git a/api/controllers/web/site.py b/api/controllers/web/site.py new file mode 100644 index 0000000000..de7a38b6df --- /dev/null +++ b/api/controllers/web/site.py @@ -0,0 +1,73 @@ +# -*- coding:utf-8 -*- +from flask_restful import fields, marshal_with +from werkzeug.exceptions import Forbidden + +from controllers.web import api +from controllers.web.wraps import WebApiResource +from extensions.ext_database import db +from models.model import Site + + +class AppSiteApi(WebApiResource): + """Resource for app sites.""" + + model_config_fields = { + 'opening_statement': fields.String, + 'suggested_questions': fields.Raw(attribute='suggested_questions_list'), + 'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'), + 'more_like_this': fields.Raw(attribute='more_like_this_dict'), + 'model': fields.Raw(attribute='model_dict'), + 'user_input_form': fields.Raw(attribute='user_input_form_list'), + 'pre_prompt': fields.String, + } + + site_fields = { + 'title': fields.String, + 'icon': fields.String, + 'icon_background': fields.String, + 'description': fields.String, + 'copyright': fields.String, + 'privacy_policy': fields.String, + 'default_language': fields.String, + 'prompt_public': fields.Boolean + } + + app_fields = { + 'app_id': fields.String, + 'end_user_id': fields.String, + 'enable_site': fields.Boolean, + 'site': fields.Nested(site_fields), + 'model_config': fields.Nested(model_config_fields, allow_null=True), + 'plan': fields.String, + } + + @marshal_with(app_fields) + def get(self, app_model, end_user): + """Retrieve app site info.""" + # get site + site = db.session.query(Site).filter(Site.app_id == app_model.id).first() + + if not site: + raise Forbidden() + + return AppSiteInfo(app_model.tenant, app_model, site, end_user.id) + + +api.add_resource(AppSiteApi, '/site') + + +class AppSiteInfo: + """Class to store site information.""" + + def __init__(self, tenant, app, site, end_user): + """Initialize AppSiteInfo instance.""" + self.app_id = app.id + self.end_user_id = end_user + self.enable_site = app.enable_site + self.site = site + self.model_config = None + self.plan = tenant.plan + + if app.enable_site and site.prompt_public: + app_model_config = app.app_model_config + self.model_config = app_model_config diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py new file mode 100644 index 0000000000..d227a9659e --- /dev/null +++ b/api/controllers/web/wraps.py @@ -0,0 +1,107 @@ +# -*- coding:utf-8 -*- +import uuid +from functools import wraps + +from flask import request, session +from flask_restful import Resource +from werkzeug.exceptions import NotFound, Unauthorized + +from extensions.ext_database import db +from models.model import App, Site, EndUser + + +def validate_token(view=None): + def decorator(view): + @wraps(view) + def decorated(*args, **kwargs): + site = validate_and_get_site() + + app_model = db.session.query(App).get(site.app_id) + if not app_model: + raise NotFound() + + if app_model.status != 'normal': + raise NotFound() + + if not app_model.enable_site: + raise NotFound() + + end_user = create_or_update_end_user_for_session(app_model) + + return view(app_model, end_user, *args, **kwargs) + return decorated + + if view: + return decorator(view) + return decorator + + +def validate_and_get_site(): + """ + Validate and get API token. + """ + auth_header = request.headers.get('Authorization') + if auth_header is None: + raise Unauthorized() + + auth_scheme, auth_token = auth_header.split(None, 1) + auth_scheme = auth_scheme.lower() + + if auth_scheme != 'bearer': + raise Unauthorized() + + site = db.session.query(Site).filter( + Site.code == auth_token, + Site.status == 'normal' + ).first() + + if not site: + raise NotFound() + + return site + + +def create_or_update_end_user_for_session(app_model): + """ + Create or update session terminal based on session ID. + """ + if 'session_id' not in session: + session['session_id'] = generate_session_id() + + session_id = session.get('session_id') + end_user = db.session.query(EndUser) \ + .filter( + EndUser.session_id == session_id, + EndUser.type == 'browser' + ).first() + + if end_user is None: + end_user = EndUser( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + type='browser', + is_anonymous=True, + session_id=session_id + ) + db.session.add(end_user) + db.session.commit() + + return end_user + + +def generate_session_id(): + """ + Generate a unique session ID. + """ + count = 1 + session_id = '' + while count != 0: + session_id = str(uuid.uuid4()) + count = db.session.query(EndUser) \ + .filter(EndUser.session_id == session_id).count() + + return session_id + + +class WebApiResource(Resource): + method_decorators = [validate_token] diff --git a/api/core/__init__.py b/api/core/__init__.py new file mode 100644 index 0000000000..f6257d8b36 --- /dev/null +++ b/api/core/__init__.py @@ -0,0 +1,52 @@ +import os +from typing import Optional + +import langchain +from flask import Flask +from jieba.analyse import default_tfidf +from langchain import set_handler +from langchain.prompts.base import DEFAULT_FORMATTER_MAPPING +from llama_index import IndexStructType, QueryMode +from llama_index.indices.registry import INDEX_STRUT_TYPE_TO_QUERY_MAP +from pydantic import BaseModel + +from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler +from core.index.keyword_table.jieba_keyword_table import GPTJIEBAKeywordTableIndex +from core.index.keyword_table.stopwords import STOPWORDS +from core.prompt.prompt_template import OneLineFormatter +from core.vector_store.vector_store import VectorStore +from core.vector_store.vector_store_index_query import EnhanceGPTVectorStoreIndexQuery + + +class HostedOpenAICredential(BaseModel): + api_key: str + + +class HostedLLMCredentials(BaseModel): + openai: Optional[HostedOpenAICredential] = None + + +hosted_llm_credentials = HostedLLMCredentials() + + +def init_app(app: Flask): + formatter = OneLineFormatter() + DEFAULT_FORMATTER_MAPPING['f-string'] = formatter.format + INDEX_STRUT_TYPE_TO_QUERY_MAP[IndexStructType.KEYWORD_TABLE] = GPTJIEBAKeywordTableIndex.get_query_map() + INDEX_STRUT_TYPE_TO_QUERY_MAP[IndexStructType.WEAVIATE] = { + QueryMode.DEFAULT: EnhanceGPTVectorStoreIndexQuery, + QueryMode.EMBEDDING: EnhanceGPTVectorStoreIndexQuery, + } + INDEX_STRUT_TYPE_TO_QUERY_MAP[IndexStructType.QDRANT] = { + QueryMode.DEFAULT: EnhanceGPTVectorStoreIndexQuery, + QueryMode.EMBEDDING: EnhanceGPTVectorStoreIndexQuery, + } + + default_tfidf.stop_words = STOPWORDS + + if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': + langchain.verbose = True + set_handler(DifyStdOutCallbackHandler()) + + if app.config.get("OPENAI_API_KEY"): + hosted_llm_credentials.openai = HostedOpenAICredential(api_key=app.config.get("OPENAI_API_KEY")) diff --git a/api/core/agent/agent_builder.py b/api/core/agent/agent_builder.py new file mode 100644 index 0000000000..b1d6948467 --- /dev/null +++ b/api/core/agent/agent_builder.py @@ -0,0 +1,89 @@ +from typing import Optional + +from langchain import LLMChain +from langchain.agents import ZeroShotAgent, AgentExecutor, ConversationalAgent +from langchain.callbacks import CallbackManager +from langchain.memory.chat_memory import BaseChatMemory + +from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler +from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler +from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler +from core.llm.llm_builder import LLMBuilder + + +class AgentBuilder: + @classmethod + def to_agent_chain(cls, tenant_id: str, tools, memory: Optional[BaseChatMemory], + dataset_tool_callback_handler: DatasetToolCallbackHandler, + agent_loop_gather_callback_handler: AgentLoopGatherCallbackHandler): + llm_callback_manager = CallbackManager([agent_loop_gather_callback_handler, DifyStdOutCallbackHandler()]) + llm = LLMBuilder.to_llm( + tenant_id=tenant_id, + model_name=agent_loop_gather_callback_handler.model_name, + temperature=0, + max_tokens=1024, + callback_manager=llm_callback_manager + ) + + tool_callback_manager = CallbackManager([ + agent_loop_gather_callback_handler, + dataset_tool_callback_handler, + DifyStdOutCallbackHandler() + ]) + + for tool in tools: + tool.callback_manager = tool_callback_manager + + prompt = cls.build_agent_prompt_template( + tools=tools, + memory=memory, + ) + + agent_llm_chain = LLMChain( + llm=llm, + prompt=prompt, + ) + + agent = cls.build_agent(agent_llm_chain=agent_llm_chain, memory=memory) + + agent_callback_manager = CallbackManager( + [agent_loop_gather_callback_handler, DifyStdOutCallbackHandler()] + ) + + agent_chain = AgentExecutor.from_agent_and_tools( + tools=tools, + agent=agent, + memory=memory, + callback_manager=agent_callback_manager, + max_iterations=6, + early_stopping_method="generate", + # `generate` will continue to complete the last inference after reaching the iteration limit or request time limit + ) + + return agent_chain + + @classmethod + def build_agent_prompt_template(cls, tools, memory: Optional[BaseChatMemory]): + if memory: + prompt = ConversationalAgent.create_prompt( + tools=tools, + ) + else: + prompt = ZeroShotAgent.create_prompt( + tools=tools, + ) + + return prompt + + @classmethod + def build_agent(cls, agent_llm_chain: LLMChain, memory: Optional[BaseChatMemory]): + if memory: + agent = ConversationalAgent( + llm_chain=agent_llm_chain + ) + else: + agent = ZeroShotAgent( + llm_chain=agent_llm_chain + ) + + return agent diff --git a/api/core/callback_handler/agent_loop_gather_callback_handler.py b/api/core/callback_handler/agent_loop_gather_callback_handler.py new file mode 100644 index 0000000000..f37411cacc --- /dev/null +++ b/api/core/callback_handler/agent_loop_gather_callback_handler.py @@ -0,0 +1,178 @@ +import logging +import time + +from typing import Any, Dict, List, Union, Optional + +from langchain.callbacks.base import BaseCallbackHandler +from langchain.schema import AgentAction, AgentFinish, LLMResult + +from core.callback_handler.entity.agent_loop import AgentLoop +from core.conversation_message_task import ConversationMessageTask + + +class AgentLoopGatherCallbackHandler(BaseCallbackHandler): + """Callback Handler that prints to std out.""" + + def __init__(self, model_name, conversation_message_task: ConversationMessageTask) -> None: + """Initialize callback handler.""" + self.model_name = model_name + self.conversation_message_task = conversation_message_task + self._agent_loops = [] + self._current_loop = None + self.current_chain = None + + @property + def agent_loops(self) -> List[AgentLoop]: + return self._agent_loops + + def clear_agent_loops(self) -> None: + self._agent_loops = [] + self._current_loop = None + + @property + def always_verbose(self) -> bool: + """Whether to call verbose callbacks even if verbose is False.""" + return True + + @property + def ignore_chain(self) -> bool: + """Whether to ignore chain callbacks.""" + return True + + def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + ) -> None: + """Print out the prompts.""" + # serialized={'name': 'OpenAI'} + # prompts=['Answer the following questions...\nThought:'] + # kwargs={} + if not self._current_loop: + # Agent start with a LLM query + self._current_loop = AgentLoop( + position=len(self._agent_loops) + 1, + prompt=prompts[0], + status='llm_started', + started_at=time.perf_counter() + ) + + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + """Do nothing.""" + # kwargs={} + if self._current_loop and self._current_loop.status == 'llm_started': + self._current_loop.status = 'llm_end' + self._current_loop.prompt_tokens = response.llm_output['token_usage']['prompt_tokens'] + self._current_loop.completion = response.generations[0][0].text + self._current_loop.completion_tokens = response.llm_output['token_usage']['completion_tokens'] + + def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + """Do nothing.""" + pass + + def on_llm_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: + logging.error(error) + self._agent_loops = [] + self._current_loop = None + + def on_chain_start( + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + ) -> None: + """Print out that we are entering a chain.""" + pass + + def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: + """Print out that we finished a chain.""" + pass + + def on_chain_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: + logging.error(error) + + def on_tool_start( + self, + serialized: Dict[str, Any], + input_str: str, + **kwargs: Any, + ) -> None: + """Do nothing.""" + # kwargs={'color': 'green', 'llm_prefix': 'Thought:', 'observation_prefix': 'Observation: '} + # input_str='action-input' + # serialized={'description': 'A search engine. Useful for when you need to answer questions about current events. Input should be a search query.', 'name': 'Search'} + pass + + def on_agent_action( + self, action: AgentAction, color: Optional[str] = None, **kwargs: Any + ) -> Any: + """Run on agent action.""" + tool = action.tool + tool_input = action.tool_input + action_name_position = action.log.index("\nAction:") + 1 if action.log else -1 + thought = action.log[:action_name_position].strip() if action.log else '' + + if self._current_loop and self._current_loop.status == 'llm_end': + self._current_loop.status = 'agent_action' + self._current_loop.thought = thought + self._current_loop.tool_name = tool + self._current_loop.tool_input = tool_input + + def on_tool_end( + self, + output: str, + color: Optional[str] = None, + observation_prefix: Optional[str] = None, + llm_prefix: Optional[str] = None, + **kwargs: Any, + ) -> None: + """If not the final action, print out observation.""" + # kwargs={'name': 'Search'} + # llm_prefix='Thought:' + # observation_prefix='Observation: ' + # output='53 years' + + if self._current_loop and self._current_loop.status == 'agent_action' and output and output != 'None': + self._current_loop.status = 'tool_end' + self._current_loop.tool_output = output + self._current_loop.completed = True + self._current_loop.completed_at = time.perf_counter() + self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at + + self.conversation_message_task.on_agent_end(self.current_chain, self.model_name, self._current_loop) + + self._agent_loops.append(self._current_loop) + self._current_loop = None + + def on_tool_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: + """Do nothing.""" + logging.error(error) + self._agent_loops = [] + self._current_loop = None + + def on_text( + self, + text: str, + color: Optional[str] = None, + end: str = "", + **kwargs: Optional[str], + ) -> None: + """Run on additional input from chains and agents.""" + pass + + def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any: + """Run on agent end.""" + # Final Answer + if self._current_loop and (self._current_loop.status == 'llm_end' or self._current_loop.status == 'agent_action'): + self._current_loop.status = 'agent_finish' + self._current_loop.completed = True + self._current_loop.completed_at = time.perf_counter() + self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at + + self.conversation_message_task.on_agent_end(self.current_chain, self.model_name, self._current_loop) + + self._agent_loops.append(self._current_loop) + self._current_loop = None + elif not self._current_loop and self._agent_loops: + self._agent_loops[-1].status = 'agent_finish' diff --git a/api/core/callback_handler/dataset_tool_callback_handler.py b/api/core/callback_handler/dataset_tool_callback_handler.py new file mode 100644 index 0000000000..e3fce66511 --- /dev/null +++ b/api/core/callback_handler/dataset_tool_callback_handler.py @@ -0,0 +1,117 @@ +import logging + +from typing import Any, Dict, List, Union, Optional + +from langchain.callbacks.base import BaseCallbackHandler +from langchain.schema import AgentAction, AgentFinish, LLMResult + +from core.callback_handler.entity.dataset_query import DatasetQueryObj +from core.conversation_message_task import ConversationMessageTask + + +class DatasetToolCallbackHandler(BaseCallbackHandler): + """Callback Handler that prints to std out.""" + + def __init__(self, conversation_message_task: ConversationMessageTask) -> None: + """Initialize callback handler.""" + self.queries = [] + self.conversation_message_task = conversation_message_task + + @property + def always_verbose(self) -> bool: + """Whether to call verbose callbacks even if verbose is False.""" + return True + + @property + def ignore_llm(self) -> bool: + """Whether to ignore LLM callbacks.""" + return True + + @property + def ignore_chain(self) -> bool: + """Whether to ignore chain callbacks.""" + return True + + @property + def ignore_agent(self) -> bool: + """Whether to ignore agent callbacks.""" + return False + + def on_tool_start( + self, + serialized: Dict[str, Any], + input_str: str, + **kwargs: Any, + ) -> None: + tool_name = serialized.get('name') + dataset_id = tool_name[len("dataset-"):] + self.conversation_message_task.on_dataset_query_end(DatasetQueryObj(dataset_id=dataset_id, query=input_str)) + + def on_tool_end( + self, + output: str, + color: Optional[str] = None, + observation_prefix: Optional[str] = None, + llm_prefix: Optional[str] = None, + **kwargs: Any, + ) -> None: + # kwargs={'name': 'Search'} + # llm_prefix='Thought:' + # observation_prefix='Observation: ' + # output='53 years' + pass + + def on_tool_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: + """Do nothing.""" + logging.error(error) + + def on_chain_start( + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + ) -> None: + pass + + def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: + pass + + def on_chain_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: + pass + + def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + ) -> None: + pass + + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + pass + + def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + """Do nothing.""" + pass + + def on_llm_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: + logging.error(error) + + def on_agent_action( + self, action: AgentAction, color: Optional[str] = None, **kwargs: Any + ) -> Any: + pass + + def on_text( + self, + text: str, + color: Optional[str] = None, + end: str = "", + **kwargs: Optional[str], + ) -> None: + """Run on additional input from chains and agents.""" + pass + + def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any: + """Run on agent end.""" + pass diff --git a/api/core/callback_handler/entity/agent_loop.py b/api/core/callback_handler/entity/agent_loop.py new file mode 100644 index 0000000000..13ed4caa7f --- /dev/null +++ b/api/core/callback_handler/entity/agent_loop.py @@ -0,0 +1,23 @@ +from pydantic import BaseModel + + +class AgentLoop(BaseModel): + position: int = 1 + + thought: str = None + tool_name: str = None + tool_input: str = None + tool_output: str = None + + prompt: str = None + prompt_tokens: int = None + completion: str = None + completion_tokens: int = None + + latency: float = None + + status: str = 'llm_started' + completed: bool = False + + started_at: float = None + completed_at: float = None \ No newline at end of file diff --git a/api/core/callback_handler/entity/chain_result.py b/api/core/callback_handler/entity/chain_result.py new file mode 100644 index 0000000000..596486cdb0 --- /dev/null +++ b/api/core/callback_handler/entity/chain_result.py @@ -0,0 +1,16 @@ +from pydantic import BaseModel + + +class ChainResult(BaseModel): + type: str = None + prompt: dict = None + completion: dict = None + + status: str = 'chain_started' + completed: bool = False + + started_at: float = None + completed_at: float = None + + agent_result: dict = None + """only when type is 'AgentExecutor'""" diff --git a/api/core/callback_handler/entity/dataset_query.py b/api/core/callback_handler/entity/dataset_query.py new file mode 100644 index 0000000000..23705e55ac --- /dev/null +++ b/api/core/callback_handler/entity/dataset_query.py @@ -0,0 +1,6 @@ +from pydantic import BaseModel + + +class DatasetQueryObj(BaseModel): + dataset_id: str = None + query: str = None diff --git a/api/core/callback_handler/entity/llm_message.py b/api/core/callback_handler/entity/llm_message.py new file mode 100644 index 0000000000..0f53295ae9 --- /dev/null +++ b/api/core/callback_handler/entity/llm_message.py @@ -0,0 +1,9 @@ +from pydantic import BaseModel + + +class LLMMessage(BaseModel): + prompt: str = '' + prompt_tokens: int = 0 + completion: str = '' + completion_tokens: int = 0 + latency: float = 0.0 diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py new file mode 100644 index 0000000000..f0c9379413 --- /dev/null +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -0,0 +1,38 @@ +from llama_index import Response + +from extensions.ext_database import db +from models.dataset import DocumentSegment + + +class IndexToolCallbackHandler: + + def __init__(self) -> None: + self._response = None + + @property + def response(self) -> Response: + return self._response + + def on_tool_end(self, response: Response) -> None: + """Handle tool end.""" + self._response = response + + +class DatasetIndexToolCallbackHandler(IndexToolCallbackHandler): + """Callback handler for dataset tool.""" + + def __init__(self, dataset_id: str) -> None: + super().__init__() + self.dataset_id = dataset_id + + def on_tool_end(self, response: Response) -> None: + """Handle tool end.""" + for node in response.source_nodes: + index_node_id = node.node.doc_id + + # add hit count to document segment + db.session.query(DocumentSegment).filter( + DocumentSegment.dataset_id == self.dataset_id, + DocumentSegment.index_node_id == index_node_id + ).update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False) + diff --git a/api/core/callback_handler/llm_callback_handler.py b/api/core/callback_handler/llm_callback_handler.py new file mode 100644 index 0000000000..b6f7ef2f54 --- /dev/null +++ b/api/core/callback_handler/llm_callback_handler.py @@ -0,0 +1,147 @@ +import logging +import time +from typing import Any, Dict, List, Union, Optional + +from langchain.callbacks.base import BaseCallbackHandler +from langchain.schema import AgentAction, AgentFinish, LLMResult, HumanMessage, AIMessage, SystemMessage + +from core.callback_handler.entity.llm_message import LLMMessage +from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException +from core.llm.streamable_chat_open_ai import StreamableChatOpenAI +from core.llm.streamable_open_ai import StreamableOpenAI + + +class LLMCallbackHandler(BaseCallbackHandler): + + def __init__(self, llm: Union[StreamableOpenAI, StreamableChatOpenAI], + conversation_message_task: ConversationMessageTask): + self.llm = llm + self.llm_message = LLMMessage() + self.start_at = None + self.conversation_message_task = conversation_message_task + + @property + def always_verbose(self) -> bool: + """Whether to call verbose callbacks even if verbose is False.""" + return True + + def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + ) -> None: + self.start_at = time.perf_counter() + + if 'Chat' in serialized['name']: + real_prompts = [] + messages = [] + for prompt in prompts: + role, content = prompt.split(': ', maxsplit=1) + if role == 'human': + role = 'user' + message = HumanMessage(content=content) + elif role == 'ai': + role = 'assistant' + message = AIMessage(content=content) + else: + message = SystemMessage(content=content) + + real_prompt = { + "role": role, + "text": content + } + real_prompts.append(real_prompt) + messages.append(message) + + self.llm_message.prompt = real_prompts + self.llm_message.prompt_tokens = self.llm.get_messages_tokens(messages) + else: + self.llm_message.prompt = [{ + "role": 'user', + "text": prompts[0] + }] + + self.llm_message.prompt_tokens = self.llm.get_num_tokens(prompts[0]) + + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + end_at = time.perf_counter() + self.llm_message.latency = end_at - self.start_at + + if not self.conversation_message_task.streaming: + self.conversation_message_task.append_message_text(response.generations[0][0].text) + self.llm_message.completion = response.generations[0][0].text + self.llm_message.completion_tokens = response.llm_output['token_usage']['completion_tokens'] + else: + self.llm_message.completion_tokens = self.llm.get_num_tokens(self.llm_message.completion) + + self.conversation_message_task.save_message(self.llm_message) + + def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + self.conversation_message_task.append_message_text(token) + self.llm_message.completion += token + + def on_llm_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: + """Do nothing.""" + if isinstance(error, ConversationTaskStoppedException): + if self.conversation_message_task.streaming: + end_at = time.perf_counter() + self.llm_message.latency = end_at - self.start_at + self.llm_message.completion_tokens = self.llm.get_num_tokens(self.llm_message.completion) + self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True) + else: + logging.error(error) + + def on_chain_start( + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + ) -> None: + pass + + def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: + pass + + def on_chain_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: + pass + + def on_tool_start( + self, + serialized: Dict[str, Any], + input_str: str, + **kwargs: Any, + ) -> None: + pass + + def on_agent_action( + self, action: AgentAction, color: Optional[str] = None, **kwargs: Any + ) -> Any: + pass + + def on_tool_end( + self, + output: str, + color: Optional[str] = None, + observation_prefix: Optional[str] = None, + llm_prefix: Optional[str] = None, + **kwargs: Any, + ) -> None: + pass + + def on_tool_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: + pass + + def on_text( + self, + text: str, + color: Optional[str] = None, + end: str = "", + **kwargs: Optional[str], + ) -> None: + pass + + def on_agent_finish( + self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any + ) -> None: + pass diff --git a/api/core/callback_handler/main_chain_gather_callback_handler.py b/api/core/callback_handler/main_chain_gather_callback_handler.py new file mode 100644 index 0000000000..1bd41edd6c --- /dev/null +++ b/api/core/callback_handler/main_chain_gather_callback_handler.py @@ -0,0 +1,137 @@ +import logging +import time + +from typing import Any, Dict, List, Union, Optional + +from langchain.callbacks.base import BaseCallbackHandler +from langchain.schema import AgentAction, AgentFinish, LLMResult + +from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler +from core.callback_handler.entity.chain_result import ChainResult +from core.constant import llm_constant +from core.conversation_message_task import ConversationMessageTask + + +class MainChainGatherCallbackHandler(BaseCallbackHandler): + """Callback Handler that prints to std out.""" + + def __init__(self, conversation_message_task: ConversationMessageTask) -> None: + """Initialize callback handler.""" + self._current_chain_result = None + self._current_chain_message = None + self.conversation_message_task = conversation_message_task + self.agent_loop_gather_callback_handler = AgentLoopGatherCallbackHandler( + llm_constant.agent_model_name, + conversation_message_task + ) + + def clear_chain_results(self) -> None: + self._current_chain_result = None + self._current_chain_message = None + self.agent_loop_gather_callback_handler.current_chain = None + + @property + def always_verbose(self) -> bool: + """Whether to call verbose callbacks even if verbose is False.""" + return True + + @property + def ignore_llm(self) -> bool: + """Whether to ignore LLM callbacks.""" + return True + + @property + def ignore_agent(self) -> bool: + """Whether to ignore agent callbacks.""" + return True + + def on_chain_start( + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + ) -> None: + """Print out that we are entering a chain.""" + if not self._current_chain_result: + self._current_chain_result = ChainResult( + type=serialized['name'], + prompt=inputs, + started_at=time.perf_counter() + ) + self._current_chain_message = self.conversation_message_task.init_chain(self._current_chain_result) + self.agent_loop_gather_callback_handler.current_chain = self._current_chain_message + + def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: + """Print out that we finished a chain.""" + if self._current_chain_result and self._current_chain_result.status == 'chain_started': + self._current_chain_result.status = 'chain_ended' + self._current_chain_result.completion = outputs + self._current_chain_result.completed = True + self._current_chain_result.completed_at = time.perf_counter() + + self.conversation_message_task.on_chain_end(self._current_chain_message, self._current_chain_result) + + self.clear_chain_results() + + def on_chain_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: + logging.error(error) + self.clear_chain_results() + + def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + ) -> None: + pass + + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + pass + + def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + """Do nothing.""" + pass + + def on_llm_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: + logging.error(error) + + def on_tool_start( + self, + serialized: Dict[str, Any], + input_str: str, + **kwargs: Any, + ) -> None: + pass + + def on_agent_action( + self, action: AgentAction, color: Optional[str] = None, **kwargs: Any + ) -> Any: + pass + + def on_tool_end( + self, + output: str, + color: Optional[str] = None, + observation_prefix: Optional[str] = None, + llm_prefix: Optional[str] = None, + **kwargs: Any, + ) -> None: + pass + + def on_tool_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: + """Do nothing.""" + logging.error(error) + + def on_text( + self, + text: str, + color: Optional[str] = None, + end: str = "", + **kwargs: Optional[str], + ) -> None: + """Run on additional input from chains and agents.""" + pass + + def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any: + """Run on agent end.""" + pass diff --git a/api/core/callback_handler/std_out_callback_handler.py b/api/core/callback_handler/std_out_callback_handler.py new file mode 100644 index 0000000000..352e6cb4d8 --- /dev/null +++ b/api/core/callback_handler/std_out_callback_handler.py @@ -0,0 +1,127 @@ +import sys +from typing import Any, Dict, List, Optional, Union + +from langchain.callbacks.base import BaseCallbackHandler +from langchain.input import print_text +from langchain.schema import AgentAction, AgentFinish, LLMResult + + +class DifyStdOutCallbackHandler(BaseCallbackHandler): + """Callback Handler that prints to std out.""" + + def __init__(self, color: Optional[str] = None) -> None: + """Initialize callback handler.""" + self.color = color + + def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + ) -> None: + """Print out the prompts.""" + print_text("\n[on_llm_start]\n", color='blue') + + if 'Chat' in serialized['name']: + for prompt in prompts: + print_text(prompt + "\n", color='blue') + else: + print_text(prompts[0] + "\n", color='blue') + + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + """Do nothing.""" + print_text("\n[on_llm_end]\nOutput: " + str(response.generations[0][0].text) + "\nllm_output: " + str( + response.llm_output) + "\n", color='blue') + + def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + """Do nothing.""" + pass + + def on_llm_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: + """Do nothing.""" + print_text("\n[on_llm_error]\nError: " + str(error) + "\n", color='blue') + + def on_chain_start( + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + ) -> None: + """Print out that we are entering a chain.""" + class_name = serialized["name"] + print_text("\n[on_chain_start]\nChain: " + class_name + "\nInputs: " + str(inputs) + "\n", color='pink') + + def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: + """Print out that we finished a chain.""" + print_text("\n[on_chain_end]\nOutputs: " + str(outputs) + "\n", color='pink') + + def on_chain_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: + """Do nothing.""" + print_text("\n[on_chain_error]\nError: " + str(error) + "\n", color='pink') + + def on_tool_start( + self, + serialized: Dict[str, Any], + input_str: str, + **kwargs: Any, + ) -> None: + """Do nothing.""" + print_text("\n[on_tool_start] " + str(serialized), color='yellow') + + def on_agent_action( + self, action: AgentAction, color: Optional[str] = None, **kwargs: Any + ) -> Any: + """Run on agent action.""" + tool = action.tool + tool_input = action.tool_input + action_name_position = action.log.index("\nAction:") + 1 if action.log else -1 + thought = action.log[:action_name_position].strip() if action.log else '' + + log = f"Thought: {thought}\nTool: {tool}\nTool Input: {tool_input}" + print_text("\n[on_agent_action]\n" + log + "\n", color='green') + + def on_tool_end( + self, + output: str, + color: Optional[str] = None, + observation_prefix: Optional[str] = None, + llm_prefix: Optional[str] = None, + **kwargs: Any, + ) -> None: + """If not the final action, print out observation.""" + print_text("\n[on_tool_end]\n", color='yellow') + if observation_prefix: + print_text(f"\n{observation_prefix}") + print_text(output, color='yellow') + if llm_prefix: + print_text(f"\n{llm_prefix}") + print_text("\n") + + def on_tool_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: + """Do nothing.""" + print_text("\n[on_tool_error] Error: " + str(error) + "\n", color='yellow') + + def on_text( + self, + text: str, + color: Optional[str] = None, + end: str = "", + **kwargs: Optional[str], + ) -> None: + """Run when agent ends.""" + print_text("\n[on_text] " + text + "\n", color=color if color else self.color, end=end) + + def on_agent_finish( + self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any + ) -> None: + """Run on agent end.""" + print_text("[on_agent_finish] " + finish.return_values['output'] + "\n", color='green', end="\n") + + +class DifyStreamingStdOutCallbackHandler(DifyStdOutCallbackHandler): + """Callback handler for streaming. Only works with LLMs that support streaming.""" + + def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + """Run on new LLM token. Only available when streaming is enabled.""" + sys.stdout.write(token) + sys.stdout.flush() diff --git a/api/core/chain/chain_builder.py b/api/core/chain/chain_builder.py new file mode 100644 index 0000000000..b7583ed890 --- /dev/null +++ b/api/core/chain/chain_builder.py @@ -0,0 +1,34 @@ +from typing import Optional + +from langchain.callbacks import CallbackManager + +from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler +from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain +from core.chain.tool_chain import ToolChain + + +class ChainBuilder: + @classmethod + def to_tool_chain(cls, tool, **kwargs) -> ToolChain: + return ToolChain( + tool=tool, + input_key=kwargs.get('input_key', 'input'), + output_key=kwargs.get('output_key', 'tool_output'), + callback_manager=CallbackManager([DifyStdOutCallbackHandler()]) + ) + + @classmethod + def to_sensitive_word_avoidance_chain(cls, tool_config: dict, **kwargs) -> Optional[ + SensitiveWordAvoidanceChain]: + sensitive_words = tool_config.get("words", "") + if tool_config.get("enabled", False) \ + and sensitive_words: + return SensitiveWordAvoidanceChain( + sensitive_words=sensitive_words.split(","), + canned_response=tool_config.get("canned_response", ''), + output_key="sensitive_word_avoidance_output", + callback_manager=CallbackManager([DifyStdOutCallbackHandler()]), + **kwargs + ) + + return None diff --git a/api/core/chain/main_chain_builder.py b/api/core/chain/main_chain_builder.py new file mode 100644 index 0000000000..5a4ab2214d --- /dev/null +++ b/api/core/chain/main_chain_builder.py @@ -0,0 +1,116 @@ +from typing import Optional, List + +from langchain.callbacks import SharedCallbackManager +from langchain.chains import SequentialChain +from langchain.chains.base import Chain +from langchain.memory.chat_memory import BaseChatMemory + +from core.agent.agent_builder import AgentBuilder +from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler +from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler +from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler +from core.chain.chain_builder import ChainBuilder +from core.constant import llm_constant +from core.conversation_message_task import ConversationMessageTask +from core.tool.dataset_tool_builder import DatasetToolBuilder + + +class MainChainBuilder: + @classmethod + def to_langchain_components(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory], + conversation_message_task: ConversationMessageTask): + first_input_key = "input" + final_output_key = "output" + + chains = [] + + chain_callback_handler = MainChainGatherCallbackHandler(conversation_message_task) + + # agent mode + tool_chains, chains_output_key = cls.get_agent_chains( + tenant_id=tenant_id, + agent_mode=agent_mode, + memory=memory, + dataset_tool_callback_handler=DatasetToolCallbackHandler(conversation_message_task), + agent_loop_gather_callback_handler=chain_callback_handler.agent_loop_gather_callback_handler + ) + chains += tool_chains + + if chains_output_key: + final_output_key = chains_output_key + + if len(chains) == 0: + return None + + for chain in chains: + # do not add handler into singleton callback manager + if not isinstance(chain.callback_manager, SharedCallbackManager): + chain.callback_manager.add_handler(chain_callback_handler) + + # build main chain + overall_chain = SequentialChain( + chains=chains, + input_variables=[first_input_key], + output_variables=[final_output_key], + memory=memory, # only for use the memory prompt input key + ) + + return overall_chain + + @classmethod + def get_agent_chains(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory], + dataset_tool_callback_handler: DatasetToolCallbackHandler, + agent_loop_gather_callback_handler: AgentLoopGatherCallbackHandler): + # agent mode + chains = [] + if agent_mode and agent_mode.get('enabled'): + tools = agent_mode.get('tools', []) + + pre_fixed_chains = [] + agent_tools = [] + for tool in tools: + tool_type = list(tool.keys())[0] + tool_config = list(tool.values())[0] + if tool_type == 'sensitive-word-avoidance': + chain = ChainBuilder.to_sensitive_word_avoidance_chain(tool_config) + if chain: + pre_fixed_chains.append(chain) + elif tool_type == "dataset": + dataset_tool = DatasetToolBuilder.build_dataset_tool( + tenant_id=tenant_id, + dataset_id=tool_config.get("id"), + response_mode='no_synthesizer', # "compact" + callback_handler=dataset_tool_callback_handler + ) + + if dataset_tool: + agent_tools.append(dataset_tool) + + # add pre-fixed chains + chains += pre_fixed_chains + + if len(agent_tools) == 1: + # tool to chain + tool_chain = ChainBuilder.to_tool_chain(tool=agent_tools[0], output_key='tool_output') + chains.append(tool_chain) + elif len(agent_tools) > 1: + # build agent config + agent_chain = AgentBuilder.to_agent_chain( + tenant_id=tenant_id, + tools=agent_tools, + memory=memory, + dataset_tool_callback_handler=dataset_tool_callback_handler, + agent_loop_gather_callback_handler=agent_loop_gather_callback_handler + ) + + chains.append(agent_chain) + + final_output_key = cls.get_chains_output_key(chains) + + return chains, final_output_key + + @classmethod + def get_chains_output_key(cls, chains: List[Chain]): + if len(chains) > 0: + return chains[-1].output_keys[0] + return None diff --git a/api/core/chain/sensitive_word_avoidance_chain.py b/api/core/chain/sensitive_word_avoidance_chain.py new file mode 100644 index 0000000000..a552551c0f --- /dev/null +++ b/api/core/chain/sensitive_word_avoidance_chain.py @@ -0,0 +1,42 @@ +from typing import List, Dict + +from langchain.chains.base import Chain + + +class SensitiveWordAvoidanceChain(Chain): + input_key: str = "input" #: :meta private: + output_key: str = "output" #: :meta private: + + sensitive_words: List[str] = [] + canned_response: str = None + + @property + def _chain_type(self) -> str: + return "sensitive_word_avoidance_chain" + + @property + def input_keys(self) -> List[str]: + """Expect input key. + + :meta private: + """ + return [self.input_key] + + @property + def output_keys(self) -> List[str]: + """Return output key. + + :meta private: + """ + return [self.output_key] + + def _check_sensitive_word(self, text: str) -> str: + for word in self.sensitive_words: + if word in text: + return self.canned_response + return text + + def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: + text = inputs[self.input_key] + output = self._check_sensitive_word(text) + return {self.output_key: output} diff --git a/api/core/chain/tool_chain.py b/api/core/chain/tool_chain.py new file mode 100644 index 0000000000..458a35eb82 --- /dev/null +++ b/api/core/chain/tool_chain.py @@ -0,0 +1,42 @@ +from typing import List, Dict + +from langchain.chains.base import Chain +from langchain.tools import BaseTool + + +class ToolChain(Chain): + input_key: str = "input" #: :meta private: + output_key: str = "output" #: :meta private: + + tool: BaseTool + + @property + def _chain_type(self) -> str: + return "tool_chain" + + @property + def input_keys(self) -> List[str]: + """Expect input key. + + :meta private: + """ + return [self.input_key] + + @property + def output_keys(self) -> List[str]: + """Return output key. + + :meta private: + """ + return [self.output_key] + + def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: + input = inputs[self.input_key] + output = self.tool.run(input, self.verbose) + return {self.output_key: output} + + async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]: + """Run the logic of this chain and return the output.""" + input = inputs[self.input_key] + output = await self.tool.arun(input, self.verbose) + return {self.output_key: output} diff --git a/api/core/completion.py b/api/core/completion.py new file mode 100644 index 0000000000..f215bd0ee5 --- /dev/null +++ b/api/core/completion.py @@ -0,0 +1,326 @@ +from typing import Optional, List, Union + +from langchain.callbacks import CallbackManager +from langchain.chat_models.base import BaseChatModel +from langchain.llms import BaseLLM +from langchain.schema import BaseMessage, BaseLanguageModel, HumanMessage +from core.constant import llm_constant +from core.callback_handler.llm_callback_handler import LLMCallbackHandler +from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \ + DifyStdOutCallbackHandler +from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException +from core.llm.error import LLMBadRequestError +from core.llm.llm_builder import LLMBuilder +from core.chain.main_chain_builder import MainChainBuilder +from core.llm.streamable_chat_open_ai import StreamableChatOpenAI +from core.llm.streamable_open_ai import StreamableOpenAI +from core.memory.read_only_conversation_token_db_buffer_shared_memory import \ + ReadOnlyConversationTokenDBBufferSharedMemory +from core.memory.read_only_conversation_token_db_string_buffer_shared_memory import \ + ReadOnlyConversationTokenDBStringBufferSharedMemory +from core.prompt.prompt_builder import PromptBuilder +from core.prompt.prompt_template import OutLinePromptTemplate +from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT +from models.model import App, AppModelConfig, Account, Conversation, Message + + +class Completion: + @classmethod + def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict, + user: Account, conversation: Optional[Conversation], streaming: bool, is_override: bool = False): + """ + errors: ProviderTokenNotInitError + """ + cls.validate_query_tokens(app.tenant_id, app_model_config, query) + + memory = None + if conversation: + # get memory of conversation (read-only) + memory = cls.get_memory_from_conversation( + tenant_id=app.tenant_id, + app_model_config=app_model_config, + conversation=conversation + ) + + inputs = conversation.inputs + + conversation_message_task = ConversationMessageTask( + task_id=task_id, + app=app, + app_model_config=app_model_config, + user=user, + conversation=conversation, + is_override=is_override, + inputs=inputs, + query=query, + streaming=streaming + ) + + # build main chain include agent + main_chain = MainChainBuilder.to_langchain_components( + tenant_id=app.tenant_id, + agent_mode=app_model_config.agent_mode_dict, + memory=ReadOnlyConversationTokenDBStringBufferSharedMemory(memory=memory) if memory else None, + conversation_message_task=conversation_message_task + ) + + chain_output = '' + if main_chain: + chain_output = main_chain.run(query) + + # run the final llm + try: + cls.run_final_llm( + tenant_id=app.tenant_id, + mode=app.mode, + app_model_config=app_model_config, + query=query, + inputs=inputs, + chain_output=chain_output, + conversation_message_task=conversation_message_task, + memory=memory, + streaming=streaming + ) + except ConversationTaskStoppedException: + return + + @classmethod + def run_final_llm(cls, tenant_id: str, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict, + chain_output: str, + conversation_message_task: ConversationMessageTask, + memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory], streaming: bool): + final_llm = LLMBuilder.to_llm_from_model( + tenant_id=tenant_id, + model=app_model_config.model_dict, + streaming=streaming + ) + + # get llm prompt + prompt = cls.get_main_llm_prompt( + mode=mode, + llm=final_llm, + pre_prompt=app_model_config.pre_prompt, + query=query, + inputs=inputs, + chain_output=chain_output, + memory=memory + ) + + final_llm.callback_manager = cls.get_llm_callback_manager(final_llm, streaming, conversation_message_task) + + cls.recale_llm_max_tokens( + final_llm=final_llm, + prompt=prompt, + mode=mode + ) + + response = final_llm.generate([prompt]) + + return response + + @classmethod + def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, pre_prompt: str, query: str, inputs: dict, chain_output: Optional[str], + memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \ + Union[str | List[BaseMessage]]: + pre_prompt = PromptBuilder.process_template(pre_prompt) if pre_prompt else pre_prompt + if mode == 'completion': + prompt_template = OutLinePromptTemplate.from_template( + template=("Use the following pieces of [CONTEXT] to answer the question at the end. " + "If you don't know the answer, " + "just say that you don't know, don't try to make up an answer. \n" + "```\n" + "[CONTEXT]\n" + "{context}\n" + "```\n" if chain_output else "") + + (pre_prompt + "\n" if pre_prompt else "") + + "{query}\n" + ) + + if chain_output: + inputs['context'] = chain_output + + prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs} + prompt_content = prompt_template.format( + query=query, + **prompt_inputs + ) + + if isinstance(llm, BaseChatModel): + # use chat llm as completion model + return [HumanMessage(content=prompt_content)] + else: + return prompt_content + else: + messages: List[BaseMessage] = [] + + system_message = None + if pre_prompt: + # append pre prompt as system message + system_message = PromptBuilder.to_system_message(pre_prompt, inputs) + + if chain_output: + # append context as system message, currently only use simple stuff prompt + context_message = PromptBuilder.to_system_message( + """Use the following pieces of [CONTEXT] to answer the users question. +If you don't know the answer, just say that you don't know, don't try to make up an answer. +``` +[CONTEXT] +{context} +```""", + {'context': chain_output} + ) + + if not system_message: + system_message = context_message + else: + system_message.content = context_message.content + "\n\n" + system_message.content + + if system_message: + messages.append(system_message) + + human_inputs = { + "query": query + } + + # construct main prompt + human_message = PromptBuilder.to_human_message( + prompt_content="{query}", + inputs=human_inputs + ) + + if memory: + # append chat histories + tmp_messages = messages.copy() + [human_message] + curr_message_tokens = memory.llm.get_messages_tokens(tmp_messages) + rest_tokens = llm_constant.max_context_token_length[ + memory.llm.model_name] - memory.llm.max_tokens - curr_message_tokens + rest_tokens = max(rest_tokens, 0) + history_messages = cls.get_history_messages_from_memory(memory, rest_tokens) + messages += history_messages + + messages.append(human_message) + + return messages + + @classmethod + def get_llm_callback_manager(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI], + streaming: bool, conversation_message_task: ConversationMessageTask) -> CallbackManager: + llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task) + if streaming: + callback_handlers = [llm_callback_handler, DifyStreamingStdOutCallbackHandler()] + else: + callback_handlers = [llm_callback_handler, DifyStdOutCallbackHandler()] + + return CallbackManager(callback_handlers) + + @classmethod + def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory, + max_token_limit: int) -> \ + List[BaseMessage]: + """Get memory messages.""" + memory.max_token_limit = max_token_limit + memory_key = memory.memory_variables[0] + external_context = memory.load_memory_variables({}) + return external_context[memory_key] + + @classmethod + def get_memory_from_conversation(cls, tenant_id: str, app_model_config: AppModelConfig, + conversation: Conversation, + **kwargs) -> ReadOnlyConversationTokenDBBufferSharedMemory: + # only for calc token in memory + memory_llm = LLMBuilder.to_llm_from_model( + tenant_id=tenant_id, + model=app_model_config.model_dict + ) + + # use llm config from conversation + memory = ReadOnlyConversationTokenDBBufferSharedMemory( + conversation=conversation, + llm=memory_llm, + max_token_limit=kwargs.get("max_token_limit", 2048), + memory_key=kwargs.get("memory_key", "chat_history"), + return_messages=kwargs.get("return_messages", True), + input_key=kwargs.get("input_key", "input"), + output_key=kwargs.get("output_key", "output"), + message_limit=kwargs.get("message_limit", 10), + ) + + return memory + + @classmethod + def validate_query_tokens(cls, tenant_id: str, app_model_config: AppModelConfig, query: str): + llm = LLMBuilder.to_llm_from_model( + tenant_id=tenant_id, + model=app_model_config.model_dict + ) + + model_limited_tokens = llm_constant.max_context_token_length[llm.model_name] + max_tokens = llm.max_tokens + + if model_limited_tokens - max_tokens - llm.get_num_tokens(query) < 0: + raise LLMBadRequestError("Query is too long") + + @classmethod + def recale_llm_max_tokens(cls, final_llm: Union[StreamableOpenAI, StreamableChatOpenAI], + prompt: Union[str, List[BaseMessage]], mode: str): + # recalc max_tokens if sum(prompt_token + max_tokens) over model token limit + model_limited_tokens = llm_constant.max_context_token_length[final_llm.model_name] + max_tokens = final_llm.max_tokens + + if mode == 'completion' and isinstance(final_llm, BaseLLM): + prompt_tokens = final_llm.get_num_tokens(prompt) + else: + prompt_tokens = final_llm.get_messages_tokens(prompt) + + if prompt_tokens + max_tokens > model_limited_tokens: + max_tokens = max(model_limited_tokens - prompt_tokens, 16) + final_llm.max_tokens = max_tokens + + @classmethod + def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str, + app_model_config: AppModelConfig, user: Account, streaming: bool): + llm: StreamableOpenAI = LLMBuilder.to_llm( + tenant_id=app.tenant_id, + model_name='gpt-3.5-turbo', + streaming=streaming + ) + + # get llm prompt + original_prompt = cls.get_main_llm_prompt( + mode="completion", + llm=llm, + pre_prompt=pre_prompt, + query=message.query, + inputs=message.inputs, + chain_output=None, + memory=None + ) + + original_completion = message.answer.strip() + + prompt = MORE_LIKE_THIS_GENERATE_PROMPT + prompt = prompt.format(prompt=original_prompt, original_completion=original_completion) + + if isinstance(llm, BaseChatModel): + prompt = [HumanMessage(content=prompt)] + + conversation_message_task = ConversationMessageTask( + task_id=task_id, + app=app, + app_model_config=app_model_config, + user=user, + inputs=message.inputs, + query=message.query, + is_override=True if message.override_model_configs else False, + streaming=streaming + ) + + llm.callback_manager = cls.get_llm_callback_manager(llm, streaming, conversation_message_task) + + cls.recale_llm_max_tokens( + final_llm=llm, + prompt=prompt, + mode='completion' + ) + + llm.generate([prompt]) diff --git a/api/core/constant/llm_constant.py b/api/core/constant/llm_constant.py new file mode 100644 index 0000000000..6879ec5b06 --- /dev/null +++ b/api/core/constant/llm_constant.py @@ -0,0 +1,84 @@ +from _decimal import Decimal + +models = { + 'gpt-4': 'openai', # 8,192 tokens + 'gpt-4-32k': 'openai', # 32,768 tokens + 'gpt-3.5-turbo': 'openai', # 4,096 tokens + 'text-davinci-003': 'openai', # 4,097 tokens + 'text-davinci-002': 'openai', # 4,097 tokens + 'text-curie-001': 'openai', # 2,049 tokens + 'text-babbage-001': 'openai', # 2,049 tokens + 'text-ada-001': 'openai', # 2,049 tokens + 'text-embedding-ada-002': 'openai' # 8191 tokens, 1536 dimensions +} + +max_context_token_length = { + 'gpt-4': 8192, + 'gpt-4-32k': 32768, + 'gpt-3.5-turbo': 4096, + 'text-davinci-003': 4097, + 'text-davinci-002': 4097, + 'text-curie-001': 2049, + 'text-babbage-001': 2049, + 'text-ada-001': 2049, + 'text-embedding-ada-002': 8191 +} + +models_by_mode = { + 'chat': [ + 'gpt-4', # 8,192 tokens + 'gpt-4-32k', # 32,768 tokens + 'gpt-3.5-turbo', # 4,096 tokens + ], + 'completion': [ + 'gpt-4', # 8,192 tokens + 'gpt-4-32k', # 32,768 tokens + 'gpt-3.5-turbo', # 4,096 tokens + 'text-davinci-003', # 4,097 tokens + 'text-davinci-002' # 4,097 tokens + 'text-curie-001', # 2,049 tokens + 'text-babbage-001', # 2,049 tokens + 'text-ada-001' # 2,049 tokens + ], + 'embedding': [ + 'text-embedding-ada-002' # 8191 tokens, 1536 dimensions + ] +} + +model_currency = 'USD' + +model_prices = { + 'gpt-4': { + 'prompt': Decimal('0.03'), + 'completion': Decimal('0.06'), + }, + 'gpt-4-32k': { + 'prompt': Decimal('0.06'), + 'completion': Decimal('0.12') + }, + 'gpt-3.5-turbo': { + 'prompt': Decimal('0.002'), + 'completion': Decimal('0.002') + }, + 'text-davinci-003': { + 'prompt': Decimal('0.02'), + 'completion': Decimal('0.02') + }, + 'text-curie-001': { + 'prompt': Decimal('0.002'), + 'completion': Decimal('0.002') + }, + 'text-babbage-001': { + 'prompt': Decimal('0.0005'), + 'completion': Decimal('0.0005') + }, + 'text-ada-001': { + 'prompt': Decimal('0.0004'), + 'completion': Decimal('0.0004') + }, + 'text-embedding-ada-002': { + 'usage': Decimal('0.0004'), + } +} + +agent_model_name = 'text-davinci-003' diff --git a/api/core/conversation_message_task.py b/api/core/conversation_message_task.py new file mode 100644 index 0000000000..0df26637a3 --- /dev/null +++ b/api/core/conversation_message_task.py @@ -0,0 +1,388 @@ +import decimal +import json +from typing import Optional, Union + +from gunicorn.config import User + +from core.callback_handler.entity.agent_loop import AgentLoop +from core.callback_handler.entity.dataset_query import DatasetQueryObj +from core.callback_handler.entity.llm_message import LLMMessage +from core.callback_handler.entity.chain_result import ChainResult +from core.constant import llm_constant +from core.llm.llm_builder import LLMBuilder +from core.llm.provider.llm_provider_service import LLMProviderService +from core.prompt.prompt_builder import PromptBuilder +from core.prompt.prompt_template import OutLinePromptTemplate +from events.message_event import message_was_created +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.dataset import DatasetQuery +from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, MessageChain +from models.provider import ProviderType, Provider + + +class ConversationMessageTask: + def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account, + inputs: dict, query: str, streaming: bool, + conversation: Optional[Conversation] = None, is_override: bool = False): + self.task_id = task_id + + self.app = app + self.tenant_id = app.tenant_id + self.app_model_config = app_model_config + self.is_override = is_override + + self.user = user + self.inputs = inputs + self.query = query + self.streaming = streaming + + self.conversation = conversation + self.is_new_conversation = False + + self.message = None + + self.model_dict = self.app_model_config.model_dict + self.model_name = self.model_dict.get('name') + self.mode = app.mode + + self.init() + + self._pub_handler = PubHandler( + user=self.user, + task_id=self.task_id, + message=self.message, + conversation=self.conversation, + chain_pub=False, # disabled currently + agent_thought_pub=False # disabled currently + ) + + def init(self): + override_model_configs = None + if self.is_override: + override_model_configs = { + "model": self.app_model_config.model_dict, + "pre_prompt": self.app_model_config.pre_prompt, + "agent_mode": self.app_model_config.agent_mode_dict, + "opening_statement": self.app_model_config.opening_statement, + "suggested_questions": self.app_model_config.suggested_questions_list, + "suggested_questions_after_answer": self.app_model_config.suggested_questions_after_answer_dict, + "more_like_this": self.app_model_config.more_like_this_dict, + "user_input_form": self.app_model_config.user_input_form_list, + } + + introduction = '' + system_instruction = '' + system_instruction_tokens = 0 + if self.mode == 'chat': + introduction = self.app_model_config.opening_statement + if introduction: + prompt_template = OutLinePromptTemplate.from_template(template=PromptBuilder.process_template(introduction)) + prompt_inputs = {k: self.inputs[k] for k in prompt_template.input_variables if k in self.inputs} + introduction = prompt_template.format(**prompt_inputs) + + if self.app_model_config.pre_prompt: + pre_prompt = PromptBuilder.process_template(self.app_model_config.pre_prompt) + system_message = PromptBuilder.to_system_message(pre_prompt, self.inputs) + system_instruction = system_message.content + llm = LLMBuilder.to_llm(self.tenant_id, self.model_name) + system_instruction_tokens = llm.get_messages_tokens([system_message]) + + if not self.conversation: + self.is_new_conversation = True + self.conversation = Conversation( + app_id=self.app_model_config.app_id, + app_model_config_id=self.app_model_config.id, + model_provider=self.model_dict.get('provider'), + model_id=self.model_name, + override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, + mode=self.mode, + name='', + inputs=self.inputs, + introduction=introduction, + system_instruction=system_instruction, + system_instruction_tokens=system_instruction_tokens, + status='normal', + from_source=('console' if isinstance(self.user, Account) else 'api'), + from_end_user_id=(self.user.id if isinstance(self.user, EndUser) else None), + from_account_id=(self.user.id if isinstance(self.user, Account) else None), + ) + + db.session.add(self.conversation) + db.session.flush() + + self.message = Message( + app_id=self.app_model_config.app_id, + model_provider=self.model_dict.get('provider'), + model_id=self.model_name, + override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, + conversation_id=self.conversation.id, + inputs=self.inputs, + query=self.query, + message="", + message_tokens=0, + message_unit_price=0, + answer="", + answer_tokens=0, + answer_unit_price=0, + provider_response_latency=0, + total_price=0, + currency=llm_constant.model_currency, + from_source=('console' if isinstance(self.user, Account) else 'api'), + from_end_user_id=(self.user.id if isinstance(self.user, EndUser) else None), + from_account_id=(self.user.id if isinstance(self.user, Account) else None), + agent_based=self.app_model_config.agent_mode_dict.get('enabled'), + ) + + db.session.add(self.message) + db.session.flush() + + def append_message_text(self, text: str): + self._pub_handler.pub_text(text) + + def save_message(self, llm_message: LLMMessage, by_stopped: bool = False): + model_name = self.app_model_config.model_dict.get('name') + + message_tokens = llm_message.prompt_tokens + answer_tokens = llm_message.completion_tokens + message_unit_price = llm_constant.model_prices[model_name]['prompt'] + answer_unit_price = llm_constant.model_prices[model_name]['completion'] + + total_price = self.calc_total_price(message_tokens, message_unit_price, answer_tokens, answer_unit_price) + + self.message.message = llm_message.prompt + self.message.message_tokens = message_tokens + self.message.message_unit_price = message_unit_price + self.message.answer = llm_message.completion.strip() if llm_message.completion else '' + self.message.answer_tokens = answer_tokens + self.message.answer_unit_price = answer_unit_price + self.message.provider_response_latency = llm_message.latency + self.message.total_price = total_price + + self.update_provider_quota() + + db.session.commit() + + message_was_created.send( + self.message, + conversation=self.conversation, + is_first_message=self.is_new_conversation + ) + + if not by_stopped: + self._pub_handler.pub_end() + + def update_provider_quota(self): + llm_provider_service = LLMProviderService( + tenant_id=self.app.tenant_id, + provider_name=self.message.model_provider, + ) + + provider = llm_provider_service.get_provider_db_record() + if provider and provider.provider_type == ProviderType.SYSTEM.value: + db.session.query(Provider).filter( + Provider.tenant_id == self.app.tenant_id, + Provider.quota_limit > Provider.quota_used + ).update({'quota_used': Provider.quota_used + 1}) + + def init_chain(self, chain_result: ChainResult): + message_chain = MessageChain( + message_id=self.message.id, + type=chain_result.type, + input=json.dumps(chain_result.prompt), + output='' + ) + + db.session.add(message_chain) + db.session.flush() + + return message_chain + + def on_chain_end(self, message_chain: MessageChain, chain_result: ChainResult): + message_chain.output = json.dumps(chain_result.completion) + + self._pub_handler.pub_chain(message_chain) + + def on_agent_end(self, message_chain: MessageChain, agent_model_name: str, + agent_loop: AgentLoop): + agent_message_unit_price = llm_constant.model_prices[agent_model_name]['prompt'] + agent_answer_unit_price = llm_constant.model_prices[agent_model_name]['completion'] + + loop_message_tokens = agent_loop.prompt_tokens + loop_answer_tokens = agent_loop.completion_tokens + + loop_total_price = self.calc_total_price( + loop_message_tokens, + agent_message_unit_price, + loop_answer_tokens, + agent_answer_unit_price + ) + + message_agent_loop = MessageAgentThought( + message_id=self.message.id, + message_chain_id=message_chain.id, + position=agent_loop.position, + thought=agent_loop.thought, + tool=agent_loop.tool_name, + tool_input=agent_loop.tool_input, + observation=agent_loop.tool_output, + tool_process_data='', # currently not support + message=agent_loop.prompt, + message_token=loop_message_tokens, + message_unit_price=agent_message_unit_price, + answer=agent_loop.completion, + answer_token=loop_answer_tokens, + answer_unit_price=agent_answer_unit_price, + latency=agent_loop.latency, + tokens=agent_loop.prompt_tokens + agent_loop.completion_tokens, + total_price=loop_total_price, + currency=llm_constant.model_currency, + created_by_role=('account' if isinstance(self.user, Account) else 'end_user'), + created_by=self.user.id + ) + + db.session.add(message_agent_loop) + db.session.flush() + + self._pub_handler.pub_agent_thought(message_agent_loop) + + def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj): + dataset_query = DatasetQuery( + dataset_id=dataset_query_obj.dataset_id, + content=dataset_query_obj.query, + source='app', + source_app_id=self.app.id, + created_by_role=('account' if isinstance(self.user, Account) else 'end_user'), + created_by=self.user.id + ) + + db.session.add(dataset_query) + + def calc_total_price(self, message_tokens, message_unit_price, answer_tokens, answer_unit_price): + message_tokens_per_1k = (decimal.Decimal(message_tokens) / 1000).quantize(decimal.Decimal('0.001'), + rounding=decimal.ROUND_HALF_UP) + answer_tokens_per_1k = (decimal.Decimal(answer_tokens) / 1000).quantize(decimal.Decimal('0.001'), + rounding=decimal.ROUND_HALF_UP) + + total_price = message_tokens_per_1k * message_unit_price + answer_tokens_per_1k * answer_unit_price + return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) + + +class PubHandler: + def __init__(self, user: Union[Account | User], task_id: str, + message: Message, conversation: Conversation, + chain_pub: bool = False, agent_thought_pub: bool = False): + self._channel = PubHandler.generate_channel_name(user, task_id) + self._stopped_cache_key = PubHandler.generate_stopped_cache_key(user, task_id) + + self._task_id = task_id + self._message = message + self._conversation = conversation + self._chain_pub = chain_pub + self._agent_thought_pub = agent_thought_pub + + @classmethod + def generate_channel_name(cls, user: Union[Account | User], task_id: str): + user_str = 'account-' + user.id if isinstance(user, Account) else 'end-user-' + user.id + return "generate_result:{}-{}".format(user_str, task_id) + + @classmethod + def generate_stopped_cache_key(cls, user: Union[Account | User], task_id: str): + user_str = 'account-' + user.id if isinstance(user, Account) else 'end-user-' + user.id + return "generate_result_stopped:{}-{}".format(user_str, task_id) + + def pub_text(self, text: str): + content = { + 'event': 'message', + 'data': { + 'task_id': self._task_id, + 'message_id': self._message.id, + 'text': text, + 'mode': self._conversation.mode, + 'conversation_id': self._conversation.id + } + } + + redis_client.publish(self._channel, json.dumps(content)) + + if self._is_stopped(): + self.pub_end() + raise ConversationTaskStoppedException() + + def pub_chain(self, message_chain: MessageChain): + if self._chain_pub: + content = { + 'event': 'chain', + 'data': { + 'task_id': self._task_id, + 'message_id': self._message.id, + 'chain_id': message_chain.id, + 'type': message_chain.type, + 'input': json.loads(message_chain.input), + 'output': json.loads(message_chain.output), + 'mode': self._conversation.mode, + 'conversation_id': self._conversation.id + } + } + + redis_client.publish(self._channel, json.dumps(content)) + + if self._is_stopped(): + self.pub_end() + raise ConversationTaskStoppedException() + + def pub_agent_thought(self, message_agent_thought: MessageAgentThought): + if self._agent_thought_pub: + content = { + 'event': 'agent_thought', + 'data': { + 'task_id': self._task_id, + 'message_id': self._message.id, + 'chain_id': message_agent_thought.message_chain_id, + 'agent_thought_id': message_agent_thought.id, + 'position': message_agent_thought.position, + 'thought': message_agent_thought.thought, + 'tool': message_agent_thought.tool, + 'tool_input': message_agent_thought.tool_input, + 'observation': message_agent_thought.observation, + 'answer': message_agent_thought.answer, + 'mode': self._conversation.mode, + 'conversation_id': self._conversation.id + } + } + + redis_client.publish(self._channel, json.dumps(content)) + + if self._is_stopped(): + self.pub_end() + raise ConversationTaskStoppedException() + + + def pub_end(self): + content = { + 'event': 'end', + } + + redis_client.publish(self._channel, json.dumps(content)) + + @classmethod + def pub_error(cls, user: Union[Account | User], task_id: str, e): + content = { + 'error': type(e).__name__, + 'description': e.description if getattr(e, 'description', None) is not None else str(e) + } + + channel = cls.generate_channel_name(user, task_id) + redis_client.publish(channel, json.dumps(content)) + + def _is_stopped(self): + return redis_client.get(self._stopped_cache_key) is not None + + @classmethod + def stop(cls, user: Union[Account | User], task_id: str): + stopped_cache_key = cls.generate_stopped_cache_key(user, task_id) + redis_client.setex(stopped_cache_key, 600, 1) + + +class ConversationTaskStoppedException(Exception): + pass diff --git a/api/core/docstore/dataset_docstore.py b/api/core/docstore/dataset_docstore.py new file mode 100644 index 0000000000..b3b968532c --- /dev/null +++ b/api/core/docstore/dataset_docstore.py @@ -0,0 +1,190 @@ +from typing import Any, Dict, Optional, Sequence + +import tiktoken +from llama_index.data_structs import Node +from llama_index.docstore.types import BaseDocumentStore +from llama_index.docstore.utils import json_to_doc +from llama_index.schema import BaseDocument +from sqlalchemy import func + +from core.llm.token_calculator import TokenCalculator +from extensions.ext_database import db +from models.dataset import Dataset, DocumentSegment + + +class DatesetDocumentStore(BaseDocumentStore): + def __init__( + self, + dataset: Dataset, + user_id: str, + embedding_model_name: str, + document_id: Optional[str] = None, + ): + self._dataset = dataset + self._user_id = user_id + self._embedding_model_name = embedding_model_name + self._document_id = document_id + + @classmethod + def from_dict(cls, config_dict: Dict[str, Any]) -> "DatesetDocumentStore": + return cls(**config_dict) + + def to_dict(self) -> Dict[str, Any]: + """Serialize to dict.""" + return { + "dataset_id": self._dataset.id, + } + + @property + def dateset_id(self) -> Any: + return self._dataset.id + + @property + def user_id(self) -> Any: + return self._user_id + + @property + def embedding_model_name(self) -> Any: + return self._embedding_model_name + + @property + def docs(self) -> Dict[str, BaseDocument]: + document_segments = db.session.query(DocumentSegment).filter( + DocumentSegment.dataset_id == self._dataset.id + ).all() + + output = {} + for document_segment in document_segments: + doc_id = document_segment.index_node_id + result = self.segment_to_dict(document_segment) + output[doc_id] = json_to_doc(result) + + return output + + def add_documents( + self, docs: Sequence[BaseDocument], allow_update: bool = True + ) -> None: + max_position = db.session.query(func.max(DocumentSegment.position)).filter( + DocumentSegment.document == self._document_id + ).scalar() + + if max_position is None: + max_position = 0 + + for doc in docs: + if doc.is_doc_id_none: + raise ValueError("doc_id not set") + + if not isinstance(doc, Node): + raise ValueError("doc must be a Node") + + segment_document = self.get_document(doc_id=doc.get_doc_id(), raise_error=False) + + # NOTE: doc could already exist in the store, but we overwrite it + if not allow_update and segment_document: + raise ValueError( + f"doc_id {doc.get_doc_id()} already exists. " + "Set allow_update to True to overwrite." + ) + + # calc embedding use tokens + tokens = TokenCalculator.get_num_tokens(self._embedding_model_name, doc.get_text()) + + if not segment_document: + max_position += 1 + + segment_document = DocumentSegment( + tenant_id=self._dataset.tenant_id, + dataset_id=self._dataset.id, + document_id=self._document_id, + index_node_id=doc.get_doc_id(), + index_node_hash=doc.get_doc_hash(), + position=max_position, + content=doc.get_text(), + word_count=len(doc.get_text()), + tokens=tokens, + created_by=self._user_id, + ) + db.session.add(segment_document) + else: + segment_document.content = doc.get_text() + segment_document.index_node_hash = doc.get_doc_hash() + segment_document.word_count = len(doc.get_text()) + segment_document.tokens = tokens + + db.session.commit() + + def document_exists(self, doc_id: str) -> bool: + """Check if document exists.""" + result = self.get_document_segment(doc_id) + return result is not None + + def get_document( + self, doc_id: str, raise_error: bool = True + ) -> Optional[BaseDocument]: + document_segment = self.get_document_segment(doc_id) + + if document_segment is None: + if raise_error: + raise ValueError(f"doc_id {doc_id} not found.") + else: + return None + + result = self.segment_to_dict(document_segment) + return json_to_doc(result) + + def delete_document(self, doc_id: str, raise_error: bool = True) -> None: + document_segment = self.get_document_segment(doc_id) + + if document_segment is None: + if raise_error: + raise ValueError(f"doc_id {doc_id} not found.") + else: + return None + + db.session.delete(document_segment) + db.session.commit() + + def set_document_hash(self, doc_id: str, doc_hash: str) -> None: + """Set the hash for a given doc_id.""" + document_segment = self.get_document_segment(doc_id) + + if document_segment is None: + return None + + document_segment.index_node_hash = doc_hash + db.session.commit() + + def get_document_hash(self, doc_id: str) -> Optional[str]: + """Get the stored hash for a document, if it exists.""" + document_segment = self.get_document_segment(doc_id) + + if document_segment is None: + return None + + return document_segment.index_node_hash + + def update_docstore(self, other: "BaseDocumentStore") -> None: + """Update docstore. + + Args: + other (BaseDocumentStore): docstore to update from + + """ + self.add_documents(list(other.docs.values())) + + def get_document_segment(self, doc_id: str) -> DocumentSegment: + document_segment = db.session.query(DocumentSegment).filter( + DocumentSegment.dataset_id == self._dataset.id, + DocumentSegment.index_node_id == doc_id + ).first() + + return document_segment + + def segment_to_dict(self, segment: DocumentSegment) -> Dict[str, Any]: + return { + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "text": segment.content, + "__type__": Node.get_type() + } diff --git a/api/core/docstore/empty_docstore.py b/api/core/docstore/empty_docstore.py new file mode 100644 index 0000000000..e19f1824cb --- /dev/null +++ b/api/core/docstore/empty_docstore.py @@ -0,0 +1,51 @@ +from typing import Any, Dict, Optional, Sequence +from llama_index.docstore.types import BaseDocumentStore +from llama_index.schema import BaseDocument + + +class EmptyDocumentStore(BaseDocumentStore): + @classmethod + def from_dict(cls, config_dict: Dict[str, Any]) -> "EmptyDocumentStore": + return cls() + + def to_dict(self) -> Dict[str, Any]: + """Serialize to dict.""" + return {} + + @property + def docs(self) -> Dict[str, BaseDocument]: + return {} + + def add_documents( + self, docs: Sequence[BaseDocument], allow_update: bool = True + ) -> None: + pass + + def document_exists(self, doc_id: str) -> bool: + """Check if document exists.""" + return False + + def get_document( + self, doc_id: str, raise_error: bool = True + ) -> Optional[BaseDocument]: + return None + + def delete_document(self, doc_id: str, raise_error: bool = True) -> None: + pass + + def set_document_hash(self, doc_id: str, doc_hash: str) -> None: + """Set the hash for a given doc_id.""" + pass + + def get_document_hash(self, doc_id: str) -> Optional[str]: + """Get the stored hash for a document, if it exists.""" + return None + + def update_docstore(self, other: "BaseDocumentStore") -> None: + """Update docstore. + + Args: + other (BaseDocumentStore): docstore to update from + + """ + self.add_documents(list(other.docs.values())) diff --git a/api/core/embedding/openai_embedding.py b/api/core/embedding/openai_embedding.py new file mode 100644 index 0000000000..0938397423 --- /dev/null +++ b/api/core/embedding/openai_embedding.py @@ -0,0 +1,176 @@ +from typing import Optional, Any, List + +import openai +from llama_index.embeddings.base import BaseEmbedding +from llama_index.embeddings.openai import OpenAIEmbeddingMode, OpenAIEmbeddingModelType, _QUERY_MODE_MODEL_DICT, \ + _TEXT_MODE_MODEL_DICT +from tenacity import wait_random_exponential, retry, stop_after_attempt + +from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async + + +@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) +def get_embedding( + text: str, + engine: Optional[str] = None, + openai_api_key: Optional[str] = None, +) -> List[float]: + """Get embedding. + + NOTE: Copied from OpenAI's embedding utils: + https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py + + Copied here to avoid importing unnecessary dependencies + like matplotlib, plotly, scipy, sklearn. + + """ + text = text.replace("\n", " ") + return openai.Embedding.create(input=[text], engine=engine, api_key=openai_api_key)["data"][0]["embedding"] + + +@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) +async def aget_embedding(text: str, engine: Optional[str] = None, openai_api_key: Optional[str] = None) -> List[float]: + """Asynchronously get embedding. + + NOTE: Copied from OpenAI's embedding utils: + https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py + + Copied here to avoid importing unnecessary dependencies + like matplotlib, plotly, scipy, sklearn. + + """ + # replace newlines, which can negatively affect performance. + text = text.replace("\n", " ") + + return (await openai.Embedding.acreate(input=[text], engine=engine, api_key=openai_api_key))["data"][0][ + "embedding" + ] + + +@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) +def get_embeddings( + list_of_text: List[str], + engine: Optional[str] = None, + openai_api_key: Optional[str] = None +) -> List[List[float]]: + """Get embeddings. + + NOTE: Copied from OpenAI's embedding utils: + https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py + + Copied here to avoid importing unnecessary dependencies + like matplotlib, plotly, scipy, sklearn. + + """ + assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048." + + # replace newlines, which can negatively affect performance. + list_of_text = [text.replace("\n", " ") for text in list_of_text] + + data = openai.Embedding.create(input=list_of_text, engine=engine, api_key=openai_api_key).data + data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input. + return [d["embedding"] for d in data] + + +@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) +async def aget_embeddings( + list_of_text: List[str], engine: Optional[str] = None, openai_api_key: Optional[str] = None +) -> List[List[float]]: + """Asynchronously get embeddings. + + NOTE: Copied from OpenAI's embedding utils: + https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py + + Copied here to avoid importing unnecessary dependencies + like matplotlib, plotly, scipy, sklearn. + + """ + assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048." + + # replace newlines, which can negatively affect performance. + list_of_text = [text.replace("\n", " ") for text in list_of_text] + + data = (await openai.Embedding.acreate(input=list_of_text, engine=engine, api_key=openai_api_key)).data + data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input. + return [d["embedding"] for d in data] + + +class OpenAIEmbedding(BaseEmbedding): + + def __init__( + self, + mode: str = OpenAIEmbeddingMode.TEXT_SEARCH_MODE, + model: str = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002, + deployment_name: Optional[str] = None, + openai_api_key: Optional[str] = None, + **kwargs: Any, + ) -> None: + """Init params.""" + super().__init__(**kwargs) + self.mode = OpenAIEmbeddingMode(mode) + self.model = OpenAIEmbeddingModelType(model) + self.deployment_name = deployment_name + self.openai_api_key = openai_api_key + + @handle_llm_exceptions + def _get_query_embedding(self, query: str) -> List[float]: + """Get query embedding.""" + if self.deployment_name is not None: + engine = self.deployment_name + else: + key = (self.mode, self.model) + if key not in _QUERY_MODE_MODEL_DICT: + raise ValueError(f"Invalid mode, model combination: {key}") + engine = _QUERY_MODE_MODEL_DICT[key] + return get_embedding(query, engine=engine, openai_api_key=self.openai_api_key) + + def _get_text_embedding(self, text: str) -> List[float]: + """Get text embedding.""" + if self.deployment_name is not None: + engine = self.deployment_name + else: + key = (self.mode, self.model) + if key not in _TEXT_MODE_MODEL_DICT: + raise ValueError(f"Invalid mode, model combination: {key}") + engine = _TEXT_MODE_MODEL_DICT[key] + return get_embedding(text, engine=engine, openai_api_key=self.openai_api_key) + + async def _aget_text_embedding(self, text: str) -> List[float]: + """Asynchronously get text embedding.""" + if self.deployment_name is not None: + engine = self.deployment_name + else: + key = (self.mode, self.model) + if key not in _TEXT_MODE_MODEL_DICT: + raise ValueError(f"Invalid mode, model combination: {key}") + engine = _TEXT_MODE_MODEL_DICT[key] + return await aget_embedding(text, engine=engine, openai_api_key=self.openai_api_key) + + def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: + """Get text embeddings. + + By default, this is a wrapper around _get_text_embedding. + Can be overriden for batch queries. + + """ + if self.deployment_name is not None: + engine = self.deployment_name + else: + key = (self.mode, self.model) + if key not in _TEXT_MODE_MODEL_DICT: + raise ValueError(f"Invalid mode, model combination: {key}") + engine = _TEXT_MODE_MODEL_DICT[key] + embeddings = get_embeddings(texts, engine=engine, openai_api_key=self.openai_api_key) + return embeddings + + async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]: + """Asynchronously get text embeddings.""" + if self.deployment_name is not None: + engine = self.deployment_name + else: + key = (self.mode, self.model) + if key not in _TEXT_MODE_MODEL_DICT: + raise ValueError(f"Invalid mode, model combination: {key}") + engine = _TEXT_MODE_MODEL_DICT[key] + embeddings = await aget_embeddings(texts, engine=engine, openai_api_key=self.openai_api_key) + return embeddings diff --git a/api/core/generator/llm_generator.py b/api/core/generator/llm_generator.py new file mode 100644 index 0000000000..67e5753007 --- /dev/null +++ b/api/core/generator/llm_generator.py @@ -0,0 +1,120 @@ +import logging + +from langchain.chat_models.base import BaseChatModel +from langchain.schema import HumanMessage + +from core.constant import llm_constant +from core.llm.llm_builder import LLMBuilder +from core.llm.streamable_open_ai import StreamableOpenAI +from core.llm.token_calculator import TokenCalculator + +from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser +from core.prompt.prompt_template import OutLinePromptTemplate +from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, CONVERSATION_SUMMARY_PROMPT, INTRODUCTION_GENERATE_PROMPT + + +# gpt-3.5-turbo works not well +generate_base_model = 'text-davinci-003' + + +class LLMGenerator: + @classmethod + def generate_conversation_name(cls, tenant_id: str, query, answer): + prompt = CONVERSATION_TITLE_PROMPT + prompt = prompt.format(query=query, answer=answer) + llm: StreamableOpenAI = LLMBuilder.to_llm( + tenant_id=tenant_id, + model_name=generate_base_model, + max_tokens=50 + ) + + if isinstance(llm, BaseChatModel): + prompt = [HumanMessage(content=prompt)] + + response = llm.generate([prompt]) + answer = response.generations[0][0].text + return answer.strip() + + @classmethod + def generate_conversation_summary(cls, tenant_id: str, messages): + max_tokens = 200 + + prompt = CONVERSATION_SUMMARY_PROMPT + prompt_with_empty_context = prompt.format(context='') + prompt_tokens = TokenCalculator.get_num_tokens(generate_base_model, prompt_with_empty_context) + rest_tokens = llm_constant.max_context_token_length[generate_base_model] - prompt_tokens - max_tokens + + context = '' + for message in messages: + if not message.answer: + continue + + message_qa_text = "Human:" + message.query + "\nAI:" + message.answer + "\n" + if rest_tokens - TokenCalculator.get_num_tokens(generate_base_model, context + message_qa_text) > 0: + context += message_qa_text + + prompt = prompt.format(context=context) + + llm: StreamableOpenAI = LLMBuilder.to_llm( + tenant_id=tenant_id, + model_name=generate_base_model, + max_tokens=max_tokens + ) + + if isinstance(llm, BaseChatModel): + prompt = [HumanMessage(content=prompt)] + + response = llm.generate([prompt]) + answer = response.generations[0][0].text + return answer.strip() + + @classmethod + def generate_introduction(cls, tenant_id: str, pre_prompt: str): + prompt = INTRODUCTION_GENERATE_PROMPT + prompt = prompt.format(prompt=pre_prompt) + + llm: StreamableOpenAI = LLMBuilder.to_llm( + tenant_id=tenant_id, + model_name=generate_base_model, + ) + + if isinstance(llm, BaseChatModel): + prompt = [HumanMessage(content=prompt)] + + response = llm.generate([prompt]) + answer = response.generations[0][0].text + return answer.strip() + + @classmethod + def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str): + output_parser = SuggestedQuestionsAfterAnswerOutputParser() + format_instructions = output_parser.get_format_instructions() + + prompt = OutLinePromptTemplate( + template="{histories}\n{format_instructions}\nquestions:\n", + input_variables=["histories"], + partial_variables={"format_instructions": format_instructions} + ) + + _input = prompt.format_prompt(histories=histories) + + llm: StreamableOpenAI = LLMBuilder.to_llm( + tenant_id=tenant_id, + model_name=generate_base_model, + temperature=0, + max_tokens=256 + ) + + if isinstance(llm, BaseChatModel): + query = [HumanMessage(content=_input.to_string())] + else: + query = _input.to_string() + + try: + output = llm(query) + questions = output_parser.parse(output) + except Exception: + logging.exception("Error generating suggested questions after answer") + questions = [] + + return questions diff --git a/api/core/index/index_builder.py b/api/core/index/index_builder.py new file mode 100644 index 0000000000..baf16b0f3a --- /dev/null +++ b/api/core/index/index_builder.py @@ -0,0 +1,45 @@ +from langchain.callbacks import CallbackManager +from llama_index import ServiceContext, PromptHelper, LLMPredictor +from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler +from core.embedding.openai_embedding import OpenAIEmbedding +from core.llm.llm_builder import LLMBuilder + + +class IndexBuilder: + @classmethod + def get_default_service_context(cls, tenant_id: str) -> ServiceContext: + # set number of output tokens + num_output = 512 + + # only for verbose + callback_manager = CallbackManager([DifyStdOutCallbackHandler()]) + + llm = LLMBuilder.to_llm( + tenant_id=tenant_id, + model_name='text-davinci-003', + temperature=0, + max_tokens=num_output, + callback_manager=callback_manager, + ) + + llm_predictor = LLMPredictor(llm=llm) + + # These parameters here will affect the logic of segmenting the final synthesized response. + # The number of refinement iterations in the synthesis process depends + # on whether the length of the segmented output exceeds the max_input_size. + prompt_helper = PromptHelper( + max_input_size=3500, + num_output=num_output, + max_chunk_overlap=20 + ) + + model_credentials = LLMBuilder.get_model_credentials( + tenant_id=tenant_id, + model_name='text-embedding-ada-002' + ) + + return ServiceContext.from_defaults( + llm_predictor=llm_predictor, + prompt_helper=prompt_helper, + embed_model=OpenAIEmbedding(**model_credentials), + ) diff --git a/api/core/index/keyword_table/jieba_keyword_table.py b/api/core/index/keyword_table/jieba_keyword_table.py new file mode 100644 index 0000000000..89dcca5802 --- /dev/null +++ b/api/core/index/keyword_table/jieba_keyword_table.py @@ -0,0 +1,159 @@ +import re +from typing import ( + Any, + Dict, + List, + Set, + Optional +) + +import jieba.analyse + +from core.index.keyword_table.stopwords import STOPWORDS +from llama_index.indices.query.base import IS +from llama_index import QueryMode +from llama_index.indices.base import QueryMap +from llama_index.indices.keyword_table.base import BaseGPTKeywordTableIndex +from llama_index.indices.keyword_table.query import BaseGPTKeywordTableQuery +from llama_index.docstore import BaseDocumentStore +from llama_index.indices.postprocessor.node import ( + BaseNodePostprocessor, +) +from llama_index.indices.response.response_builder import ResponseMode +from llama_index.indices.service_context import ServiceContext +from llama_index.optimization.optimizer import BaseTokenUsageOptimizer +from llama_index.prompts.prompts import ( + QuestionAnswerPrompt, + RefinePrompt, + SimpleInputPrompt, +) + +from core.index.query.synthesizer import EnhanceResponseSynthesizer + + +def jieba_extract_keywords( + text_chunk: str, + max_keywords: Optional[int] = None, + expand_with_subtokens: bool = True, +) -> Set[str]: + """Extract keywords with JIEBA tfidf.""" + keywords = jieba.analyse.extract_tags( + sentence=text_chunk, + topK=max_keywords, + ) + + if expand_with_subtokens: + return set(expand_tokens_with_subtokens(keywords)) + else: + return set(keywords) + + +def expand_tokens_with_subtokens(tokens: Set[str]) -> Set[str]: + """Get subtokens from a list of tokens., filtering for stopwords.""" + results = set() + for token in tokens: + results.add(token) + sub_tokens = re.findall(r"\w+", token) + if len(sub_tokens) > 1: + results.update({w for w in sub_tokens if w not in list(STOPWORDS)}) + + return results + + +class GPTJIEBAKeywordTableIndex(BaseGPTKeywordTableIndex): + """GPT JIEBA Keyword Table Index. + + This index uses a JIEBA keyword extractor to extract keywords from the text. + + """ + + def _extract_keywords(self, text: str) -> Set[str]: + """Extract keywords from text.""" + return jieba_extract_keywords(text, max_keywords=self.max_keywords_per_chunk) + + @classmethod + def get_query_map(self) -> QueryMap: + """Get query map.""" + super_map = super().get_query_map() + super_map[QueryMode.DEFAULT] = GPTKeywordTableJIEBAQuery + return super_map + + def _delete(self, doc_id: str, **delete_kwargs: Any) -> None: + """Delete a document.""" + # get set of ids that correspond to node + node_idxs_to_delete = {doc_id} + + # delete node_idxs from keyword to node idxs mapping + keywords_to_delete = set() + for keyword, node_idxs in self._index_struct.table.items(): + if node_idxs_to_delete.intersection(node_idxs): + self._index_struct.table[keyword] = node_idxs.difference( + node_idxs_to_delete + ) + if not self._index_struct.table[keyword]: + keywords_to_delete.add(keyword) + + for keyword in keywords_to_delete: + del self._index_struct.table[keyword] + + +class GPTKeywordTableJIEBAQuery(BaseGPTKeywordTableQuery): + """GPT Keyword Table Index JIEBA Query. + + Extracts keywords using JIEBA keyword extractor. + Set when `mode="jieba"` in `query` method of `GPTKeywordTableIndex`. + + .. code-block:: python + + response = index.query("", mode="jieba") + + See BaseGPTKeywordTableQuery for arguments. + + """ + + @classmethod + def from_args( + cls, + index_struct: IS, + service_context: ServiceContext, + docstore: Optional[BaseDocumentStore] = None, + node_postprocessors: Optional[List[BaseNodePostprocessor]] = None, + verbose: bool = False, + # response synthesizer args + response_mode: ResponseMode = ResponseMode.DEFAULT, + text_qa_template: Optional[QuestionAnswerPrompt] = None, + refine_template: Optional[RefinePrompt] = None, + simple_template: Optional[SimpleInputPrompt] = None, + response_kwargs: Optional[Dict] = None, + use_async: bool = False, + streaming: bool = False, + optimizer: Optional[BaseTokenUsageOptimizer] = None, + # class-specific args + **kwargs: Any, + ) -> "BaseGPTIndexQuery": + response_synthesizer = EnhanceResponseSynthesizer.from_args( + service_context=service_context, + text_qa_template=text_qa_template, + refine_template=refine_template, + simple_template=simple_template, + response_mode=response_mode, + response_kwargs=response_kwargs, + use_async=use_async, + streaming=streaming, + optimizer=optimizer, + ) + return cls( + index_struct=index_struct, + service_context=service_context, + response_synthesizer=response_synthesizer, + docstore=docstore, + node_postprocessors=node_postprocessors, + verbose=verbose, + **kwargs, + ) + + def _get_keywords(self, query_str: str) -> List[str]: + """Extract keywords.""" + return list( + jieba_extract_keywords(query_str, max_keywords=self.max_keywords_per_query) + ) diff --git a/api/core/index/keyword_table/stopwords.py b/api/core/index/keyword_table/stopwords.py new file mode 100644 index 0000000000..c616a15cf0 --- /dev/null +++ b/api/core/index/keyword_table/stopwords.py @@ -0,0 +1,90 @@ +STOPWORDS = { + "during", "when", "but", "then", "further", "isn", "mustn't", "until", "own", "i", "couldn", "y", "only", "you've", + "ours", "who", "where", "ourselves", "has", "to", "was", "didn't", "themselves", "if", "against", "through", "her", + "an", "your", "can", "those", "didn", "about", "aren't", "shan't", "be", "not", "these", "again", "so", "t", + "theirs", "weren", "won't", "won", "itself", "just", "same", "while", "why", "doesn", "aren", "him", "haven", + "for", "you'll", "that", "we", "am", "d", "by", "having", "wasn't", "than", "weren't", "out", "from", "now", + "their", "too", "hadn", "o", "needn", "most", "it", "under", "needn't", "any", "some", "few", "ll", "hers", "which", + "m", "you're", "off", "other", "had", "she", "you'd", "do", "you", "does", "s", "will", "each", "wouldn't", "hasn't", + "such", "more", "whom", "she's", "my", "yours", "yourself", "of", "on", "very", "hadn't", "with", "yourselves", + "been", "ma", "them", "mightn't", "shan", "mustn", "they", "what", "both", "that'll", "how", "is", "he", "because", + "down", "haven't", "are", "no", "it's", "our", "being", "the", "or", "above", "myself", "once", "don't", "doesn't", + "as", "nor", "here", "herself", "hasn", "mightn", "have", "its", "all", "were", "ain", "this", "at", "after", + "over", "shouldn't", "into", "before", "don", "wouldn", "re", "couldn't", "wasn", "in", "should", "there", + "himself", "isn't", "should've", "doing", "ve", "shouldn", "a", "did", "and", "his", "between", "me", "up", "below", + "人民", "末##末", "啊", "阿", "哎", "哎呀", "哎哟", "唉", "俺", "俺们", "按", "按照", "吧", "吧哒", "把", "罢了", "被", "本", + "本着", "比", "比方", "比如", "鄙人", "彼", "彼此", "边", "别", "别的", "别说", "并", "并且", "不比", "不成", "不单", "不但", + "不独", "不管", "不光", "不过", "不仅", "不拘", "不论", "不怕", "不然", "不如", "不特", "不惟", "不问", "不只", "朝", "朝着", + "趁", "趁着", "乘", "冲", "除", "除此之外", "除非", "除了", "此", "此间", "此外", "从", "从而", "打", "待", "但", "但是", "当", + "当着", "到", "得", "的", "的话", "等", "等等", "地", "第", "叮咚", "对", "对于", "多", "多少", "而", "而况", "而且", "而是", + "而外", "而言", "而已", "尔后", "反过来", "反过来说", "反之", "非但", "非徒", "否则", "嘎", "嘎登", "该", "赶", "个", "各", + "各个", "各位", "各种", "各自", "给", "根据", "跟", "故", "故此", "固然", "关于", "管", "归", "果然", "果真", "过", "哈", + "哈哈", "呵", "和", "何", "何处", "何况", "何时", "嘿", "哼", "哼唷", "呼哧", "乎", "哗", "还是", "还有", "换句话说", "换言之", + "或", "或是", "或者", "极了", "及", "及其", "及至", "即", "即便", "即或", "即令", "即若", "即使", "几", "几时", "己", "既", + "既然", "既是", "继而", "加之", "假如", "假若", "假使", "鉴于", "将", "较", "较之", "叫", "接着", "结果", "借", "紧接着", + "进而", "尽", "尽管", "经", "经过", "就", "就是", "就是说", "据", "具体地说", "具体说来", "开始", "开外", "靠", "咳", "可", + "可见", "可是", "可以", "况且", "啦", "来", "来着", "离", "例如", "哩", "连", "连同", "两者", "了", "临", "另", "另外", + "另一方面", "论", "嘛", "吗", "慢说", "漫说", "冒", "么", "每", "每当", "们", "莫若", "某", "某个", "某些", "拿", "哪", + "哪边", "哪儿", "哪个", "哪里", "哪年", "哪怕", "哪天", "哪些", "哪样", "那", "那边", "那儿", "那个", "那会儿", "那里", "那么", + "那么些", "那么样", "那时", "那些", "那样", "乃", "乃至", "呢", "能", "你", "你们", "您", "宁", "宁可", "宁肯", "宁愿", "哦", + "呕", "啪达", "旁人", "呸", "凭", "凭借", "其", "其次", "其二", "其他", "其它", "其一", "其余", "其中", "起", "起见", "岂但", + "恰恰相反", "前后", "前者", "且", "然而", "然后", "然则", "让", "人家", "任", "任何", "任凭", "如", "如此", "如果", "如何", + "如其", "如若", "如上所述", "若", "若非", "若是", "啥", "上下", "尚且", "设若", "设使", "甚而", "甚么", "甚至", "省得", "时候", + "什么", "什么样", "使得", "是", "是的", "首先", "谁", "谁知", "顺", "顺着", "似的", "虽", "虽然", "虽说", "虽则", "随", "随着", + "所", "所以", "他", "他们", "他人", "它", "它们", "她", "她们", "倘", "倘或", "倘然", "倘若", "倘使", "腾", "替", "通过", "同", + "同时", "哇", "万一", "往", "望", "为", "为何", "为了", "为什么", "为着", "喂", "嗡嗡", "我", "我们", "呜", "呜呼", "乌乎", + "无论", "无宁", "毋宁", "嘻", "吓", "相对而言", "像", "向", "向着", "嘘", "呀", "焉", "沿", "沿着", "要", "要不", "要不然", + "要不是", "要么", "要是", "也", "也罢", "也好", "一", "一般", "一旦", "一方面", "一来", "一切", "一样", "一则", "依", "依照", + "矣", "以", "以便", "以及", "以免", "以至", "以至于", "以致", "抑或", "因", "因此", "因而", "因为", "哟", "用", "由", + "由此可见", "由于", "有", "有的", "有关", "有些", "又", "于", "于是", "于是乎", "与", "与此同时", "与否", "与其", "越是", + "云云", "哉", "再说", "再者", "在", "在下", "咱", "咱们", "则", "怎", "怎么", "怎么办", "怎么样", "怎样", "咋", "照", "照着", + "者", "这", "这边", "这儿", "这个", "这会儿", "这就是说", "这里", "这么", "这么点儿", "这么些", "这么样", "这时", "这些", "这样", + "正如", "吱", "之", "之类", "之所以", "之一", "只是", "只限", "只要", "只有", "至", "至于", "诸位", "着", "着呢", "自", "自从", + "自个儿", "自各儿", "自己", "自家", "自身", "综上所述", "总的来看", "总的来说", "总的说来", "总而言之", "总之", "纵", "纵令", + "纵然", "纵使", "遵照", "作为", "兮", "呃", "呗", "咚", "咦", "喏", "啐", "喔唷", "嗬", "嗯", "嗳", "~", "!", ".", ":", + "\"", "'", "(", ")", "*", "A", "白", "社会主义", "--", "..", ">>", " [", " ]", "", "<", ">", "/", "\\", "|", "-", "_", + "+", "=", "&", "^", "%", "#", "@", "`", ";", "$", "(", ")", "——", "—", "¥", "·", "...", "‘", "’", "〉", "〈", "…", + " ", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "二", + "三", "四", "五", "六", "七", "八", "九", "零", ">", "<", "@", "#", "$", "%", "︿", "&", "*", "+", "~", "|", "[", + "]", "{", "}", "啊哈", "啊呀", "啊哟", "挨次", "挨个", "挨家挨户", "挨门挨户", "挨门逐户", "挨着", "按理", "按期", "按时", + "按说", "暗地里", "暗中", "暗自", "昂然", "八成", "白白", "半", "梆", "保管", "保险", "饱", "背地里", "背靠背", "倍感", "倍加", + "本人", "本身", "甭", "比起", "比如说", "比照", "毕竟", "必", "必定", "必将", "必须", "便", "别人", "并非", "并肩", "并没", + "并没有", "并排", "并无", "勃然", "不", "不必", "不常", "不大", "不但...而且", "不得", "不得不", "不得了", "不得已", "不迭", + "不定", "不对", "不妨", "不管怎样", "不会", "不仅...而且", "不仅仅", "不仅仅是", "不经意", "不可开交", "不可抗拒", "不力", "不了", + "不料", "不满", "不免", "不能不", "不起", "不巧", "不然的话", "不日", "不少", "不胜", "不时", "不是", "不同", "不能", "不要", + "不外", "不外乎", "不下", "不限", "不消", "不已", "不亦乐乎", "不由得", "不再", "不择手段", "不怎么", "不曾", "不知不觉", "不止", + "不止一次", "不至于", "才", "才能", "策略地", "差不多", "差一点", "常", "常常", "常言道", "常言说", "常言说得好", "长此下去", + "长话短说", "长期以来", "长线", "敞开儿", "彻夜", "陈年", "趁便", "趁机", "趁热", "趁势", "趁早", "成年", "成年累月", "成心", + "乘机", "乘胜", "乘势", "乘隙", "乘虚", "诚然", "迟早", "充分", "充其极", "充其量", "抽冷子", "臭", "初", "出", "出来", "出去", + "除此", "除此而外", "除此以外", "除开", "除去", "除却", "除外", "处处", "川流不息", "传", "传说", "传闻", "串行", "纯", "纯粹", + "此后", "此中", "次第", "匆匆", "从不", "从此", "从此以后", "从古到今", "从古至今", "从今以后", "从宽", "从来", "从轻", "从速", + "从头", "从未", "从无到有", "从小", "从新", "从严", "从优", "从早到晚", "从中", "从重", "凑巧", "粗", "存心", "达旦", "打从", + "打开天窗说亮话", "大", "大不了", "大大", "大抵", "大都", "大多", "大凡", "大概", "大家", "大举", "大略", "大面儿上", "大事", + "大体", "大体上", "大约", "大张旗鼓", "大致", "呆呆地", "带", "殆", "待到", "单", "单纯", "单单", "但愿", "弹指之间", "当场", + "当儿", "当即", "当口儿", "当然", "当庭", "当头", "当下", "当真", "当中", "倒不如", "倒不如说", "倒是", "到处", "到底", "到了儿", + "到目前为止", "到头", "到头来", "得起", "得天独厚", "的确", "等到", "叮当", "顶多", "定", "动不动", "动辄", "陡然", "都", "独", + "独自", "断然", "顿时", "多次", "多多", "多多少少", "多多益善", "多亏", "多年来", "多年前", "而后", "而论", "而又", "尔等", + "二话不说", "二话没说", "反倒", "反倒是", "反而", "反手", "反之亦然", "反之则", "方", "方才", "方能", "放量", "非常", "非得", + "分期", "分期分批", "分头", "奋勇", "愤然", "风雨无阻", "逢", "弗", "甫", "嘎嘎", "该当", "概", "赶快", "赶早不赶晚", "敢", + "敢情", "敢于", "刚", "刚才", "刚好", "刚巧", "高低", "格外", "隔日", "隔夜", "个人", "各式", "更", "更加", "更进一步", "更为", + "公然", "共", "共总", "够瞧的", "姑且", "古来", "故而", "故意", "固", "怪", "怪不得", "惯常", "光", "光是", "归根到底", + "归根结底", "过于", "毫不", "毫无", "毫无保留地", "毫无例外", "好在", "何必", "何尝", "何妨", "何苦", "何乐而不为", "何须", + "何止", "很", "很多", "很少", "轰然", "后来", "呼啦", "忽地", "忽然", "互", "互相", "哗啦", "话说", "还", "恍然", "会", "豁然", + "活", "伙同", "或多或少", "或许", "基本", "基本上", "基于", "极", "极大", "极度", "极端", "极力", "极其", "极为", "急匆匆", + "即将", "即刻", "即是说", "几度", "几番", "几乎", "几经", "既...又", "继之", "加上", "加以", "间或", "简而言之", "简言之", + "简直", "见", "将才", "将近", "将要", "交口", "较比", "较为", "接连不断", "接下来", "皆可", "截然", "截至", "藉以", "借此", + "借以", "届时", "仅", "仅仅", "谨", "进来", "进去", "近", "近几年来", "近来", "近年来", "尽管如此", "尽可能", "尽快", "尽量", + "尽然", "尽如人意", "尽心竭力", "尽心尽力", "尽早", "精光", "经常", "竟", "竟然", "究竟", "就此", "就地", "就算", "居然", "局外", + "举凡", "据称", "据此", "据实", "据说", "据我所知", "据悉", "具体来说", "决不", "决非", "绝", "绝不", "绝顶", "绝对", "绝非", + "均", "喀", "看", "看来", "看起来", "看上去", "看样子", "可好", "可能", "恐怕", "快", "快要", "来不及", "来得及", "来讲", + "来看", "拦腰", "牢牢", "老", "老大", "老老实实", "老是", "累次", "累年", "理当", "理该", "理应", "历", "立", "立地", "立刻", + "立马", "立时", "联袂", "连连", "连日", "连日来", "连声", "连袂", "临到", "另方面", "另行", "另一个", "路经", "屡", "屡次", + "屡次三番", "屡屡", "缕缕", "率尔", "率然", "略", "略加", "略微", "略为", "论说", "马上", "蛮", "满", "没", "没有", "每逢", + "每每", "每时每刻", "猛然", "猛然间", "莫", "莫不", "莫非", "莫如", "默默地", "默然", "呐", "那末", "奈", "难道", "难得", "难怪", + "难说", "内", "年复一年", "凝神", "偶而", "偶尔", "怕", "砰", "碰巧", "譬如", "偏偏", "乒", "平素", "颇", "迫于", "扑通", + "其后", "其实", "奇", "齐", "起初", "起来", "起首", "起头", "起先", "岂", "岂非", "岂止", "迄", "恰逢", "恰好", "恰恰", "恰巧", + "恰如", "恰似", "千", "千万", "千万千万", "切", "切不可", "切莫", "切切", "切勿", "窃", "亲口", "亲身", "亲手", "亲眼", "亲自", + "顷", "顷刻", "顷刻间", "顷刻之间", "请勿", "穷年累月", "取道", "去", "权时", "全都", "全力", "全年", "全然", "全身心", "然", + "人人", "仍", "仍旧", "仍然", "日复一日", "日见", "日渐", "日益", "日臻", "如常", "如此等等", "如次", "如今", "如期", "如前所述", + "如上", "如下", "汝", "三番两次", "三番五次", "三天两头", "瑟瑟", "沙沙", "上", "上来", "上去", "一个", "月", "日", "\n" +} diff --git a/api/core/index/keyword_table_index.py b/api/core/index/keyword_table_index.py new file mode 100644 index 0000000000..f0b3905557 --- /dev/null +++ b/api/core/index/keyword_table_index.py @@ -0,0 +1,135 @@ +import json +from typing import List, Optional + +from llama_index import ServiceContext, LLMPredictor, OpenAIEmbedding +from llama_index.data_structs import KeywordTable, Node +from llama_index.indices.keyword_table.base import BaseGPTKeywordTableIndex +from llama_index.indices.registry import load_index_struct_from_dict + +from core.docstore.dataset_docstore import DatesetDocumentStore +from core.docstore.empty_docstore import EmptyDocumentStore +from core.index.index_builder import IndexBuilder +from core.index.keyword_table.jieba_keyword_table import GPTJIEBAKeywordTableIndex +from core.llm.llm_builder import LLMBuilder +from extensions.ext_database import db +from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment + + +class KeywordTableIndex: + + def __init__(self, dataset: Dataset): + self._dataset = dataset + + def add_nodes(self, nodes: List[Node]): + llm = LLMBuilder.to_llm( + tenant_id=self._dataset.tenant_id, + model_name='fake' + ) + + service_context = ServiceContext.from_defaults( + llm_predictor=LLMPredictor(llm=llm), + embed_model=OpenAIEmbedding() + ) + + dataset_keyword_table = self.get_keyword_table() + if not dataset_keyword_table or not dataset_keyword_table.keyword_table_dict: + index_struct = KeywordTable() + else: + index_struct_dict = dataset_keyword_table.keyword_table_dict + index_struct: KeywordTable = load_index_struct_from_dict(index_struct_dict) + + # create index + index = GPTJIEBAKeywordTableIndex( + index_struct=index_struct, + docstore=EmptyDocumentStore(), + service_context=service_context + ) + + for node in nodes: + keywords = index._extract_keywords(node.get_text()) + self.update_segment_keywords(node.doc_id, list(keywords)) + index._index_struct.add_node(list(keywords), node) + + index_struct_dict = index.index_struct.to_dict() + + if not dataset_keyword_table: + dataset_keyword_table = DatasetKeywordTable( + dataset_id=self._dataset.id, + keyword_table=json.dumps(index_struct_dict) + ) + db.session.add(dataset_keyword_table) + else: + dataset_keyword_table.keyword_table = json.dumps(index_struct_dict) + + db.session.commit() + + def del_nodes(self, node_ids: List[str]): + llm = LLMBuilder.to_llm( + tenant_id=self._dataset.tenant_id, + model_name='fake' + ) + + service_context = ServiceContext.from_defaults( + llm_predictor=LLMPredictor(llm=llm), + embed_model=OpenAIEmbedding() + ) + + dataset_keyword_table = self.get_keyword_table() + if not dataset_keyword_table or not dataset_keyword_table.keyword_table_dict: + return + else: + index_struct_dict = dataset_keyword_table.keyword_table_dict + index_struct: KeywordTable = load_index_struct_from_dict(index_struct_dict) + + # create index + index = GPTJIEBAKeywordTableIndex( + index_struct=index_struct, + docstore=EmptyDocumentStore(), + service_context=service_context + ) + + for node_id in node_ids: + index.delete(node_id) + + index_struct_dict = index.index_struct.to_dict() + + if not dataset_keyword_table: + dataset_keyword_table = DatasetKeywordTable( + dataset_id=self._dataset.id, + keyword_table=json.dumps(index_struct_dict) + ) + db.session.add(dataset_keyword_table) + else: + dataset_keyword_table.keyword_table = json.dumps(index_struct_dict) + + db.session.commit() + + @property + def query_index(self) -> Optional[BaseGPTKeywordTableIndex]: + docstore = DatesetDocumentStore( + dataset=self._dataset, + user_id=self._dataset.created_by, + embedding_model_name="text-embedding-ada-002" + ) + + service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id) + + dataset_keyword_table = self.get_keyword_table() + if not dataset_keyword_table or not dataset_keyword_table.keyword_table_dict: + return None + + index_struct: KeywordTable = load_index_struct_from_dict(dataset_keyword_table.keyword_table_dict) + + return GPTJIEBAKeywordTableIndex(index_struct=index_struct, docstore=docstore, service_context=service_context) + + def get_keyword_table(self): + dataset_keyword_table = self._dataset.dataset_keyword_table + if dataset_keyword_table: + return dataset_keyword_table + return None + + def update_segment_keywords(self, node_id: str, keywords: List[str]): + document_segment = db.session.query(DocumentSegment).filter(DocumentSegment.index_node_id == node_id).first() + if document_segment: + document_segment.keywords = keywords + db.session.commit() diff --git a/api/core/index/query/synthesizer.py b/api/core/index/query/synthesizer.py new file mode 100644 index 0000000000..7ab8b4a8ca --- /dev/null +++ b/api/core/index/query/synthesizer.py @@ -0,0 +1,79 @@ +from typing import ( + Any, + Dict, + Optional, Sequence, +) + +from llama_index.indices.response.response_synthesis import ResponseSynthesizer +from llama_index.indices.response.response_builder import ResponseMode, BaseResponseBuilder, get_response_builder +from llama_index.indices.service_context import ServiceContext +from llama_index.optimization.optimizer import BaseTokenUsageOptimizer +from llama_index.prompts.prompts import ( + QuestionAnswerPrompt, + RefinePrompt, + SimpleInputPrompt, +) +from llama_index.types import RESPONSE_TEXT_TYPE + + +class EnhanceResponseSynthesizer(ResponseSynthesizer): + @classmethod + def from_args( + cls, + service_context: ServiceContext, + streaming: bool = False, + use_async: bool = False, + text_qa_template: Optional[QuestionAnswerPrompt] = None, + refine_template: Optional[RefinePrompt] = None, + simple_template: Optional[SimpleInputPrompt] = None, + response_mode: ResponseMode = ResponseMode.DEFAULT, + response_kwargs: Optional[Dict] = None, + optimizer: Optional[BaseTokenUsageOptimizer] = None, + ) -> "ResponseSynthesizer": + response_builder: Optional[BaseResponseBuilder] = None + if response_mode != ResponseMode.NO_TEXT: + if response_mode == 'no_synthesizer': + response_builder = NoSynthesizer( + service_context=service_context, + simple_template=simple_template, + streaming=streaming, + ) + else: + response_builder = get_response_builder( + service_context, + text_qa_template, + refine_template, + simple_template, + response_mode, + use_async=use_async, + streaming=streaming, + ) + return cls(response_builder, response_mode, response_kwargs, optimizer) + + +class NoSynthesizer(BaseResponseBuilder): + def __init__( + self, + service_context: ServiceContext, + simple_template: Optional[SimpleInputPrompt] = None, + streaming: bool = False, + ) -> None: + super().__init__(service_context, streaming) + + async def aget_response( + self, + query_str: str, + text_chunks: Sequence[str], + prev_response: Optional[str] = None, + **response_kwargs: Any, + ) -> RESPONSE_TEXT_TYPE: + return "\n".join(text_chunks) + + def get_response( + self, + query_str: str, + text_chunks: Sequence[str], + prev_response: Optional[str] = None, + **response_kwargs: Any, + ) -> RESPONSE_TEXT_TYPE: + return "\n".join(text_chunks) \ No newline at end of file diff --git a/api/core/index/readers/html_parser.py b/api/core/index/readers/html_parser.py new file mode 100644 index 0000000000..2afadb284c --- /dev/null +++ b/api/core/index/readers/html_parser.py @@ -0,0 +1,22 @@ +from pathlib import Path +from typing import Dict + +from bs4 import BeautifulSoup +from llama_index.readers.file.base_parser import BaseParser + + +class HTMLParser(BaseParser): + """HTML parser.""" + + def _init_parser(self) -> Dict: + """Init parser.""" + return {} + + def parse_file(self, file: Path, errors: str = "ignore") -> str: + """Parse file.""" + with open(file, "rb") as fp: + soup = BeautifulSoup(fp, 'html.parser') + text = soup.get_text() + text = text.strip() if text else '' + + return text diff --git a/api/core/index/readers/pdf_parser.py b/api/core/index/readers/pdf_parser.py new file mode 100644 index 0000000000..81c4840c60 --- /dev/null +++ b/api/core/index/readers/pdf_parser.py @@ -0,0 +1,56 @@ +from pathlib import Path +from typing import Dict + +from flask import current_app +from llama_index.readers.file.base_parser import BaseParser +from pypdf import PdfReader + +from extensions.ext_storage import storage +from models.model import UploadFile + + +class PDFParser(BaseParser): + """PDF parser.""" + + def _init_parser(self) -> Dict: + """Init parser.""" + return {} + + def parse_file(self, file: Path, errors: str = "ignore") -> str: + """Parse file.""" + if not current_app.config.get('PDF_PREVIEW', True): + return '' + + plaintext_file_key = '' + plaintext_file_exists = False + if self._parser_config and 'upload_file' in self._parser_config and self._parser_config['upload_file']: + upload_file: UploadFile = self._parser_config['upload_file'] + if upload_file.hash: + plaintext_file_key = 'upload_files/' + upload_file.tenant_id + '/' + upload_file.hash + '.plaintext' + try: + text = storage.load(plaintext_file_key).decode('utf-8') + plaintext_file_exists = True + return text + except FileNotFoundError: + pass + + text_list = [] + with open(file, "rb") as fp: + # Create a PDF object + pdf = PdfReader(fp) + + # Get the number of pages in the PDF document + num_pages = len(pdf.pages) + + # Iterate over every page + for page in range(num_pages): + # Extract the text from the page + page_text = pdf.pages[page].extract_text() + text_list.append(page_text) + text = "\n".join(text_list) + + # save plaintext file for caching + if not plaintext_file_exists and plaintext_file_key: + storage.save(plaintext_file_key, text.encode('utf-8')) + + return text diff --git a/api/core/index/vector_index.py b/api/core/index/vector_index.py new file mode 100644 index 0000000000..f9d8542a8c --- /dev/null +++ b/api/core/index/vector_index.py @@ -0,0 +1,136 @@ +import json +import logging +from typing import List, Optional + +from llama_index.data_structs import Node +from requests import ReadTimeout +from sqlalchemy.exc import IntegrityError +from tenacity import retry, stop_after_attempt, retry_if_exception_type + +from core.index.index_builder import IndexBuilder +from core.vector_store.base import BaseGPTVectorStoreIndex +from extensions.ext_vector_store import vector_store +from extensions.ext_database import db +from models.dataset import Dataset, Embedding + + +class VectorIndex: + + def __init__(self, dataset: Dataset): + self._dataset = dataset + + def add_nodes(self, nodes: List[Node], duplicate_check: bool = False): + if not self._dataset.index_struct_dict: + index_id = "Vector_index_" + self._dataset.id.replace("-", "_") + self._dataset.index_struct = json.dumps(vector_store.to_index_struct(index_id)) + db.session.commit() + + service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id) + + index = vector_store.get_index( + service_context=service_context, + index_struct=self._dataset.index_struct_dict + ) + + if duplicate_check: + nodes = self._filter_duplicate_nodes(index, nodes) + + embedding_queue_nodes = [] + embedded_nodes = [] + for node in nodes: + node_hash = node.doc_hash + + # if node hash in cached embedding tables, use cached embedding + embedding = db.session.query(Embedding).filter_by(hash=node_hash).first() + if embedding: + node.embedding = embedding.get_embedding() + embedded_nodes.append(node) + else: + embedding_queue_nodes.append(node) + + if embedding_queue_nodes: + embedding_results = index._get_node_embedding_results( + embedding_queue_nodes, + set(), + ) + + # pre embed nodes for cached embedding + for embedding_result in embedding_results: + node = embedding_result.node + node.embedding = embedding_result.embedding + + try: + embedding = Embedding(hash=node.doc_hash) + embedding.set_embedding(node.embedding) + db.session.add(embedding) + db.session.commit() + except IntegrityError: + db.session.rollback() + continue + except: + logging.exception('Failed to add embedding to db') + continue + + embedded_nodes.append(node) + + self.index_insert_nodes(index, embedded_nodes) + + @retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3)) + def index_insert_nodes(self, index: BaseGPTVectorStoreIndex, nodes: List[Node]): + index.insert_nodes(nodes) + + def del_nodes(self, node_ids: List[str]): + if not self._dataset.index_struct_dict: + return + + service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id) + + index = vector_store.get_index( + service_context=service_context, + index_struct=self._dataset.index_struct_dict + ) + + for node_id in node_ids: + self.index_delete_node(index, node_id) + + @retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3)) + def index_delete_node(self, index: BaseGPTVectorStoreIndex, node_id: str): + index.delete_node(node_id) + + def del_doc(self, doc_id: str): + if not self._dataset.index_struct_dict: + return + + service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id) + + index = vector_store.get_index( + service_context=service_context, + index_struct=self._dataset.index_struct_dict + ) + + self.index_delete_doc(index, doc_id) + + @retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3)) + def index_delete_doc(self, index: BaseGPTVectorStoreIndex, doc_id: str): + index.delete(doc_id) + + @property + def query_index(self) -> Optional[BaseGPTVectorStoreIndex]: + if not self._dataset.index_struct_dict: + return None + + service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id) + + return vector_store.get_index( + service_context=service_context, + index_struct=self._dataset.index_struct_dict + ) + + def _filter_duplicate_nodes(self, index: BaseGPTVectorStoreIndex, nodes: List[Node]) -> List[Node]: + for node in nodes: + node_id = node.doc_id + exists_duplicate_node = index.exists_by_node_id(node_id) + if exists_duplicate_node: + nodes.remove(node) + + return nodes diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py new file mode 100644 index 0000000000..fd9e430116 --- /dev/null +++ b/api/core/indexing_runner.py @@ -0,0 +1,467 @@ +import datetime +import json +import re +import tempfile +import time +from pathlib import Path +from typing import Optional, List +from langchain.text_splitter import RecursiveCharacterTextSplitter + +from llama_index import SimpleDirectoryReader +from llama_index.data_structs import Node +from llama_index.data_structs.node_v2 import DocumentRelationship +from llama_index.node_parser import SimpleNodeParser, NodeParser +from llama_index.readers.file.base import DEFAULT_FILE_EXTRACTOR +from llama_index.readers.file.markdown_parser import MarkdownParser + +from core.docstore.dataset_docstore import DatesetDocumentStore +from core.index.keyword_table_index import KeywordTableIndex +from core.index.readers.html_parser import HTMLParser +from core.index.readers.pdf_parser import PDFParser +from core.index.vector_index import VectorIndex +from core.llm.token_calculator import TokenCalculator +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from extensions.ext_storage import storage +from models.dataset import Document, Dataset, DocumentSegment, DatasetProcessRule +from models.model import UploadFile + + +class IndexingRunner: + + def __init__(self, embedding_model_name: str = "text-embedding-ada-002"): + self.storage = storage + self.embedding_model_name = embedding_model_name + + def run(self, document: Document): + """Run the indexing process.""" + # get dataset + dataset = Dataset.query.filter_by( + id=document.dataset_id + ).first() + + if not dataset: + raise ValueError("no dataset found") + + # load file + text_docs = self._load_data(document) + + # get the process rule + processing_rule = db.session.query(DatasetProcessRule). \ + filter(DatasetProcessRule.id == document.dataset_process_rule_id). \ + first() + + # get node parser for splitting + node_parser = self._get_node_parser(processing_rule) + + # split to nodes + nodes = self._step_split( + text_docs=text_docs, + node_parser=node_parser, + dataset=dataset, + document=document, + processing_rule=processing_rule + ) + + # build index + self._build_index( + dataset=dataset, + document=document, + nodes=nodes + ) + + def run_in_splitting_status(self, document: Document): + """Run the indexing process when the index_status is splitting.""" + # get dataset + dataset = Dataset.query.filter_by( + id=document.dataset_id + ).first() + + if not dataset: + raise ValueError("no dataset found") + + # get exist document_segment list and delete + document_segments = DocumentSegment.query.filter_by( + dataset_id=dataset.id, + document_id=document.id + ).all() + db.session.delete(document_segments) + db.session.commit() + # load file + text_docs = self._load_data(document) + + # get the process rule + processing_rule = db.session.query(DatasetProcessRule). \ + filter(DatasetProcessRule.id == document.dataset_process_rule_id). \ + first() + + # get node parser for splitting + node_parser = self._get_node_parser(processing_rule) + + # split to nodes + nodes = self._step_split( + text_docs=text_docs, + node_parser=node_parser, + dataset=dataset, + document=document, + processing_rule=processing_rule + ) + + # build index + self._build_index( + dataset=dataset, + document=document, + nodes=nodes + ) + + def run_in_indexing_status(self, document: Document): + """Run the indexing process when the index_status is indexing.""" + # get dataset + dataset = Dataset.query.filter_by( + id=document.dataset_id + ).first() + + if not dataset: + raise ValueError("no dataset found") + + # get exist document_segment list and delete + document_segments = DocumentSegment.query.filter_by( + dataset_id=dataset.id, + document_id=document.id + ).all() + nodes = [] + if document_segments: + for document_segment in document_segments: + # transform segment to node + if document_segment.status != "completed": + relationships = { + DocumentRelationship.SOURCE: document_segment.document_id, + } + + previous_segment = document_segment.previous_segment + if previous_segment: + relationships[DocumentRelationship.PREVIOUS] = previous_segment.index_node_id + + next_segment = document_segment.next_segment + if next_segment: + relationships[DocumentRelationship.NEXT] = next_segment.index_node_id + node = Node( + doc_id=document_segment.index_node_id, + doc_hash=document_segment.index_node_hash, + text=document_segment.content, + extra_info=None, + node_info=None, + relationships=relationships + ) + nodes.append(node) + + # build index + self._build_index( + dataset=dataset, + document=document, + nodes=nodes + ) + + def indexing_estimate(self, file_detail: UploadFile, tmp_processing_rule: dict) -> dict: + """ + Estimate the indexing for the document. + """ + # load data from file + text_docs = self._load_data_from_file(file_detail) + + processing_rule = DatasetProcessRule( + mode=tmp_processing_rule["mode"], + rules=json.dumps(tmp_processing_rule["rules"]) + ) + + # get node parser for splitting + node_parser = self._get_node_parser(processing_rule) + + # split to nodes + nodes = self._split_to_nodes( + text_docs=text_docs, + node_parser=node_parser, + processing_rule=processing_rule + ) + + tokens = 0 + preview_texts = [] + for node in nodes: + if len(preview_texts) < 5: + preview_texts.append(node.get_text()) + + tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text()) + + return { + "total_segments": len(nodes), + "tokens": tokens, + "total_price": '{:f}'.format(TokenCalculator.get_token_price(self.embedding_model_name, tokens)), + "currency": TokenCalculator.get_currency(self.embedding_model_name), + "preview": preview_texts + } + + def _load_data(self, document: Document) -> List[Document]: + # load file + if document.data_source_type != "upload_file": + return [] + + data_source_info = document.data_source_info_dict + if not data_source_info or 'upload_file_id' not in data_source_info: + raise ValueError("no upload file found") + + file_detail = db.session.query(UploadFile). \ + filter(UploadFile.id == data_source_info['upload_file_id']). \ + one_or_none() + + text_docs = self._load_data_from_file(file_detail) + + # update document status to splitting + self._update_document_index_status( + document_id=document.id, + after_indexing_status="splitting", + extra_update_params={ + Document.file_id: file_detail.id, + Document.word_count: sum([len(text_doc.text) for text_doc in text_docs]), + Document.parsing_completed_at: datetime.datetime.utcnow() + } + ) + + # replace doc id to document model id + for text_doc in text_docs: + # remove invalid symbol + text_doc.text = self.filter_string(text_doc.get_text()) + text_doc.doc_id = document.id + + return text_docs + + def filter_string(self, text): + pattern = re.compile('[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\x80-\xFF]') + return pattern.sub('', text) + + def _load_data_from_file(self, upload_file: UploadFile) -> List[Document]: + with tempfile.TemporaryDirectory() as temp_dir: + suffix = Path(upload_file.key).suffix + filepath = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" + self.storage.download(upload_file.key, filepath) + + file_extractor = DEFAULT_FILE_EXTRACTOR.copy() + file_extractor[".markdown"] = MarkdownParser() + file_extractor[".html"] = HTMLParser() + file_extractor[".htm"] = HTMLParser() + file_extractor[".pdf"] = PDFParser({'upload_file': upload_file}) + + loader = SimpleDirectoryReader(input_files=[filepath], file_extractor=file_extractor) + text_docs = loader.load_data() + + return text_docs + + def _get_node_parser(self, processing_rule: DatasetProcessRule) -> NodeParser: + """ + Get the NodeParser object according to the processing rule. + """ + if processing_rule.mode == "custom": + # The user-defined segmentation rule + rules = json.loads(processing_rule.rules) + segmentation = rules["segmentation"] + if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > 1000: + raise ValueError("Custom segment length should be between 50 and 1000.") + + separator = segmentation["separator"] + if not separator: + separators = ["\n\n", "。", ".", " ", ""] + else: + separator = separator.replace('\\n', '\n') + separators = [separator, ""] + + character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( + chunk_size=segmentation["max_tokens"], + chunk_overlap=0, + separators=separators + ) + else: + # Automatic segmentation + character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( + chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'], + chunk_overlap=0, + separators=["\n\n", "。", ".", " ", ""] + ) + + return SimpleNodeParser(text_splitter=character_splitter, include_extra_info=True) + + def _step_split(self, text_docs: List[Document], node_parser: NodeParser, + dataset: Dataset, document: Document, processing_rule: DatasetProcessRule) -> List[Node]: + """ + Split the text documents into nodes and save them to the document segment. + """ + nodes = self._split_to_nodes( + text_docs=text_docs, + node_parser=node_parser, + processing_rule=processing_rule + ) + + # save node to document segment + doc_store = DatesetDocumentStore( + dataset=dataset, + user_id=document.created_by, + embedding_model_name=self.embedding_model_name, + document_id=document.id + ) + + doc_store.add_documents(nodes) + + # update document status to indexing + cur_time = datetime.datetime.utcnow() + self._update_document_index_status( + document_id=document.id, + after_indexing_status="indexing", + extra_update_params={ + Document.cleaning_completed_at: cur_time, + Document.splitting_completed_at: cur_time, + } + ) + + # update segment status to indexing + self._update_segments_by_document( + document_id=document.id, + update_params={ + DocumentSegment.status: "indexing", + DocumentSegment.indexing_at: datetime.datetime.utcnow() + } + ) + + return nodes + + def _split_to_nodes(self, text_docs: List[Document], node_parser: NodeParser, + processing_rule: DatasetProcessRule) -> List[Node]: + """ + Split the text documents into nodes. + """ + all_nodes = [] + for text_doc in text_docs: + # document clean + document_text = self._document_clean(text_doc.get_text(), processing_rule) + text_doc.text = document_text + + # parse document to nodes + nodes = node_parser.get_nodes_from_documents([text_doc]) + + all_nodes.extend(nodes) + + return all_nodes + + def _document_clean(self, text: str, processing_rule: DatasetProcessRule) -> str: + """ + Clean the document text according to the processing rules. + """ + if processing_rule.mode == "automatic": + rules = DatasetProcessRule.AUTOMATIC_RULES + else: + rules = json.loads(processing_rule.rules) if processing_rule.rules else {} + + if 'pre_processing_rules' in rules: + pre_processing_rules = rules["pre_processing_rules"] + for pre_processing_rule in pre_processing_rules: + if pre_processing_rule["id"] == "remove_extra_spaces" and pre_processing_rule["enabled"] is True: + # Remove extra spaces + pattern = r'\n{3,}' + text = re.sub(pattern, '\n\n', text) + pattern = r'[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}' + text = re.sub(pattern, ' ', text) + elif pre_processing_rule["id"] == "remove_urls_emails" and pre_processing_rule["enabled"] is True: + # Remove email + pattern = r'([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)' + text = re.sub(pattern, '', text) + + # Remove URL + pattern = r'https?://[^\s]+' + text = re.sub(pattern, '', text) + + return text + + def _build_index(self, dataset: Dataset, document: Document, nodes: List[Node]) -> None: + """ + Build the index for the document. + """ + vector_index = VectorIndex(dataset=dataset) + keyword_table_index = KeywordTableIndex(dataset=dataset) + + # chunk nodes by chunk size + indexing_start_at = time.perf_counter() + tokens = 0 + chunk_size = 100 + for i in range(0, len(nodes), chunk_size): + # check document is paused + self._check_document_paused_status(document.id) + chunk_nodes = nodes[i:i + chunk_size] + + tokens += sum( + TokenCalculator.get_num_tokens(self.embedding_model_name, node.get_text()) for node in chunk_nodes + ) + + # save vector index + if dataset.indexing_technique == "high_quality": + vector_index.add_nodes(chunk_nodes) + + # save keyword index + keyword_table_index.add_nodes(chunk_nodes) + + node_ids = [node.doc_id for node in chunk_nodes] + db.session.query(DocumentSegment).filter( + DocumentSegment.document_id == document.id, + DocumentSegment.index_node_id.in_(node_ids), + DocumentSegment.status == "indexing" + ).update({ + DocumentSegment.status: "completed", + DocumentSegment.completed_at: datetime.datetime.utcnow() + }) + + db.session.commit() + + indexing_end_at = time.perf_counter() + + # update document status to completed + self._update_document_index_status( + document_id=document.id, + after_indexing_status="completed", + extra_update_params={ + Document.tokens: tokens, + Document.completed_at: datetime.datetime.utcnow(), + Document.indexing_latency: indexing_end_at - indexing_start_at, + } + ) + + def _check_document_paused_status(self, document_id: str): + indexing_cache_key = 'document_{}_is_paused'.format(document_id) + result = redis_client.get(indexing_cache_key) + if result: + raise DocumentIsPausedException() + + def _update_document_index_status(self, document_id: str, after_indexing_status: str, + extra_update_params: Optional[dict] = None) -> None: + """ + Update the document indexing status. + """ + count = Document.query.filter_by(id=document_id, is_paused=True).count() + if count > 0: + raise DocumentIsPausedException() + + update_params = { + Document.indexing_status: after_indexing_status + } + + if extra_update_params: + update_params.update(extra_update_params) + + Document.query.filter_by(id=document_id).update(update_params) + db.session.commit() + + def _update_segments_by_document(self, document_id: str, update_params: dict) -> None: + """ + Update the document segment by document id. + """ + DocumentSegment.query.filter_by(document_id=document_id).update(update_params) + db.session.commit() + + +class DocumentIsPausedException(Exception): + pass diff --git a/api/core/llm/error.py b/api/core/llm/error.py new file mode 100644 index 0000000000..883d282e8a --- /dev/null +++ b/api/core/llm/error.py @@ -0,0 +1,55 @@ +from typing import Optional + + +class LLMError(Exception): + """Base class for all LLM exceptions.""" + description: Optional[str] = None + + def __init__(self, description: Optional[str] = None) -> None: + self.description = description + + +class LLMBadRequestError(LLMError): + """Raised when the LLM returns bad request.""" + description = "Bad Request" + + +class LLMAPIConnectionError(LLMError): + """Raised when the LLM returns API connection error.""" + description = "API Connection Error" + + +class LLMAPIUnavailableError(LLMError): + """Raised when the LLM returns API unavailable error.""" + description = "API Unavailable Error" + + +class LLMRateLimitError(LLMError): + """Raised when the LLM returns rate limit error.""" + description = "Rate Limit Error" + + +class LLMAuthorizationError(LLMError): + """Raised when the LLM returns authorization error.""" + description = "Authorization Error" + + +class ProviderTokenNotInitError(Exception): + """ + Custom exception raised when the provider token is not initialized. + """ + description = "Provider Token Not Init" + + +class QuotaExceededError(Exception): + """ + Custom exception raised when the quota for a provider has been exceeded. + """ + description = "Quota Exceeded" + + +class ModelCurrentlyNotSupportError(Exception): + """ + Custom exception raised when the model not support + """ + description = "Model Currently Not Support" diff --git a/api/core/llm/error_handle_wraps.py b/api/core/llm/error_handle_wraps.py new file mode 100644 index 0000000000..ae9a0278bb --- /dev/null +++ b/api/core/llm/error_handle_wraps.py @@ -0,0 +1,51 @@ +import logging +from functools import wraps + +import openai + +from core.llm.error import LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, \ + LLMBadRequestError + + +def handle_llm_exceptions(func): + @wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except openai.error.InvalidRequestError as e: + logging.exception("Invalid request to OpenAI API.") + raise LLMBadRequestError(str(e)) + except openai.error.APIConnectionError as e: + logging.exception("Failed to connect to OpenAI API.") + raise LLMAPIConnectionError(str(e)) + except (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout) as e: + logging.exception("OpenAI service unavailable.") + raise LLMAPIUnavailableError(str(e)) + except openai.error.RateLimitError as e: + raise LLMRateLimitError(str(e)) + except openai.error.AuthenticationError as e: + raise LLMAuthorizationError(str(e)) + + return wrapper + + +def handle_llm_exceptions_async(func): + @wraps(func) + async def wrapper(*args, **kwargs): + try: + return await func(*args, **kwargs) + except openai.error.InvalidRequestError as e: + logging.exception("Invalid request to OpenAI API.") + raise LLMBadRequestError(str(e)) + except openai.error.APIConnectionError as e: + logging.exception("Failed to connect to OpenAI API.") + raise LLMAPIConnectionError(str(e)) + except (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout) as e: + logging.exception("OpenAI service unavailable.") + raise LLMAPIUnavailableError(str(e)) + except openai.error.RateLimitError as e: + raise LLMRateLimitError(str(e)) + except openai.error.AuthenticationError as e: + raise LLMAuthorizationError(str(e)) + + return wrapper diff --git a/api/core/llm/llm_builder.py b/api/core/llm/llm_builder.py new file mode 100644 index 0000000000..4355593c5d --- /dev/null +++ b/api/core/llm/llm_builder.py @@ -0,0 +1,103 @@ +from typing import Union, Optional + +from langchain.callbacks import CallbackManager +from langchain.llms.fake import FakeListLLM + +from core.constant import llm_constant +from core.llm.provider.llm_provider_service import LLMProviderService +from core.llm.streamable_chat_open_ai import StreamableChatOpenAI +from core.llm.streamable_open_ai import StreamableOpenAI + + +class LLMBuilder: + """ + This class handles the following logic: + 1. For providers with the name 'OpenAI', the OPENAI_API_KEY value is stored directly in encrypted_config. + 2. For providers with the name 'Azure OpenAI', encrypted_config stores the serialized values of four fields, as shown below: + OPENAI_API_TYPE=azure + OPENAI_API_VERSION=2022-12-01 + OPENAI_API_BASE=https://your-resource-name.openai.azure.com + OPENAI_API_KEY= + 3. For providers with the name 'Anthropic', the ANTHROPIC_API_KEY value is stored directly in encrypted_config. + 4. For providers with the name 'Cohere', the COHERE_API_KEY value is stored directly in encrypted_config. + 5. For providers with the name 'HUGGINGFACEHUB', the HUGGINGFACEHUB_API_KEY value is stored directly in encrypted_config. + 6. Providers with the provider_type 'CUSTOM' can be created through the admin interface, while 'System' providers cannot be created through the admin interface. + 7. If both CUSTOM and System providers exist in the records, the CUSTOM provider is preferred by default, but this preference can be changed via an input parameter. + 8. For providers with the provider_type 'System', the quota_used must not exceed quota_limit. If the quota is exceeded, the provider cannot be used. Currently, only the TRIAL quota_type is supported, which is permanently non-resetting. + """ + + @classmethod + def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI, FakeListLLM]: + if model_name == 'fake': + return FakeListLLM(responses=[]) + + mode = cls.get_mode_by_model(model_name) + if mode == 'chat': + # llm_cls = StreamableAzureChatOpenAI + llm_cls = StreamableChatOpenAI + elif mode == 'completion': + llm_cls = StreamableOpenAI + else: + raise ValueError(f"model name {model_name} is not supported.") + + model_credentials = cls.get_model_credentials(tenant_id, model_name) + + return llm_cls( + model_name=model_name, + temperature=kwargs.get('temperature', 0), + max_tokens=kwargs.get('max_tokens', 256), + top_p=kwargs.get('top_p', 1), + frequency_penalty=kwargs.get('frequency_penalty', 0), + presence_penalty=kwargs.get('presence_penalty', 0), + callback_manager=kwargs.get('callback_manager', None), + streaming=kwargs.get('streaming', False), + # request_timeout=None + **model_credentials + ) + + @classmethod + def to_llm_from_model(cls, tenant_id: str, model: dict, streaming: bool = False, + callback_manager: Optional[CallbackManager] = None) -> Union[StreamableOpenAI, StreamableChatOpenAI]: + model_name = model.get("name") + completion_params = model.get("completion_params", {}) + + return cls.to_llm( + tenant_id=tenant_id, + model_name=model_name, + temperature=completion_params.get('temperature', 0), + max_tokens=completion_params.get('max_tokens', 256), + top_p=completion_params.get('top_p', 0), + frequency_penalty=completion_params.get('frequency_penalty', 0.1), + presence_penalty=completion_params.get('presence_penalty', 0.1), + streaming=streaming, + callback_manager=callback_manager + ) + + @classmethod + def get_mode_by_model(cls, model_name: str) -> str: + if not model_name: + raise ValueError(f"empty model name is not supported.") + + if model_name in llm_constant.models_by_mode['chat']: + return "chat" + elif model_name in llm_constant.models_by_mode['completion']: + return "completion" + else: + raise ValueError(f"model name {model_name} is not supported.") + + @classmethod + def get_model_credentials(cls, tenant_id: str, model_name: str) -> dict: + """ + Returns the API credentials for the given tenant_id and model_name, based on the model's provider. + Raises an exception if the model_name is not found or if the provider is not found. + """ + if not model_name: + raise Exception('model name not found') + + if model_name not in llm_constant.models: + raise Exception('model {} not found'.format(model_name)) + + model_provider = llm_constant.models[model_name] + + provider_service = LLMProviderService(tenant_id=tenant_id, provider_name=model_provider) + return provider_service.get_credentials(model_name) diff --git a/api/core/llm/moderation.py b/api/core/llm/moderation.py new file mode 100644 index 0000000000..d18d6fc5c2 --- /dev/null +++ b/api/core/llm/moderation.py @@ -0,0 +1,15 @@ +import openai +from models.provider import ProviderName + + +class Moderation: + + def __init__(self, provider: str, api_key: str): + self.provider = provider + self.api_key = api_key + + if self.provider == ProviderName.OPENAI.value: + self.client = openai.Moderation + + def moderate(self, text): + return self.client.create(input=text, api_key=self.api_key) diff --git a/api/core/llm/provider/anthropic_provider.py b/api/core/llm/provider/anthropic_provider.py new file mode 100644 index 0000000000..4c7756305e --- /dev/null +++ b/api/core/llm/provider/anthropic_provider.py @@ -0,0 +1,23 @@ +from typing import Optional + +from core.llm.provider.base import BaseProvider +from models.provider import ProviderName + + +class AnthropicProvider(BaseProvider): + def get_models(self, model_id: Optional[str] = None) -> list[dict]: + credentials = self.get_credentials(model_id) + # todo + return [] + + def get_credentials(self, model_id: Optional[str] = None) -> dict: + """ + Returns the API credentials for Azure OpenAI as a dictionary, for the given tenant_id. + The dictionary contains keys: azure_api_type, azure_api_version, azure_api_base, and azure_api_key. + """ + return { + 'anthropic_api_key': self.get_provider_api_key(model_id=model_id) + } + + def get_provider_name(self): + return ProviderName.ANTHROPIC \ No newline at end of file diff --git a/api/core/llm/provider/azure_provider.py b/api/core/llm/provider/azure_provider.py new file mode 100644 index 0000000000..e0ba0d0734 --- /dev/null +++ b/api/core/llm/provider/azure_provider.py @@ -0,0 +1,105 @@ +import json +from typing import Optional, Union + +import requests + +from core.llm.provider.base import BaseProvider +from models.provider import ProviderName + + +class AzureProvider(BaseProvider): + def get_models(self, model_id: Optional[str] = None) -> list[dict]: + credentials = self.get_credentials(model_id) + url = "{}/openai/deployments?api-version={}".format( + credentials.get('openai_api_base'), + credentials.get('openai_api_version') + ) + + headers = { + "api-key": credentials.get('openai_api_key'), + "content-type": "application/json; charset=utf-8" + } + + response = requests.get(url, headers=headers) + + if response.status_code == 200: + result = response.json() + return [{ + 'id': deployment['id'], + 'name': '{} ({})'.format(deployment['id'], deployment['model']) + } for deployment in result['data'] if deployment['status'] == 'succeeded'] + else: + # TODO: optimize in future + raise Exception('Failed to get deployments from Azure OpenAI. Status code: {}'.format(response.status_code)) + + def get_credentials(self, model_id: Optional[str] = None) -> dict: + """ + Returns the API credentials for Azure OpenAI as a dictionary. + """ + encrypted_config = self.get_provider_api_key(model_id=model_id) + config = json.loads(encrypted_config) + config['openai_api_type'] = 'azure' + config['deployment_name'] = model_id + return config + + def get_provider_name(self): + return ProviderName.AZURE_OPENAI + + def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]: + """ + Returns the provider configs. + """ + try: + config = self.get_provider_api_key() + config = json.loads(config) + except: + config = { + 'openai_api_type': 'azure', + 'openai_api_version': '2023-03-15-preview', + 'openai_api_base': 'https://foo.microsoft.com/bar', + 'openai_api_key': '' + } + + if obfuscated: + if not config.get('openai_api_key'): + config = { + 'openai_api_type': 'azure', + 'openai_api_version': '2023-03-15-preview', + 'openai_api_base': 'https://foo.microsoft.com/bar', + 'openai_api_key': '' + } + + config['openai_api_key'] = self.obfuscated_token(config.get('openai_api_key')) + return config + + return config + + def get_token_type(self): + # TODO: change to dict when implemented + return lambda value: value + + def config_validate(self, config: Union[dict | str]): + """ + Validates the given config. + """ + # TODO: implement + pass + + def get_encrypted_token(self, config: Union[dict | str]): + """ + Returns the encrypted token. + """ + return json.dumps({ + 'openai_api_type': 'azure', + 'openai_api_version': '2023-03-15-preview', + 'openai_api_base': config['openai_api_base'], + 'openai_api_key': self.encrypt_token(config['openai_api_key']) + }) + + def get_decrypted_token(self, token: str): + """ + Returns the decrypted token. + """ + config = json.loads(token) + config['openai_api_key'] = self.decrypt_token(config['openai_api_key']) + return config diff --git a/api/core/llm/provider/base.py b/api/core/llm/provider/base.py new file mode 100644 index 0000000000..89343ff62a --- /dev/null +++ b/api/core/llm/provider/base.py @@ -0,0 +1,124 @@ +import base64 +from abc import ABC, abstractmethod +from typing import Optional, Union + +from core import hosted_llm_credentials +from core.llm.error import QuotaExceededError, ModelCurrentlyNotSupportError, ProviderTokenNotInitError +from extensions.ext_database import db +from libs import rsa +from models.account import Tenant +from models.provider import Provider, ProviderType, ProviderName + + +class BaseProvider(ABC): + def __init__(self, tenant_id: str): + self.tenant_id = tenant_id + + def get_provider_api_key(self, model_id: Optional[str] = None, prefer_custom: bool = True) -> str: + """ + Returns the decrypted API key for the given tenant_id and provider_name. + If the provider is of type SYSTEM and the quota is exceeded, raises a QuotaExceededError. + If the provider is not found or not valid, raises a ProviderTokenNotInitError. + """ + provider = self.get_provider(prefer_custom) + if not provider: + raise ProviderTokenNotInitError() + + if provider.provider_type == ProviderType.SYSTEM.value: + quota_used = provider.quota_used if provider.quota_used is not None else 0 + quota_limit = provider.quota_limit if provider.quota_limit is not None else 0 + + if model_id and model_id == 'gpt-4': + raise ModelCurrentlyNotSupportError() + + if quota_used >= quota_limit: + raise QuotaExceededError() + + return self.get_hosted_credentials() + else: + return self.get_decrypted_token(provider.encrypted_config) + + def get_provider(self, prefer_custom: bool) -> Optional[Provider]: + """ + Returns the Provider instance for the given tenant_id and provider_name. + If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag. + """ + providers = db.session.query(Provider).filter( + Provider.tenant_id == self.tenant_id, + Provider.provider_name == self.get_provider_name().value + ).order_by(Provider.provider_type.desc() if prefer_custom else Provider.provider_type).all() + + custom_provider = None + system_provider = None + + for provider in providers: + if provider.provider_type == ProviderType.CUSTOM.value: + custom_provider = provider + elif provider.provider_type == ProviderType.SYSTEM.value: + system_provider = provider + + if custom_provider and custom_provider.is_valid and custom_provider.encrypted_config: + return custom_provider + elif system_provider and system_provider.is_valid: + return system_provider + else: + return None + + def get_hosted_credentials(self) -> str: + if self.get_provider_name() != ProviderName.OPENAI: + raise ProviderTokenNotInitError() + + if not hosted_llm_credentials.openai or not hosted_llm_credentials.openai.api_key: + raise ProviderTokenNotInitError() + + return hosted_llm_credentials.openai.api_key + + def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]: + """ + Returns the provider configs. + """ + try: + config = self.get_provider_api_key() + except: + config = 'THIS-IS-A-MOCK-TOKEN' + + if obfuscated: + return self.obfuscated_token(config) + + return config + + def obfuscated_token(self, token: str): + return token[:6] + '*' * (len(token) - 8) + token[-2:] + + def get_token_type(self): + return str + + def get_encrypted_token(self, config: Union[dict | str]): + return self.encrypt_token(config) + + def get_decrypted_token(self, token: str): + return self.decrypt_token(token) + + def encrypt_token(self, token): + tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() + encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key) + return base64.b64encode(encrypted_token).decode() + + def decrypt_token(self, token): + return rsa.decrypt(base64.b64decode(token), self.tenant_id) + + @abstractmethod + def get_provider_name(self): + raise NotImplementedError + + @abstractmethod + def get_credentials(self, model_id: Optional[str] = None) -> dict: + raise NotImplementedError + + @abstractmethod + def get_models(self, model_id: Optional[str] = None) -> list[dict]: + raise NotImplementedError + + @abstractmethod + def config_validate(self, config: str): + raise NotImplementedError diff --git a/api/core/llm/provider/errors.py b/api/core/llm/provider/errors.py new file mode 100644 index 0000000000..407b7f7906 --- /dev/null +++ b/api/core/llm/provider/errors.py @@ -0,0 +1,2 @@ +class ValidateFailedError(Exception): + description = "Provider Validate failed" diff --git a/api/core/llm/provider/huggingface_provider.py b/api/core/llm/provider/huggingface_provider.py new file mode 100644 index 0000000000..b3dd3ed573 --- /dev/null +++ b/api/core/llm/provider/huggingface_provider.py @@ -0,0 +1,22 @@ +from typing import Optional + +from core.llm.provider.base import BaseProvider +from models.provider import ProviderName + + +class HuggingfaceProvider(BaseProvider): + def get_models(self, model_id: Optional[str] = None) -> list[dict]: + credentials = self.get_credentials(model_id) + # todo + return [] + + def get_credentials(self, model_id: Optional[str] = None) -> dict: + """ + Returns the API credentials for Huggingface as a dictionary, for the given tenant_id. + """ + return { + 'huggingface_api_key': self.get_provider_api_key(model_id=model_id) + } + + def get_provider_name(self): + return ProviderName.HUGGINGFACEHUB \ No newline at end of file diff --git a/api/core/llm/provider/llm_provider_service.py b/api/core/llm/provider/llm_provider_service.py new file mode 100644 index 0000000000..ca4f8bec6d --- /dev/null +++ b/api/core/llm/provider/llm_provider_service.py @@ -0,0 +1,53 @@ +from typing import Optional, Union + +from core.llm.provider.anthropic_provider import AnthropicProvider +from core.llm.provider.azure_provider import AzureProvider +from core.llm.provider.base import BaseProvider +from core.llm.provider.huggingface_provider import HuggingfaceProvider +from core.llm.provider.openai_provider import OpenAIProvider +from models.provider import Provider + + +class LLMProviderService: + + def __init__(self, tenant_id: str, provider_name: str): + self.provider = self.init_provider(tenant_id, provider_name) + + def init_provider(self, tenant_id: str, provider_name: str) -> BaseProvider: + if provider_name == 'openai': + return OpenAIProvider(tenant_id) + elif provider_name == 'azure_openai': + return AzureProvider(tenant_id) + elif provider_name == 'anthropic': + return AnthropicProvider(tenant_id) + elif provider_name == 'huggingface': + return HuggingfaceProvider(tenant_id) + else: + raise Exception('provider {} not found'.format(provider_name)) + + def get_models(self, model_id: Optional[str] = None) -> list[dict]: + return self.provider.get_models(model_id) + + def get_credentials(self, model_id: Optional[str] = None) -> dict: + return self.provider.get_credentials(model_id) + + def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]: + return self.provider.get_provider_configs(obfuscated) + + def get_provider_db_record(self, prefer_custom: bool = False) -> Optional[Provider]: + return self.provider.get_provider(prefer_custom) + + def config_validate(self, config: Union[dict | str]): + """ + Validates the given config. + + :param config: + :raises: ValidateFailedError + """ + return self.provider.config_validate(config) + + def get_token_type(self): + return self.provider.get_token_type() + + def get_encrypted_token(self, config: Union[dict | str]): + return self.provider.get_encrypted_token(config) diff --git a/api/core/llm/provider/openai_provider.py b/api/core/llm/provider/openai_provider.py new file mode 100644 index 0000000000..8257ad3aab --- /dev/null +++ b/api/core/llm/provider/openai_provider.py @@ -0,0 +1,44 @@ +import logging +from typing import Optional, Union + +import openai +from openai.error import AuthenticationError, OpenAIError + +from core.llm.moderation import Moderation +from core.llm.provider.base import BaseProvider +from core.llm.provider.errors import ValidateFailedError +from models.provider import ProviderName + + +class OpenAIProvider(BaseProvider): + def get_models(self, model_id: Optional[str] = None) -> list[dict]: + credentials = self.get_credentials(model_id) + response = openai.Model.list(**credentials) + + return [{ + 'id': model['id'], + 'name': model['id'], + } for model in response['data']] + + def get_credentials(self, model_id: Optional[str] = None) -> dict: + """ + Returns the credentials for the given tenant_id and provider_name. + """ + return { + 'openai_api_key': self.get_provider_api_key(model_id=model_id) + } + + def get_provider_name(self): + return ProviderName.OPENAI + + def config_validate(self, config: Union[dict | str]): + """ + Validates the given config. + """ + try: + Moderation(self.get_provider_name().value, config).moderate('test') + except (AuthenticationError, OpenAIError) as ex: + raise ValidateFailedError(str(ex)) + except Exception as ex: + logging.exception('OpenAI config validation failed') + raise ex diff --git a/api/core/llm/streamable_azure_chat_open_ai.py b/api/core/llm/streamable_azure_chat_open_ai.py new file mode 100644 index 0000000000..539ce92774 --- /dev/null +++ b/api/core/llm/streamable_azure_chat_open_ai.py @@ -0,0 +1,89 @@ +import requests +from langchain.schema import BaseMessage, ChatResult, LLMResult +from langchain.chat_models import AzureChatOpenAI +from typing import Optional, List + +from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async + + +class StreamableAzureChatOpenAI(AzureChatOpenAI): + def get_messages_tokens(self, messages: List[BaseMessage]) -> int: + """Get the number of tokens in a list of messages. + + Args: + messages: The messages to count the tokens of. + + Returns: + The number of tokens in the messages. + """ + tokens_per_message = 5 + tokens_per_request = 3 + + message_tokens = tokens_per_request + message_strs = '' + for message in messages: + message_strs += message.content + message_tokens += tokens_per_message + + # calc once + message_tokens += self.get_num_tokens(message_strs) + + return message_tokens + + def _generate( + self, messages: List[BaseMessage], stop: Optional[List[str]] = None + ) -> ChatResult: + self.callback_manager.on_llm_start( + {"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], + verbose=self.verbose + ) + + chat_result = super()._generate(messages, stop) + + result = LLMResult( + generations=[chat_result.generations], + llm_output=chat_result.llm_output + ) + self.callback_manager.on_llm_end(result, verbose=self.verbose) + + return chat_result + + async def _agenerate( + self, messages: List[BaseMessage], stop: Optional[List[str]] = None + ) -> ChatResult: + if self.callback_manager.is_async: + await self.callback_manager.on_llm_start( + {"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], + verbose=self.verbose + ) + else: + self.callback_manager.on_llm_start( + {"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], + verbose=self.verbose + ) + + chat_result = super()._generate(messages, stop) + + result = LLMResult( + generations=[chat_result.generations], + llm_output=chat_result.llm_output + ) + + if self.callback_manager.is_async: + await self.callback_manager.on_llm_end(result, verbose=self.verbose) + else: + self.callback_manager.on_llm_end(result, verbose=self.verbose) + + return chat_result + + @handle_llm_exceptions + def generate( + self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None + ) -> LLMResult: + return super().generate(messages, stop) + + @handle_llm_exceptions_async + async def agenerate( + self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None + ) -> LLMResult: + return await super().agenerate(messages, stop) diff --git a/api/core/llm/streamable_chat_open_ai.py b/api/core/llm/streamable_chat_open_ai.py new file mode 100644 index 0000000000..59391e4ce0 --- /dev/null +++ b/api/core/llm/streamable_chat_open_ai.py @@ -0,0 +1,86 @@ +from langchain.schema import BaseMessage, ChatResult, LLMResult +from langchain.chat_models import ChatOpenAI +from typing import Optional, List + +from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async + + +class StreamableChatOpenAI(ChatOpenAI): + + def get_messages_tokens(self, messages: List[BaseMessage]) -> int: + """Get the number of tokens in a list of messages. + + Args: + messages: The messages to count the tokens of. + + Returns: + The number of tokens in the messages. + """ + tokens_per_message = 5 + tokens_per_request = 3 + + message_tokens = tokens_per_request + message_strs = '' + for message in messages: + message_strs += message.content + message_tokens += tokens_per_message + + # calc once + message_tokens += self.get_num_tokens(message_strs) + + return message_tokens + + def _generate( + self, messages: List[BaseMessage], stop: Optional[List[str]] = None + ) -> ChatResult: + self.callback_manager.on_llm_start( + {"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose + ) + + chat_result = super()._generate(messages, stop) + + result = LLMResult( + generations=[chat_result.generations], + llm_output=chat_result.llm_output + ) + self.callback_manager.on_llm_end(result, verbose=self.verbose) + + return chat_result + + async def _agenerate( + self, messages: List[BaseMessage], stop: Optional[List[str]] = None + ) -> ChatResult: + if self.callback_manager.is_async: + await self.callback_manager.on_llm_start( + {"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose + ) + else: + self.callback_manager.on_llm_start( + {"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose + ) + + chat_result = super()._generate(messages, stop) + + result = LLMResult( + generations=[chat_result.generations], + llm_output=chat_result.llm_output + ) + + if self.callback_manager.is_async: + await self.callback_manager.on_llm_end(result, verbose=self.verbose) + else: + self.callback_manager.on_llm_end(result, verbose=self.verbose) + + return chat_result + + @handle_llm_exceptions + def generate( + self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None + ) -> LLMResult: + return super().generate(messages, stop) + + @handle_llm_exceptions_async + async def agenerate( + self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None + ) -> LLMResult: + return await super().agenerate(messages, stop) diff --git a/api/core/llm/streamable_open_ai.py b/api/core/llm/streamable_open_ai.py new file mode 100644 index 0000000000..94754af30e --- /dev/null +++ b/api/core/llm/streamable_open_ai.py @@ -0,0 +1,20 @@ +from langchain.schema import LLMResult +from typing import Optional, List +from langchain import OpenAI + +from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async + + +class StreamableOpenAI(OpenAI): + + @handle_llm_exceptions + def generate( + self, prompts: List[str], stop: Optional[List[str]] = None + ) -> LLMResult: + return super().generate(prompts, stop) + + @handle_llm_exceptions_async + async def agenerate( + self, prompts: List[str], stop: Optional[List[str]] = None + ) -> LLMResult: + return await super().agenerate(prompts, stop) diff --git a/api/core/llm/token_calculator.py b/api/core/llm/token_calculator.py new file mode 100644 index 0000000000..e45f2b4d62 --- /dev/null +++ b/api/core/llm/token_calculator.py @@ -0,0 +1,41 @@ +import decimal +from typing import Optional + +import tiktoken + +from core.constant import llm_constant + + +class TokenCalculator: + @classmethod + def get_num_tokens(cls, model_name: str, text: str): + if len(text) == 0: + return 0 + + enc = tiktoken.encoding_for_model(model_name) + + tokenized_text = enc.encode(text) + + # calculate the number of tokens in the encoded text + return len(tokenized_text) + + @classmethod + def get_token_price(cls, model_name: str, tokens: int, text_type: Optional[str] = None) -> decimal.Decimal: + if model_name in llm_constant.models_by_mode['embedding']: + unit_price = llm_constant.model_prices[model_name]['usage'] + elif text_type == 'prompt': + unit_price = llm_constant.model_prices[model_name]['prompt'] + elif text_type == 'completion': + unit_price = llm_constant.model_prices[model_name]['completion'] + else: + raise Exception('Invalid text type') + + tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'), + rounding=decimal.ROUND_HALF_UP) + + total_price = tokens_per_1k * unit_price + return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) + + @classmethod + def get_currency(cls, model_name: str): + return llm_constant.model_currency diff --git a/api/core/memory/read_only_conversation_token_db_buffer_shared_memory.py b/api/core/memory/read_only_conversation_token_db_buffer_shared_memory.py new file mode 100644 index 0000000000..16f982c592 --- /dev/null +++ b/api/core/memory/read_only_conversation_token_db_buffer_shared_memory.py @@ -0,0 +1,77 @@ +from typing import Any, List, Dict, Union + +from langchain.memory.chat_memory import BaseChatMemory +from langchain.schema import get_buffer_string, BaseMessage, HumanMessage, AIMessage + +from core.llm.streamable_chat_open_ai import StreamableChatOpenAI +from core.llm.streamable_open_ai import StreamableOpenAI +from extensions.ext_database import db +from models.model import Conversation, Message + + +class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory): + conversation: Conversation + human_prefix: str = "Human" + ai_prefix: str = "AI" + llm: Union[StreamableChatOpenAI | StreamableOpenAI] + memory_key: str = "chat_history" + max_token_limit: int = 2000 + message_limit: int = 10 + + @property + def buffer(self) -> List[BaseMessage]: + """String buffer of memory.""" + # fetch limited messages desc, and return reversed + messages = db.session.query(Message).filter( + Message.conversation_id == self.conversation.id, + Message.answer_tokens > 0 + ).order_by(Message.created_at.desc()).limit(self.message_limit).all() + + messages = list(reversed(messages)) + + chat_messages: List[BaseMessage] = [] + for message in messages: + chat_messages.append(HumanMessage(content=message.query)) + chat_messages.append(AIMessage(content=message.answer)) + + if not chat_messages: + return chat_messages + + # prune the chat message if it exceeds the max token limit + curr_buffer_length = self.llm.get_messages_tokens(chat_messages) + if curr_buffer_length > self.max_token_limit: + pruned_memory = [] + while curr_buffer_length > self.max_token_limit and chat_messages: + pruned_memory.append(chat_messages.pop(0)) + curr_buffer_length = self.llm.get_messages_tokens(chat_messages) + + return chat_messages + + @property + def memory_variables(self) -> List[str]: + """Will always return list of memory variables. + + :meta private: + """ + return [self.memory_key] + + def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """Return history buffer.""" + buffer: Any = self.buffer + if self.return_messages: + final_buffer: Any = buffer + else: + final_buffer = get_buffer_string( + buffer, + human_prefix=self.human_prefix, + ai_prefix=self.ai_prefix, + ) + return {self.memory_key: final_buffer} + + def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: + """Nothing should be saved or changed""" + pass + + def clear(self) -> None: + """Nothing to clear, got a memory like a vault.""" + pass diff --git a/api/core/memory/read_only_conversation_token_db_string_buffer_shared_memory.py b/api/core/memory/read_only_conversation_token_db_string_buffer_shared_memory.py new file mode 100644 index 0000000000..e5933931a2 --- /dev/null +++ b/api/core/memory/read_only_conversation_token_db_string_buffer_shared_memory.py @@ -0,0 +1,36 @@ +from typing import Any, List, Dict + +from langchain.memory.chat_memory import BaseChatMemory +from langchain.schema import get_buffer_string, BaseMessage, BaseLanguageModel + +from core.memory.read_only_conversation_token_db_buffer_shared_memory import \ + ReadOnlyConversationTokenDBBufferSharedMemory + + +class ReadOnlyConversationTokenDBStringBufferSharedMemory(BaseChatMemory): + memory: ReadOnlyConversationTokenDBBufferSharedMemory + + @property + def memory_variables(self) -> List[str]: + """Return memory variables.""" + return self.memory.memory_variables + + def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]: + """Load memory variables from memory.""" + buffer: Any = self.memory.buffer + + final_buffer = get_buffer_string( + buffer, + human_prefix=self.memory.human_prefix, + ai_prefix=self.memory.ai_prefix, + ) + + return {self.memory.memory_key: final_buffer} + + def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None: + """Nothing should be saved or changed""" + pass + + def clear(self) -> None: + """Nothing to clear, got a memory like a vault.""" + pass \ No newline at end of file diff --git a/api/core/prompt/output_parser/suggested_questions_after_answer.py b/api/core/prompt/output_parser/suggested_questions_after_answer.py new file mode 100644 index 0000000000..7898d08262 --- /dev/null +++ b/api/core/prompt/output_parser/suggested_questions_after_answer.py @@ -0,0 +1,16 @@ +import json +from typing import Any + +from langchain.schema import BaseOutputParser +from core.prompt.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT + + +class SuggestedQuestionsAfterAnswerOutputParser(BaseOutputParser): + + def get_format_instructions(self) -> str: + return SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT + + def parse(self, text: str) -> Any: + json_string = text.strip() + json_obj = json.loads(json_string) + return json_obj diff --git a/api/core/prompt/prompt_builder.py b/api/core/prompt/prompt_builder.py new file mode 100644 index 0000000000..cbe41576f1 --- /dev/null +++ b/api/core/prompt/prompt_builder.py @@ -0,0 +1,37 @@ +import re + +from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, AIMessagePromptTemplate +from langchain.schema import BaseMessage + +from core.prompt.prompt_template import OutLinePromptTemplate + + +class PromptBuilder: + @classmethod + def to_system_message(cls, prompt_content: str, inputs: dict) -> BaseMessage: + prompt_template = OutLinePromptTemplate.from_template(prompt_content) + system_prompt_template = SystemMessagePromptTemplate(prompt=prompt_template) + prompt_inputs = {k: inputs[k] for k in system_prompt_template.input_variables if k in inputs} + system_message = system_prompt_template.format(**prompt_inputs) + return system_message + + @classmethod + def to_ai_message(cls, prompt_content: str, inputs: dict) -> BaseMessage: + prompt_template = OutLinePromptTemplate.from_template(prompt_content) + ai_prompt_template = AIMessagePromptTemplate(prompt=prompt_template) + prompt_inputs = {k: inputs[k] for k in ai_prompt_template.input_variables if k in inputs} + ai_message = ai_prompt_template.format(**prompt_inputs) + return ai_message + + @classmethod + def to_human_message(cls, prompt_content: str, inputs: dict) -> BaseMessage: + prompt_template = OutLinePromptTemplate.from_template(prompt_content) + human_prompt_template = HumanMessagePromptTemplate(prompt=prompt_template) + human_message = human_prompt_template.format(**inputs) + return human_message + + @classmethod + def process_template(cls, template: str): + processed_template = re.sub(r'\{(.+?)\}', r'\1', template) + processed_template = re.sub(r'\{\{(.+?)\}\}', r'{\1}', processed_template) + return processed_template diff --git a/api/core/prompt/prompt_template.py b/api/core/prompt/prompt_template.py new file mode 100644 index 0000000000..6799c5a733 --- /dev/null +++ b/api/core/prompt/prompt_template.py @@ -0,0 +1,37 @@ +import re +from typing import Any + +from langchain import PromptTemplate +from langchain.formatting import StrictFormatter + + +class OutLinePromptTemplate(PromptTemplate): + @classmethod + def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate: + """Load a prompt template from a template.""" + input_variables = { + v for _, v, _, _ in OneLineFormatter().parse(template) if v is not None + } + return cls( + input_variables=list(sorted(input_variables)), template=template, **kwargs + ) + + +class OneLineFormatter(StrictFormatter): + def parse(self, format_string): + last_end = 0 + results = [] + for match in re.finditer(r"{([a-zA-Z_]\w*)}", format_string): + field_name = match.group(1) + start, end = match.span() + + literal_text = format_string[last_end:start] + last_end = end + + results.append((literal_text, field_name, '', None)) + + remaining_literal_text = format_string[last_end:] + if remaining_literal_text: + results.append((remaining_literal_text, None, None, None)) + + return results diff --git a/api/core/prompt/prompts.py b/api/core/prompt/prompts.py new file mode 100644 index 0000000000..1d9c00990c --- /dev/null +++ b/api/core/prompt/prompts.py @@ -0,0 +1,63 @@ +from llama_index import QueryKeywordExtractPrompt + +CONVERSATION_TITLE_PROMPT = ( + "Human:{query}\n-----\n" + "Help me summarize the intent of what the human said and provide a title, the title should not exceed 20 words.\n" + "If the human said is conducted in Chinese, you should return a Chinese title.\n" + "If the human said is conducted in English, you should return an English title.\n" + "title:" +) + +CONVERSATION_SUMMARY_PROMPT = ( + "Please generate a short summary of the following conversation.\n" + "If the conversation communicating in Chinese, you should return a Chinese summary.\n" + "If the conversation communicating in English, you should return an English summary.\n" + "[Conversation Start]\n" + "{context}\n" + "[Conversation End]\n\n" + "summary:" +) + +INTRODUCTION_GENERATE_PROMPT = ( + "I am designing a product for users to interact with an AI through dialogue. " + "The Prompt given to the AI before the conversation is:\n\n" + "```\n{prompt}\n```\n\n" + "Please generate a brief introduction of no more than 50 words that greets the user, based on this Prompt. " + "Do not reveal the developer's motivation or deep logic behind the Prompt, " + "but focus on building a relationship with the user:\n" +) + +MORE_LIKE_THIS_GENERATE_PROMPT = ( + "-----\n" + "{original_completion}\n" + "-----\n\n" + "Please use the above content as a sample for generating the result, " + "and include key information points related to the original sample in the result. " + "Try to rephrase this information in different ways and predict according to the rules below.\n\n" + "-----\n" + "{prompt}\n" +) + +SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = ( + "Please help me predict the three most likely questions that human would ask, " + "and keeping each question under 20 characters.\n" + "The output must be in JSON format following the specified schema:\n" + "[\"question1\",\"question2\",\"question3\"]\n" +) + +QUERY_KEYWORD_EXTRACT_TEMPLATE_TMPL = ( + "A question is provided below. Given the question, extract up to {max_keywords} " + "keywords from the text. Focus on extracting the keywords that we can use " + "to best lookup answers to the question. Avoid stopwords." + "I am not sure which language the following question is in. " + "If the user asked the question in Chinese, please return the keywords in Chinese. " + "If the user asked the question in English, please return the keywords in English.\n" + "---------------------\n" + "{question}\n" + "---------------------\n" + "Provide keywords in the following comma-separated format: 'KEYWORDS: '\n" +) + +QUERY_KEYWORD_EXTRACT_TEMPLATE = QueryKeywordExtractPrompt( + QUERY_KEYWORD_EXTRACT_TEMPLATE_TMPL +) diff --git a/api/core/tool/dataset_tool_builder.py b/api/core/tool/dataset_tool_builder.py new file mode 100644 index 0000000000..b31b15511a --- /dev/null +++ b/api/core/tool/dataset_tool_builder.py @@ -0,0 +1,83 @@ +from typing import Optional + +from langchain.callbacks import CallbackManager +from llama_index.langchain_helpers.agents import IndexToolConfig + +from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler +from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler +from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler +from core.index.keyword_table_index import KeywordTableIndex +from core.index.vector_index import VectorIndex +from core.prompt.prompts import QUERY_KEYWORD_EXTRACT_TEMPLATE +from core.tool.llama_index_tool import EnhanceLlamaIndexTool +from extensions.ext_database import db +from models.dataset import Dataset + + +class DatasetToolBuilder: + @classmethod + def build_dataset_tool(cls, tenant_id: str, dataset_id: str, + response_mode: str = "no_synthesizer", + callback_handler: Optional[DatasetToolCallbackHandler] = None): + # get dataset from dataset id + dataset = db.session.query(Dataset).filter( + Dataset.tenant_id == tenant_id, + Dataset.id == dataset_id + ).first() + + if not dataset: + return None + + if dataset.indexing_technique == "economy": + # use keyword table query + index = KeywordTableIndex(dataset=dataset).query_index + + if not index: + return None + + query_kwargs = { + "mode": "default", + "response_mode": response_mode, + "query_keyword_extract_template": QUERY_KEYWORD_EXTRACT_TEMPLATE, + "max_keywords_per_query": 5, + # If num_chunks_per_query is too large, + # it will slow down the synthesis process due to multiple iterations of refinement. + "num_chunks_per_query": 2 + } + else: + index = VectorIndex(dataset=dataset).query_index + + if not index: + return None + + query_kwargs = { + "mode": "default", + "response_mode": response_mode, + # If top_k is too large, + # it will slow down the synthesis process due to multiple iterations of refinement. + "similarity_top_k": 2 + } + + # fulfill description when it is empty + description = dataset.description + if not description: + description = 'useful for when you want to answer queries about the ' + dataset.name + + index_tool_config = IndexToolConfig( + index=index, + name=f"dataset-{dataset_id}", + description=description, + index_query_kwargs=query_kwargs, + tool_kwargs={ + "callback_manager": CallbackManager([callback_handler, DifyStdOutCallbackHandler()]) + }, + # tool_kwargs={"return_direct": True}, + # return_direct: Whether to return LLM results directly or process the output data with an Output Parser + ) + + index_callback_handler = DatasetIndexToolCallbackHandler(dataset_id=dataset_id) + + return EnhanceLlamaIndexTool.from_tool_config( + tool_config=index_tool_config, + callback_handler=index_callback_handler + ) diff --git a/api/core/tool/llama_index_tool.py b/api/core/tool/llama_index_tool.py new file mode 100644 index 0000000000..ffb216771b --- /dev/null +++ b/api/core/tool/llama_index_tool.py @@ -0,0 +1,43 @@ +from typing import Dict + +from langchain.tools import BaseTool +from llama_index.indices.base import BaseGPTIndex +from llama_index.langchain_helpers.agents import IndexToolConfig +from pydantic import Field + +from core.callback_handler.index_tool_callback_handler import IndexToolCallbackHandler + + +class EnhanceLlamaIndexTool(BaseTool): + """Tool for querying a LlamaIndex.""" + + # NOTE: name/description still needs to be set + index: BaseGPTIndex + query_kwargs: Dict = Field(default_factory=dict) + return_sources: bool = False + callback_handler: IndexToolCallbackHandler + + @classmethod + def from_tool_config(cls, tool_config: IndexToolConfig, + callback_handler: IndexToolCallbackHandler) -> "EnhanceLlamaIndexTool": + """Create a tool from a tool config.""" + return_sources = tool_config.tool_kwargs.pop("return_sources", False) + return cls( + index=tool_config.index, + callback_handler=callback_handler, + name=tool_config.name, + description=tool_config.description, + return_sources=return_sources, + query_kwargs=tool_config.index_query_kwargs, + **tool_config.tool_kwargs, + ) + + def _run(self, tool_input: str) -> str: + response = self.index.query(tool_input, **self.query_kwargs) + self.callback_handler.on_tool_end(response) + return str(response) + + async def _arun(self, tool_input: str) -> str: + response = await self.index.aquery(tool_input, **self.query_kwargs) + self.callback_handler.on_tool_end(response) + return str(response) diff --git a/api/core/vector_store/base.py b/api/core/vector_store/base.py new file mode 100644 index 0000000000..526f83831d --- /dev/null +++ b/api/core/vector_store/base.py @@ -0,0 +1,34 @@ +from abc import ABC, abstractmethod +from typing import Optional + +from llama_index import ServiceContext, GPTVectorStoreIndex +from llama_index.data_structs import Node +from llama_index.vector_stores.types import VectorStore + + +class BaseVectorStoreClient(ABC): + @abstractmethod + def get_index(self, service_context: ServiceContext, config: dict) -> GPTVectorStoreIndex: + raise NotImplementedError + + @abstractmethod + def to_index_config(self, index_id: str) -> dict: + raise NotImplementedError + + +class BaseGPTVectorStoreIndex(GPTVectorStoreIndex): + def delete_node(self, node_id: str): + self._vector_store.delete_node(node_id) + + def exists_by_node_id(self, node_id: str) -> bool: + return self._vector_store.exists_by_node_id(node_id) + + +class EnhanceVectorStore(ABC): + @abstractmethod + def delete_node(self, node_id: str): + pass + + @abstractmethod + def exists_by_node_id(self, node_id: str) -> bool: + pass diff --git a/api/core/vector_store/qdrant_vector_store_client.py b/api/core/vector_store/qdrant_vector_store_client.py new file mode 100644 index 0000000000..1188c121e3 --- /dev/null +++ b/api/core/vector_store/qdrant_vector_store_client.py @@ -0,0 +1,147 @@ +import os +from typing import cast, List + +from llama_index.data_structs import Node +from llama_index.data_structs.node_v2 import DocumentRelationship +from llama_index.vector_stores.types import VectorStoreQuery, VectorStoreQueryResult +from qdrant_client.http.models import Payload, Filter + +import qdrant_client +from llama_index import ServiceContext, GPTVectorStoreIndex, GPTQdrantIndex +from llama_index.data_structs.data_structs_v2 import QdrantIndexDict +from llama_index.vector_stores import QdrantVectorStore +from qdrant_client.local.qdrant_local import QdrantLocal + +from core.vector_store.base import BaseVectorStoreClient, BaseGPTVectorStoreIndex, EnhanceVectorStore + + +class QdrantVectorStoreClient(BaseVectorStoreClient): + + def __init__(self, url: str, api_key: str, root_path: str): + self._client = self.init_from_config(url, api_key, root_path) + + @classmethod + def init_from_config(cls, url: str, api_key: str, root_path: str): + if url and url.startswith('path:'): + path = url.replace('path:', '') + if not os.path.isabs(path): + path = os.path.join(root_path, path) + + return qdrant_client.QdrantClient( + path=path + ) + else: + return qdrant_client.QdrantClient( + url=url, + api_key=api_key, + ) + + def get_index(self, service_context: ServiceContext, config: dict) -> GPTVectorStoreIndex: + index_struct = QdrantIndexDict() + + if self._client is None: + raise Exception("Vector client is not initialized.") + + # {"collection_name": "Gpt_index_xxx"} + collection_name = config.get('collection_name') + if not collection_name: + raise Exception("collection_name cannot be None.") + + return GPTQdrantEnhanceIndex( + service_context=service_context, + index_struct=index_struct, + vector_store=QdrantEnhanceVectorStore( + client=self._client, + collection_name=collection_name + ) + ) + + def to_index_config(self, index_id: str) -> dict: + return {"collection_name": index_id} + + +class GPTQdrantEnhanceIndex(GPTQdrantIndex, BaseGPTVectorStoreIndex): + pass + + +class QdrantEnhanceVectorStore(QdrantVectorStore, EnhanceVectorStore): + def delete_node(self, node_id: str): + """ + Delete node from the index. + + :param node_id: node id + """ + from qdrant_client.http import models as rest + + self._reload_if_needed() + + self._client.delete( + collection_name=self._collection_name, + points_selector=rest.Filter( + must=[ + rest.FieldCondition( + key="id", match=rest.MatchValue(value=node_id) + ) + ] + ), + ) + + def exists_by_node_id(self, node_id: str) -> bool: + """ + Get node from the index by node id. + + :param node_id: node id + """ + self._reload_if_needed() + + response = self._client.retrieve( + collection_name=self._collection_name, + ids=[node_id] + ) + + return len(response) > 0 + + def query( + self, + query: VectorStoreQuery, + ) -> VectorStoreQueryResult: + """Query index for top k most similar nodes. + + Args: + query (VectorStoreQuery): query + """ + query_embedding = cast(List[float], query.query_embedding) + + self._reload_if_needed() + + response = self._client.search( + collection_name=self._collection_name, + query_vector=query_embedding, + limit=cast(int, query.similarity_top_k), + query_filter=cast(Filter, self._build_query_filter(query)), + with_vectors=True + ) + + nodes = [] + similarities = [] + ids = [] + for point in response: + payload = cast(Payload, point.payload) + node = Node( + doc_id=str(point.id), + text=payload.get("text"), + embedding=point.vector, + extra_info=payload.get("extra_info"), + relationships={ + DocumentRelationship.SOURCE: payload.get("doc_id", "None"), + }, + ) + nodes.append(node) + similarities.append(point.score) + ids.append(str(point.id)) + + return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids) + + def _reload_if_needed(self): + if isinstance(self._client._client, QdrantLocal): + self._client._client._load() diff --git a/api/core/vector_store/vector_store.py b/api/core/vector_store/vector_store.py new file mode 100644 index 0000000000..56b5fd0f97 --- /dev/null +++ b/api/core/vector_store/vector_store.py @@ -0,0 +1,61 @@ +from flask import Flask +from llama_index import ServiceContext, GPTVectorStoreIndex +from requests import ReadTimeout +from tenacity import retry, retry_if_exception_type, stop_after_attempt + +from core.vector_store.qdrant_vector_store_client import QdrantVectorStoreClient +from core.vector_store.weaviate_vector_store_client import WeaviateVectorStoreClient + +SUPPORTED_VECTOR_STORES = ['weaviate', 'qdrant'] + + +class VectorStore: + + def __init__(self): + self._vector_store = None + self._client = None + + def init_app(self, app: Flask): + if not app.config['VECTOR_STORE']: + return + + self._vector_store = app.config['VECTOR_STORE'] + if self._vector_store not in SUPPORTED_VECTOR_STORES: + raise ValueError(f"Vector store {self._vector_store} is not supported.") + + if self._vector_store == 'weaviate': + self._client = WeaviateVectorStoreClient( + endpoint=app.config['WEAVIATE_ENDPOINT'], + api_key=app.config['WEAVIATE_API_KEY'], + grpc_enabled=app.config['WEAVIATE_GRPC_ENABLED'] + ) + elif self._vector_store == 'qdrant': + self._client = QdrantVectorStoreClient( + url=app.config['QDRANT_URL'], + api_key=app.config['QDRANT_API_KEY'], + root_path=app.root_path + ) + + app.extensions['vector_store'] = self + + @retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3)) + def get_index(self, service_context: ServiceContext, index_struct: dict) -> GPTVectorStoreIndex: + vector_store_config: dict = index_struct.get('vector_store') + index = self.get_client().get_index( + service_context=service_context, + config=vector_store_config + ) + + return index + + def to_index_struct(self, index_id: str) -> dict: + return { + "type": self._vector_store, + "vector_store": self.get_client().to_index_config(index_id) + } + + def get_client(self): + if not self._client: + raise Exception("Vector store client is not initialized.") + + return self._client diff --git a/api/core/vector_store/vector_store_index_query.py b/api/core/vector_store/vector_store_index_query.py new file mode 100644 index 0000000000..f29de83f9e --- /dev/null +++ b/api/core/vector_store/vector_store_index_query.py @@ -0,0 +1,66 @@ +from llama_index.indices.query.base import IS +from typing import ( + Any, + Dict, + List, + Optional +) + +from llama_index.docstore import BaseDocumentStore +from llama_index.indices.postprocessor.node import ( + BaseNodePostprocessor, +) +from llama_index.indices.vector_store import GPTVectorStoreIndexQuery +from llama_index.indices.response.response_builder import ResponseMode +from llama_index.indices.service_context import ServiceContext +from llama_index.optimization.optimizer import BaseTokenUsageOptimizer +from llama_index.prompts.prompts import ( + QuestionAnswerPrompt, + RefinePrompt, + SimpleInputPrompt, +) + +from core.index.query.synthesizer import EnhanceResponseSynthesizer + + +class EnhanceGPTVectorStoreIndexQuery(GPTVectorStoreIndexQuery): + @classmethod + def from_args( + cls, + index_struct: IS, + service_context: ServiceContext, + docstore: Optional[BaseDocumentStore] = None, + node_postprocessors: Optional[List[BaseNodePostprocessor]] = None, + verbose: bool = False, + # response synthesizer args + response_mode: ResponseMode = ResponseMode.DEFAULT, + text_qa_template: Optional[QuestionAnswerPrompt] = None, + refine_template: Optional[RefinePrompt] = None, + simple_template: Optional[SimpleInputPrompt] = None, + response_kwargs: Optional[Dict] = None, + use_async: bool = False, + streaming: bool = False, + optimizer: Optional[BaseTokenUsageOptimizer] = None, + # class-specific args + **kwargs: Any, + ) -> "BaseGPTIndexQuery": + response_synthesizer = EnhanceResponseSynthesizer.from_args( + service_context=service_context, + text_qa_template=text_qa_template, + refine_template=refine_template, + simple_template=simple_template, + response_mode=response_mode, + response_kwargs=response_kwargs, + use_async=use_async, + streaming=streaming, + optimizer=optimizer, + ) + return cls( + index_struct=index_struct, + service_context=service_context, + response_synthesizer=response_synthesizer, + docstore=docstore, + node_postprocessors=node_postprocessors, + verbose=verbose, + **kwargs, + ) diff --git a/api/core/vector_store/weaviate_vector_store_client.py b/api/core/vector_store/weaviate_vector_store_client.py new file mode 100644 index 0000000000..2310278cf9 --- /dev/null +++ b/api/core/vector_store/weaviate_vector_store_client.py @@ -0,0 +1,258 @@ +import json +import weaviate +from dataclasses import field +from typing import List, Any, Dict, Optional + +from core.vector_store.base import BaseVectorStoreClient, BaseGPTVectorStoreIndex, EnhanceVectorStore +from llama_index import ServiceContext, GPTWeaviateIndex, GPTVectorStoreIndex +from llama_index.data_structs.data_structs_v2 import WeaviateIndexDict, Node +from llama_index.data_structs.node_v2 import DocumentRelationship +from llama_index.readers.weaviate.client import _class_name, NODE_SCHEMA, _logger +from llama_index.vector_stores import WeaviateVectorStore +from llama_index.vector_stores.types import VectorStoreQuery, VectorStoreQueryResult, VectorStoreQueryMode +from llama_index.readers.weaviate.utils import ( + parse_get_response, + validate_client, +) + + +class WeaviateVectorStoreClient(BaseVectorStoreClient): + + def __init__(self, endpoint: str, api_key: str, grpc_enabled: bool): + self._client = self.init_from_config(endpoint, api_key, grpc_enabled) + + def init_from_config(self, endpoint: str, api_key: str, grpc_enabled: bool): + auth_config = weaviate.auth.AuthApiKey(api_key=api_key) + + weaviate.connect.connection.has_grpc = grpc_enabled + + return weaviate.Client( + url=endpoint, + auth_client_secret=auth_config, + timeout_config=(5, 15), + startup_period=None + ) + + def get_index(self, service_context: ServiceContext, config: dict) -> GPTVectorStoreIndex: + index_struct = WeaviateIndexDict() + + if self._client is None: + raise Exception("Vector client is not initialized.") + + # {"class_prefix": "Gpt_index_xxx"} + class_prefix = config.get('class_prefix') + if not class_prefix: + raise Exception("class_prefix cannot be None.") + + return GPTWeaviateEnhanceIndex( + service_context=service_context, + index_struct=index_struct, + vector_store=WeaviateWithSimilaritiesVectorStore( + weaviate_client=self._client, + class_prefix=class_prefix + ) + ) + + def to_index_config(self, index_id: str) -> dict: + return {"class_prefix": index_id} + + +class WeaviateWithSimilaritiesVectorStore(WeaviateVectorStore, EnhanceVectorStore): + def query(self, query: VectorStoreQuery) -> VectorStoreQueryResult: + """Query index for top k most similar nodes.""" + nodes = self.weaviate_query( + self._client, + self._class_prefix, + query, + ) + nodes = nodes[: query.similarity_top_k] + node_idxs = [str(i) for i in range(len(nodes))] + + similarities = [] + for node in nodes: + similarities.append(node.extra_info['similarity']) + del node.extra_info['similarity'] + + return VectorStoreQueryResult(nodes=nodes, ids=node_idxs, similarities=similarities) + + def weaviate_query( + self, + client: Any, + class_prefix: str, + query_spec: VectorStoreQuery, + ) -> List[Node]: + """Convert to LlamaIndex list.""" + validate_client(client) + + class_name = _class_name(class_prefix) + prop_names = [p["name"] for p in NODE_SCHEMA] + vector = query_spec.query_embedding + + # build query + query = client.query.get(class_name, prop_names).with_additional(["id", "vector", "certainty"]) + if query_spec.mode == VectorStoreQueryMode.DEFAULT: + _logger.debug("Using vector search") + if vector is not None: + query = query.with_near_vector( + { + "vector": vector, + } + ) + elif query_spec.mode == VectorStoreQueryMode.HYBRID: + _logger.debug(f"Using hybrid search with alpha {query_spec.alpha}") + query = query.with_hybrid( + query=query_spec.query_str, + alpha=query_spec.alpha, + vector=vector, + ) + query = query.with_limit(query_spec.similarity_top_k) + _logger.debug(f"Using limit of {query_spec.similarity_top_k}") + + # execute query + query_result = query.do() + + # parse results + parsed_result = parse_get_response(query_result) + entries = parsed_result[class_name] + results = [self._to_node(entry) for entry in entries] + return results + + def _to_node(self, entry: Dict) -> Node: + """Convert to Node.""" + extra_info_str = entry["extra_info"] + if extra_info_str == "": + extra_info = None + else: + extra_info = json.loads(extra_info_str) + + if 'certainty' in entry['_additional']: + if extra_info: + extra_info['similarity'] = entry['_additional']['certainty'] + else: + extra_info = {'similarity': entry['_additional']['certainty']} + + node_info_str = entry["node_info"] + if node_info_str == "": + node_info = None + else: + node_info = json.loads(node_info_str) + + relationships_str = entry["relationships"] + relationships: Dict[DocumentRelationship, str] + if relationships_str == "": + relationships = field(default_factory=dict) + else: + relationships = { + DocumentRelationship(k): v for k, v in json.loads(relationships_str).items() + } + + return Node( + text=entry["text"], + doc_id=entry["doc_id"], + embedding=entry["_additional"]["vector"], + extra_info=extra_info, + node_info=node_info, + relationships=relationships, + ) + + def delete(self, doc_id: str, **delete_kwargs: Any) -> None: + """Delete a document. + + Args: + doc_id (str): document id + + """ + delete_document(self._client, doc_id, self._class_prefix) + + def delete_node(self, node_id: str): + """ + Delete node from the index. + + :param node_id: node id + """ + delete_node(self._client, node_id, self._class_prefix) + + def exists_by_node_id(self, node_id: str) -> bool: + """ + Get node from the index by node id. + + :param node_id: node id + """ + entry = get_by_node_id(self._client, node_id, self._class_prefix) + return True if entry else False + + +class GPTWeaviateEnhanceIndex(GPTWeaviateIndex, BaseGPTVectorStoreIndex): + pass + + +def delete_document(client: Any, ref_doc_id: str, class_prefix: str) -> None: + """Delete entry.""" + validate_client(client) + # make sure that each entry + class_name = _class_name(class_prefix) + where_filter = { + "path": ["ref_doc_id"], + "operator": "Equal", + "valueString": ref_doc_id, + } + query = ( + client.query.get(class_name).with_additional(["id"]).with_where(where_filter) + ) + + query_result = query.do() + parsed_result = parse_get_response(query_result) + entries = parsed_result[class_name] + for entry in entries: + client.data_object.delete(entry["_additional"]["id"], class_name) + + while len(entries) > 0: + query_result = query.do() + parsed_result = parse_get_response(query_result) + entries = parsed_result[class_name] + for entry in entries: + client.data_object.delete(entry["_additional"]["id"], class_name) + + +def delete_node(client: Any, node_id: str, class_prefix: str) -> None: + """Delete entry.""" + validate_client(client) + # make sure that each entry + class_name = _class_name(class_prefix) + where_filter = { + "path": ["doc_id"], + "operator": "Equal", + "valueString": node_id, + } + query = ( + client.query.get(class_name).with_additional(["id"]).with_where(where_filter) + ) + + query_result = query.do() + parsed_result = parse_get_response(query_result) + entries = parsed_result[class_name] + for entry in entries: + client.data_object.delete(entry["_additional"]["id"], class_name) + + +def get_by_node_id(client: Any, node_id: str, class_prefix: str) -> Optional[Dict]: + """Delete entry.""" + validate_client(client) + # make sure that each entry + class_name = _class_name(class_prefix) + where_filter = { + "path": ["doc_id"], + "operator": "Equal", + "valueString": node_id, + } + query = ( + client.query.get(class_name).with_additional(["id"]).with_where(where_filter) + ) + + query_result = query.do() + parsed_result = parse_get_response(query_result) + entries = parsed_result[class_name] + if len(entries) == 0: + return None + + return entries[0] diff --git a/api/docker/entrypoint.sh b/api/docker/entrypoint.sh new file mode 100644 index 0000000000..50b8cbd86a --- /dev/null +++ b/api/docker/entrypoint.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +set -e + +if [[ "${MIGRATION_ENABLED}" == "true" ]]; then + echo "Running migrations" + flask db upgrade +fi + +if [[ "${MODE}" == "worker" ]]; then + celery -A app.celery worker -P ${CELERY_WORKER_CLASS:-gevent} -c ${CELERY_WORKER_AMOUNT:-1} --loglevel INFO +else + if [[ "${DEBUG}" == "true" ]]; then + flask run --host=${DIFY_BIND_ADDRESS:-0.0.0.0} --port=${DIFY_PORT:-5001} --debug + else + gunicorn \ + --bind "${DIFY_BIND_ADDRESS:-0.0.0.0}:${DIFY_PORT:-5001}" \ + --workers ${SERVER_WORKER_AMOUNT:-1} \ + --worker-class ${SERVER_WORKER_CLASS:-gevent} \ + --timeout ${GUNICORN_TIMEOUT:-200} \ + --preload \ + app:app + fi +fi \ No newline at end of file diff --git a/api/events/app_event.py b/api/events/app_event.py new file mode 100644 index 0000000000..938478d3b7 --- /dev/null +++ b/api/events/app_event.py @@ -0,0 +1,10 @@ +from blinker import signal + +# sender: app +app_was_created = signal('app-was-created') + +# sender: app +app_was_deleted = signal('app-was-deleted') + +# sender: app, kwargs: old_app_model_config, new_app_model_config +app_model_config_was_updated = signal('app-model-config-was-updated') diff --git a/api/events/dataset_event.py b/api/events/dataset_event.py new file mode 100644 index 0000000000..d4a2b6f313 --- /dev/null +++ b/api/events/dataset_event.py @@ -0,0 +1,4 @@ +from blinker import signal + +# sender: dataset +dataset_was_deleted = signal('dataset-was-deleted') diff --git a/api/events/document_event.py b/api/events/document_event.py new file mode 100644 index 0000000000..f95326630b --- /dev/null +++ b/api/events/document_event.py @@ -0,0 +1,4 @@ +from blinker import signal + +# sender: document +document_was_deleted = signal('document-was-deleted') diff --git a/api/events/event_handlers/__init__.py b/api/events/event_handlers/__init__.py new file mode 100644 index 0000000000..c858ac0880 --- /dev/null +++ b/api/events/event_handlers/__init__.py @@ -0,0 +1,9 @@ +from .create_installed_app_when_app_created import handle +from .delete_installed_app_when_app_deleted import handle +from .create_provider_when_tenant_created import handle +from .create_provider_when_tenant_updated import handle +from .clean_when_document_deleted import handle +from .clean_when_dataset_deleted import handle +from .update_app_dataset_join_when_app_model_config_updated import handle +from .generate_conversation_name_when_first_message_created import handle +from .generate_conversation_summary_when_few_message_created import handle diff --git a/api/events/event_handlers/clean_when_dataset_deleted.py b/api/events/event_handlers/clean_when_dataset_deleted.py new file mode 100644 index 0000000000..e9975c92bc --- /dev/null +++ b/api/events/event_handlers/clean_when_dataset_deleted.py @@ -0,0 +1,8 @@ +from events.dataset_event import dataset_was_deleted +from tasks.clean_dataset_task import clean_dataset_task + + +@dataset_was_deleted.connect +def handle(sender, **kwargs): + dataset = sender + clean_dataset_task.delay(dataset.id, dataset.tenant_id, dataset.indexing_technique, dataset.index_struct) diff --git a/api/events/event_handlers/clean_when_document_deleted.py b/api/events/event_handlers/clean_when_document_deleted.py new file mode 100644 index 0000000000..d6553b385e --- /dev/null +++ b/api/events/event_handlers/clean_when_document_deleted.py @@ -0,0 +1,9 @@ +from events.document_event import document_was_deleted +from tasks.clean_document_task import clean_document_task + + +@document_was_deleted.connect +def handle(sender, **kwargs): + document_id = sender + dataset_id = kwargs.get('dataset_id') + clean_document_task.delay(document_id, dataset_id) diff --git a/api/events/event_handlers/create_installed_app_when_app_created.py b/api/events/event_handlers/create_installed_app_when_app_created.py new file mode 100644 index 0000000000..31084ce0fe --- /dev/null +++ b/api/events/event_handlers/create_installed_app_when_app_created.py @@ -0,0 +1,16 @@ +from events.app_event import app_was_created +from extensions.ext_database import db +from models.model import InstalledApp + + +@app_was_created.connect +def handle(sender, **kwargs): + """Create an installed app when an app is created.""" + app = sender + installed_app = InstalledApp( + tenant_id=app.tenant_id, + app_id=app.id, + app_owner_tenant_id=app.tenant_id + ) + db.session.add(installed_app) + db.session.commit() diff --git a/api/events/event_handlers/create_provider_when_tenant_created.py b/api/events/event_handlers/create_provider_when_tenant_created.py new file mode 100644 index 0000000000..e967a5d071 --- /dev/null +++ b/api/events/event_handlers/create_provider_when_tenant_created.py @@ -0,0 +1,9 @@ +from events.tenant_event import tenant_was_updated +from services.provider_service import ProviderService + + +@tenant_was_updated.connect +def handle(sender, **kwargs): + tenant = sender + if tenant.status == 'normal': + ProviderService.create_system_provider(tenant) diff --git a/api/events/event_handlers/create_provider_when_tenant_updated.py b/api/events/event_handlers/create_provider_when_tenant_updated.py new file mode 100644 index 0000000000..81a7d40ff6 --- /dev/null +++ b/api/events/event_handlers/create_provider_when_tenant_updated.py @@ -0,0 +1,9 @@ +from events.tenant_event import tenant_was_created +from services.provider_service import ProviderService + + +@tenant_was_created.connect +def handle(sender, **kwargs): + tenant = sender + if tenant.status == 'normal': + ProviderService.create_system_provider(tenant) diff --git a/api/events/event_handlers/delete_installed_app_when_app_deleted.py b/api/events/event_handlers/delete_installed_app_when_app_deleted.py new file mode 100644 index 0000000000..1d6271a466 --- /dev/null +++ b/api/events/event_handlers/delete_installed_app_when_app_deleted.py @@ -0,0 +1,12 @@ +from events.app_event import app_was_deleted +from extensions.ext_database import db +from models.model import InstalledApp + + +@app_was_deleted.connect +def handle(sender, **kwargs): + app = sender + installed_apps = db.session.query(InstalledApp).filter(InstalledApp.app_id == app.id).all() + for installed_app in installed_apps: + db.session.delete(installed_app) + db.session.commit() diff --git a/api/events/event_handlers/generate_conversation_name_when_first_message_created.py b/api/events/event_handlers/generate_conversation_name_when_first_message_created.py new file mode 100644 index 0000000000..4c1bbee53e --- /dev/null +++ b/api/events/event_handlers/generate_conversation_name_when_first_message_created.py @@ -0,0 +1,29 @@ +import logging + +from core.generator.llm_generator import LLMGenerator +from events.message_event import message_was_created +from extensions.ext_database import db + + +@message_was_created.connect +def handle(sender, **kwargs): + message = sender + conversation = kwargs.get('conversation') + is_first_message = kwargs.get('is_first_message') + + if is_first_message: + if conversation.mode == 'chat': + app_model = conversation.app + if not app_model: + return + + # generate conversation name + try: + name = LLMGenerator.generate_conversation_name(app_model.tenant_id, message.query, message.answer) + conversation.name = name + except: + conversation.name = 'New Chat' + logging.exception('generate_conversation_name failed') + + db.session.add(conversation) + db.session.commit() diff --git a/api/events/event_handlers/generate_conversation_summary_when_few_message_created.py b/api/events/event_handlers/generate_conversation_summary_when_few_message_created.py new file mode 100644 index 0000000000..df62a90b8e --- /dev/null +++ b/api/events/event_handlers/generate_conversation_summary_when_few_message_created.py @@ -0,0 +1,14 @@ +from events.message_event import message_was_created +from tasks.generate_conversation_summary_task import generate_conversation_summary_task + + +@message_was_created.connect +def handle(sender, **kwargs): + message = sender + conversation = kwargs.get('conversation') + is_first_message = kwargs.get('is_first_message') + + if not is_first_message and conversation.mode == 'chat' and not conversation.summary: + history_message_count = conversation.message_count + if history_message_count >= 5: + generate_conversation_summary_task.delay(conversation.id) diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py new file mode 100644 index 0000000000..d165b014d6 --- /dev/null +++ b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py @@ -0,0 +1,66 @@ +from events.app_event import app_model_config_was_updated +from extensions.ext_database import db +from models.dataset import AppDatasetJoin +from models.model import AppModelConfig + + +@app_model_config_was_updated.connect +def handle(sender, **kwargs): + app_model = sender + app_model_config = kwargs.get('app_model_config') + + dataset_ids = get_dataset_ids_from_model_config(app_model_config) + + app_dataset_joins = db.session.query(AppDatasetJoin).filter( + AppDatasetJoin.app_id == app_model.id + ).all() + + removed_dataset_ids = [] + if not app_dataset_joins: + added_dataset_ids = dataset_ids + else: + old_dataset_ids = set() + for app_dataset_join in app_dataset_joins: + old_dataset_ids.add(app_dataset_join.dataset_id) + + added_dataset_ids = dataset_ids - old_dataset_ids + removed_dataset_ids = old_dataset_ids - dataset_ids + + if removed_dataset_ids: + for dataset_id in removed_dataset_ids: + db.session.query(AppDatasetJoin).filter( + AppDatasetJoin.app_id == app_model.id, + AppDatasetJoin.dataset_id == dataset_id + ).delete() + + if added_dataset_ids: + for dataset_id in added_dataset_ids: + app_dataset_join = AppDatasetJoin( + app_id=app_model.id, + dataset_id=dataset_id + ) + db.session.add(app_dataset_join) + + db.session.commit() + + +def get_dataset_ids_from_model_config(app_model_config: AppModelConfig) -> set: + dataset_ids = set() + if not app_model_config: + return dataset_ids + + agent_mode = app_model_config.agent_mode_dict + if agent_mode.get('enabled') is False: + return dataset_ids + + if not agent_mode.get('tools'): + return dataset_ids + + tools = agent_mode.get('tools') + for tool in tools: + tool_type = list(tool.keys())[0] + tool_config = list(tool.values())[0] + if tool_type == "dataset": + dataset_ids.add(tool_config.get("id")) + + return dataset_ids diff --git a/api/events/message_event.py b/api/events/message_event.py new file mode 100644 index 0000000000..21da83f249 --- /dev/null +++ b/api/events/message_event.py @@ -0,0 +1,4 @@ +from blinker import signal + +# sender: message, kwargs: conversation +message_was_created = signal('message-was-created') diff --git a/api/events/tenant_event.py b/api/events/tenant_event.py new file mode 100644 index 0000000000..942f709917 --- /dev/null +++ b/api/events/tenant_event.py @@ -0,0 +1,7 @@ +from blinker import signal + +# sender: tenant +tenant_was_created = signal('tenant-was-created') + +# sender: tenant +tenant_was_updated = signal('tenant-was-updated') diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py new file mode 100644 index 0000000000..f738b984d9 --- /dev/null +++ b/api/extensions/ext_celery.py @@ -0,0 +1,23 @@ +from celery import Task, Celery +from flask import Flask + + +def init_app(app: Flask) -> Celery: + class FlaskTask(Task): + def __call__(self, *args: object, **kwargs: object) -> object: + with app.app_context(): + return self.run(*args, **kwargs) + + celery_app = Celery( + app.name, + task_cls=FlaskTask, + broker=app.config["CELERY_BROKER_URL"], + backend=app.config["CELERY_BACKEND"], + task_ignore_result=True, + ) + celery_app.conf.update( + result_backend=app.config["CELERY_RESULT_BACKEND"], + ) + celery_app.set_default() + app.extensions["celery"] = celery_app + return celery_app diff --git a/api/extensions/ext_database.py b/api/extensions/ext_database.py new file mode 100644 index 0000000000..9121c6ead9 --- /dev/null +++ b/api/extensions/ext_database.py @@ -0,0 +1,7 @@ +from flask_sqlalchemy import SQLAlchemy + +db = SQLAlchemy() + + +def init_app(app): + db.init_app(app) diff --git a/api/extensions/ext_login.py b/api/extensions/ext_login.py new file mode 100644 index 0000000000..f7d5cffdda --- /dev/null +++ b/api/extensions/ext_login.py @@ -0,0 +1,7 @@ +import flask_login + +login_manager = flask_login.LoginManager() + + +def init_app(app): + login_manager.init_app(app) diff --git a/api/extensions/ext_migrate.py b/api/extensions/ext_migrate.py new file mode 100644 index 0000000000..e7b278fc38 --- /dev/null +++ b/api/extensions/ext_migrate.py @@ -0,0 +1,5 @@ +import flask_migrate + + +def init(app, db): + flask_migrate.Migrate(app, db) diff --git a/api/extensions/ext_redis.py b/api/extensions/ext_redis.py new file mode 100644 index 0000000000..c3e021e798 --- /dev/null +++ b/api/extensions/ext_redis.py @@ -0,0 +1,18 @@ +import redis + + +redis_client = redis.Redis() + + +def init_app(app): + redis_client.connection_pool = redis.ConnectionPool(**{ + 'host': app.config.get('REDIS_HOST', 'localhost'), + 'port': app.config.get('REDIS_PORT', 6379), + 'password': app.config.get('REDIS_PASSWORD', None), + 'db': app.config.get('REDIS_DB', 0), + 'encoding': 'utf-8', + 'encoding_errors': 'strict', + 'decode_responses': False + }) + + app.extensions['redis'] = redis_client diff --git a/api/extensions/ext_sentry.py b/api/extensions/ext_sentry.py new file mode 100644 index 0000000000..f05c10bc08 --- /dev/null +++ b/api/extensions/ext_sentry.py @@ -0,0 +1,20 @@ +import sentry_sdk +from sentry_sdk.integrations.celery import CeleryIntegration +from sentry_sdk.integrations.flask import FlaskIntegration +from werkzeug.exceptions import HTTPException + + +def init_app(app): + if app.config.get('SENTRY_DSN'): + sentry_sdk.init( + dsn=app.config.get('SENTRY_DSN'), + integrations=[ + FlaskIntegration(), + CeleryIntegration() + ], + ignore_errors=[HTTPException, ValueError], + traces_sample_rate=app.config.get('SENTRY_TRACES_SAMPLE_RATE', 1.0), + profiles_sample_rate=app.config.get('SENTRY_PROFILES_SAMPLE_RATE', 1.0), + environment=app.config.get('DEPLOY_ENV'), + release=f"dify-{app.config.get('CURRENT_VERSION')}-{app.config.get('COMMIT_SHA')}" + ) diff --git a/api/extensions/ext_session.py b/api/extensions/ext_session.py new file mode 100644 index 0000000000..5b454d469e --- /dev/null +++ b/api/extensions/ext_session.py @@ -0,0 +1,168 @@ +import redis +from flask import request +from flask_session import Session, SqlAlchemySessionInterface, RedisSessionInterface +from flask_session.sessions import total_seconds +from itsdangerous import want_bytes + +from extensions.ext_database import db + +sess = Session() + + +def init_app(app): + sqlalchemy_session_interface = CustomSqlAlchemySessionInterface( + app, + db, + app.config.get('SESSION_SQLALCHEMY_TABLE', 'sessions'), + app.config.get('SESSION_KEY_PREFIX', 'session:'), + app.config.get('SESSION_USE_SIGNER', False), + app.config.get('SESSION_PERMANENT', True) + ) + + session_type = app.config.get('SESSION_TYPE') + if session_type == 'sqlalchemy': + app.session_interface = sqlalchemy_session_interface + elif session_type == 'redis': + sess_redis_client = redis.Redis() + sess_redis_client.connection_pool = redis.ConnectionPool(**{ + 'host': app.config.get('SESSION_REDIS_HOST', 'localhost'), + 'port': app.config.get('SESSION_REDIS_PORT', 6379), + 'password': app.config.get('SESSION_REDIS_PASSWORD', None), + 'db': app.config.get('SESSION_REDIS_DB', 2), + 'encoding': 'utf-8', + 'encoding_errors': 'strict', + 'decode_responses': False + }) + + app.extensions['session_redis'] = sess_redis_client + + app.session_interface = CustomRedisSessionInterface( + sess_redis_client, + app.config.get('SESSION_KEY_PREFIX', 'session:'), + app.config.get('SESSION_USE_SIGNER', False), + app.config.get('SESSION_PERMANENT', True) + ) + + +class CustomSqlAlchemySessionInterface(SqlAlchemySessionInterface): + + def __init__( + self, + app, + db, + table, + key_prefix, + use_signer=False, + permanent=True, + sequence=None, + autodelete=False, + ): + if db is None: + from flask_sqlalchemy import SQLAlchemy + + db = SQLAlchemy(app) + self.db = db + self.key_prefix = key_prefix + self.use_signer = use_signer + self.permanent = permanent + self.autodelete = autodelete + self.sequence = sequence + self.has_same_site_capability = hasattr(self, "get_cookie_samesite") + + class Session(self.db.Model): + __tablename__ = table + + if sequence: + id = self.db.Column( # noqa: A003, VNE003, A001 + self.db.Integer, self.db.Sequence(sequence), primary_key=True + ) + else: + id = self.db.Column( # noqa: A003, VNE003, A001 + self.db.Integer, primary_key=True + ) + + session_id = self.db.Column(self.db.String(255), unique=True) + data = self.db.Column(self.db.LargeBinary) + expiry = self.db.Column(self.db.DateTime) + + def __init__(self, session_id, data, expiry): + self.session_id = session_id + self.data = data + self.expiry = expiry + + def __repr__(self): + return f"" + + self.sql_session_model = Session + + def save_session(self, *args, **kwargs): + if request.blueprint == 'service_api': + return + elif request.method == 'OPTIONS': + return + elif request.endpoint and request.endpoint == 'health': + return + return super().save_session(*args, **kwargs) + + +class CustomRedisSessionInterface(RedisSessionInterface): + + def save_session(self, app, session, response): + if request.blueprint == 'service_api': + return + elif request.method == 'OPTIONS': + return + elif request.endpoint and request.endpoint == 'health': + return + + if not self.should_set_cookie(app, session): + return + domain = self.get_cookie_domain(app) + path = self.get_cookie_path(app) + if not session: + if session.modified: + self.redis.delete(self.key_prefix + session.sid) + response.delete_cookie( + app.config["SESSION_COOKIE_NAME"], domain=domain, path=path + ) + return + + # Modification case. There are upsides and downsides to + # emitting a set-cookie header each request. The behavior + # is controlled by the :meth:`should_set_cookie` method + # which performs a quick check to figure out if the cookie + # should be set or not. This is controlled by the + # SESSION_REFRESH_EACH_REQUEST config flag as well as + # the permanent flag on the session itself. + # if not self.should_set_cookie(app, session): + # return + conditional_cookie_kwargs = {} + httponly = self.get_cookie_httponly(app) + secure = self.get_cookie_secure(app) + if self.has_same_site_capability: + conditional_cookie_kwargs["samesite"] = self.get_cookie_samesite(app) + expires = self.get_expiration_time(app, session) + + if session.permanent: + value = self.serializer.dumps(dict(session)) + if value is not None: + self.redis.setex( + name=self.key_prefix + session.sid, + value=value, + time=total_seconds(app.permanent_session_lifetime), + ) + + if self.use_signer: + session_id = self._get_signer(app).sign(want_bytes(session.sid)).decode("utf-8") + else: + session_id = session.sid + response.set_cookie( + app.config["SESSION_COOKIE_NAME"], + session_id, + expires=expires, + httponly=httponly, + domain=domain, + path=path, + secure=secure, + **conditional_cookie_kwargs, + ) diff --git a/api/extensions/ext_storage.py b/api/extensions/ext_storage.py new file mode 100644 index 0000000000..dc44892024 --- /dev/null +++ b/api/extensions/ext_storage.py @@ -0,0 +1,108 @@ +import os +import shutil +from contextlib import closing + +import boto3 +from botocore.exceptions import ClientError +from flask import Flask + + +class Storage: + def __init__(self): + self.storage_type = None + self.bucket_name = None + self.client = None + self.folder = None + + def init_app(self, app: Flask): + self.storage_type = app.config.get('STORAGE_TYPE') + if self.storage_type == 's3': + self.bucket_name = app.config.get('S3_BUCKET_NAME') + self.client = boto3.client( + 's3', + aws_secret_access_key=app.config.get('S3_SECRET_KEY'), + aws_access_key_id=app.config.get('S3_ACCESS_KEY'), + endpoint_url=app.config.get('S3_ENDPOINT'), + region_name=app.config.get('S3_REGION') + ) + else: + self.folder = app.config.get('STORAGE_LOCAL_PATH') + if not os.path.isabs(self.folder): + self.folder = os.path.join(app.root_path, self.folder) + + def save(self, filename, data): + if self.storage_type == 's3': + self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data) + else: + if not self.folder or self.folder.endswith('/'): + filename = self.folder + filename + else: + filename = self.folder + '/' + filename + + folder = os.path.dirname(filename) + os.makedirs(folder, exist_ok=True) + + with open(os.path.join(os.getcwd(), filename), "wb") as f: + f.write(data) + + def load(self, filename): + if self.storage_type == 's3': + try: + with closing(self.client) as client: + data = client.get_object(Bucket=self.bucket_name, Key=filename)['Body'].read() + except ClientError as ex: + if ex.response['Error']['Code'] == 'NoSuchKey': + raise FileNotFoundError("File not found") + else: + raise + else: + if not self.folder or self.folder.endswith('/'): + filename = self.folder + filename + else: + filename = self.folder + '/' + filename + + if not os.path.exists(filename): + raise FileNotFoundError("File not found") + + with open(filename, "rb") as f: + data = f.read() + + return data + + def download(self, filename, target_filepath): + if self.storage_type == 's3': + with closing(self.client) as client: + client.download_file(self.bucket_name, filename, target_filepath) + else: + if not self.folder or self.folder.endswith('/'): + filename = self.folder + filename + else: + filename = self.folder + '/' + filename + + if not os.path.exists(filename): + raise FileNotFoundError("File not found") + + shutil.copyfile(filename, target_filepath) + + def exists(self, filename): + if self.storage_type == 's3': + with closing(self.client) as client: + try: + client.head_object(Bucket=self.bucket_name, Key=filename) + return True + except: + return False + else: + if not self.folder or self.folder.endswith('/'): + filename = self.folder + filename + else: + filename = self.folder + '/' + filename + + return os.path.exists(filename) + + +storage = Storage() + + +def init_app(app: Flask): + storage.init_app(app) diff --git a/api/extensions/ext_vector_store.py b/api/extensions/ext_vector_store.py new file mode 100644 index 0000000000..4ed7a93422 --- /dev/null +++ b/api/extensions/ext_vector_store.py @@ -0,0 +1,7 @@ +from core.vector_store.vector_store import VectorStore + +vector_store = VectorStore() + + +def init_app(app): + vector_store.init_app(app) diff --git a/api/libs/__init__.py b/api/libs/__init__.py new file mode 100644 index 0000000000..380474e035 --- /dev/null +++ b/api/libs/__init__.py @@ -0,0 +1 @@ +# -*- coding:utf-8 -*- diff --git a/api/libs/ecc_aes.py b/api/libs/ecc_aes.py new file mode 100644 index 0000000000..aef3214535 --- /dev/null +++ b/api/libs/ecc_aes.py @@ -0,0 +1,82 @@ +from Crypto.Cipher import AES +from Crypto.Hash import SHA256 +from Crypto.PublicKey import ECC +from Crypto.Util.Padding import pad, unpad + + +class ECC_AES: + def __init__(self, curve='P-256'): + self.curve = curve + self._aes_key = None + self._private_key = None + + def _derive_aes_key(self, ecc_key, nonce): + if not self._aes_key: + hasher = SHA256.new() + hasher.update(ecc_key.export_key(format='DER') + nonce.encode()) + self._aes_key = hasher.digest()[:32] + return self._aes_key + + def generate_key_pair(self): + private_key = ECC.generate(curve=self.curve) + public_key = private_key.public_key() + + pem_private = private_key.export_key(format='PEM') + pem_public = public_key.export_key(format='PEM') + + return pem_private, pem_public + + def load_private_key(self, private_key_pem): + self._private_key = ECC.import_key(private_key_pem) + self._aes_key = None + + def encrypt(self, text, nonce): + if not self._private_key: + raise ValueError("Private key not loaded") + + # Generate AES key using ECC private key and nonce + aes_key = self._derive_aes_key(self._private_key, nonce) + + # Encrypt data using AES key + cipher = AES.new(aes_key, AES.MODE_ECB) + padded_text = pad(text.encode(), AES.block_size) + ciphertext = cipher.encrypt(padded_text) + + return ciphertext + + def decrypt(self, ciphertext, nonce): + if not self._private_key: + raise ValueError("Private key not loaded") + + # Generate AES key using ECC private key and nonce + aes_key = self._derive_aes_key(self._private_key, nonce) + + # Decrypt data using AES key + cipher = AES.new(aes_key, AES.MODE_ECB) + padded_plaintext = cipher.decrypt(ciphertext) + plaintext = unpad(padded_plaintext, AES.block_size) + + return plaintext.decode() + + +if __name__ == '__main__': + ecc_aes = ECC_AES() + + # Generate key pairs for the user + private_key, public_key = ecc_aes.generate_key_pair() + ecc_aes.load_private_key(private_key) + nonce = "THIS-IS-USER-ID" + + print(private_key) + + # Encrypt a message + message = "Hello, this is a secret message!" + encrypted_message = ecc_aes.encrypt(message, nonce) + print(f"Encrypted message: {encrypted_message.hex()}") + + # Decrypt the message + decrypted_message = ecc_aes.decrypt(encrypted_message, nonce) + print(f"Decrypted message: {decrypted_message}") + + # Check if the original message and decrypted message are the same + assert message == decrypted_message, "Original message and decrypted message do not match" diff --git a/api/libs/exception.py b/api/libs/exception.py new file mode 100644 index 0000000000..567062f064 --- /dev/null +++ b/api/libs/exception.py @@ -0,0 +1,17 @@ +from typing import Optional + +from werkzeug.exceptions import HTTPException + + +class BaseHTTPException(HTTPException): + error_code: str = 'unknown' + data: Optional[dict] = None + + def __init__(self, description=None, response=None): + super().__init__(description, response) + + self.data = { + "code": self.error_code, + "message": self.description, + "status": self.code, + } \ No newline at end of file diff --git a/api/libs/external_api.py b/api/libs/external_api.py new file mode 100644 index 0000000000..b5cc8fb9c5 --- /dev/null +++ b/api/libs/external_api.py @@ -0,0 +1,115 @@ +import re +import sys + +from flask import got_request_exception, current_app +from flask_restful import Api, http_status_message +from werkzeug.datastructures import Headers +from werkzeug.exceptions import HTTPException + + +class ExternalApi(Api): + + def handle_error(self, e): + """Error handler for the API transforms a raised exception into a Flask + response, with the appropriate HTTP status code and body. + + :param e: the raised Exception object + :type e: Exception + + """ + got_request_exception.send(current_app, exception=e) + + headers = Headers() + if isinstance(e, HTTPException): + if e.response is not None: + resp = e.get_response() + return resp + + status_code = e.code + default_data = { + 'code': re.sub(r'(?= 500: + exc_info = sys.exc_info() + if exc_info[1] is None: + exc_info = None + current_app.log_exception(exc_info) + + if status_code == 406 and self.default_mediatype is None: + # if we are handling NotAcceptable (406), make sure that + # make_response uses a representation we support as the + # default mediatype (so that make_response doesn't throw + # another NotAcceptable error). + supported_mediatypes = list(self.representations.keys()) # only supported application/json + fallback_mediatype = supported_mediatypes[0] if supported_mediatypes else "text/plain" + data = { + 'code': 'not_acceptable', + 'message': data.get('message') + } + resp = self.make_response( + data, + status_code, + headers, + fallback_mediatype = fallback_mediatype + ) + elif status_code == 400: + if isinstance(data.get('message'), dict): + param_key, param_value = list(data.get('message').items())[0] + data = { + 'code': 'invalid_param', + 'message': param_value, + 'params': param_key + } + else: + if 'code' not in data: + data['code'] = 'unknown' + + resp = self.make_response(data, status_code, headers) + else: + if 'code' not in data: + data['code'] = 'unknown' + + resp = self.make_response(data, status_code, headers) + + if status_code == 401: + resp = self.unauthorized(resp) + return resp diff --git a/api/libs/helper.py b/api/libs/helper.py new file mode 100644 index 0000000000..bbf01cbad7 --- /dev/null +++ b/api/libs/helper.py @@ -0,0 +1,149 @@ +# -*- coding:utf-8 -*- +import re +import subprocess +import uuid +from datetime import datetime +from zoneinfo import available_timezones +import random +import string + +from flask_restful import fields + + +def run(script): + return subprocess.getstatusoutput('source /root/.bashrc && ' + script) + + +class TimestampField(fields.Raw): + def format(self, value): + return int(value.timestamp()) + + +def email(email): + # Define a regex pattern for email addresses + pattern = r"^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$" + # Check if the email matches the pattern + if re.match(pattern, email) is not None: + return email + + error = ('{email} is not a valid email.' + .format(email=email)) + raise ValueError(error) + + +def uuid_value(value): + if value == '': + return str(value) + + try: + uuid_obj = uuid.UUID(value) + return str(uuid_obj) + except ValueError: + error = ('{value} is not a valid uuid.' + .format(value=value)) + raise ValueError(error) + + +def timestamp_value(timestamp): + try: + int_timestamp = int(timestamp) + if int_timestamp < 0: + raise ValueError + return int_timestamp + except ValueError: + error = ('{timestamp} is not a valid timestamp.' + .format(timestamp=timestamp)) + raise ValueError(error) + + +class str_len(object): + """ Restrict input to an integer in a range (inclusive) """ + + def __init__(self, max_length, argument='argument'): + self.max_length = max_length + self.argument = argument + + def __call__(self, value): + length = len(value) + if length > self.max_length: + error = ('Invalid {arg}: {val}. {arg} cannot exceed length {length}' + .format(arg=self.argument, val=value, length=self.max_length)) + raise ValueError(error) + + return value + + +class float_range(object): + """ Restrict input to an float in a range (inclusive) """ + def __init__(self, low, high, argument='argument'): + self.low = low + self.high = high + self.argument = argument + + def __call__(self, value): + value = _get_float(value) + if value < self.low or value > self.high: + error = ('Invalid {arg}: {val}. {arg} must be within the range {lo} - {hi}' + .format(arg=self.argument, val=value, lo=self.low, hi=self.high)) + raise ValueError(error) + + return value + + +class datetime_string(object): + def __init__(self, format, argument='argument'): + self.format = format + self.argument = argument + + def __call__(self, value): + try: + datetime.strptime(value, self.format) + except ValueError: + error = ('Invalid {arg}: {val}. {arg} must be conform to the format {format}' + .format(arg=self.argument, val=value, lo=self.format)) + raise ValueError(error) + + return value + + +def _get_float(value): + try: + return float(value) + except (TypeError, ValueError): + raise ValueError('{0} is not a valid float'.format(value)) + + +def supported_language(lang): + if lang in ['en-US', 'zh-Hans']: + return lang + + error = ('{lang} is not a valid language.' + .format(lang=lang)) + raise ValueError(error) + + +def timezone(timezone_string): + if timezone_string and timezone_string in available_timezones(): + return timezone_string + + error = ('{timezone_string} is not a valid timezone.' + .format(timezone_string=timezone_string)) + raise ValueError(error) + + +def generate_string(n): + letters_digits = string.ascii_letters + string.digits + result = "" + for i in range(n): + result += random.choice(letters_digits) + + return result + + +def get_remote_ip(request): + if request.headers.get('CF-Connecting-IP'): + return request.headers.get('Cf-Connecting-Ip') + elif request.headers.getlist("X-Forwarded-For"): + return request.headers.getlist("X-Forwarded-For")[0] + else: + return request.remote_addr diff --git a/api/libs/infinite_scroll_pagination.py b/api/libs/infinite_scroll_pagination.py new file mode 100644 index 0000000000..076cb383b8 --- /dev/null +++ b/api/libs/infinite_scroll_pagination.py @@ -0,0 +1,7 @@ +# -*- coding:utf-8 -*- + +class InfiniteScrollPagination: + def __init__(self, data, limit, has_more): + self.data = data + self.limit = limit + self.has_more = has_more diff --git a/api/libs/oauth.py b/api/libs/oauth.py new file mode 100644 index 0000000000..ce41f0c22c --- /dev/null +++ b/api/libs/oauth.py @@ -0,0 +1,136 @@ +import urllib.parse +from dataclasses import dataclass + +import requests + + +@dataclass +class OAuthUserInfo: + id: str + name: str + email: str + + +class OAuth: + def __init__(self, client_id: str, client_secret: str, redirect_uri: str): + self.client_id = client_id + self.client_secret = client_secret + self.redirect_uri = redirect_uri + + def get_authorization_url(self): + raise NotImplementedError() + + def get_access_token(self, code: str): + raise NotImplementedError() + + def get_raw_user_info(self, token: str): + raise NotImplementedError() + + def get_user_info(self, token: str) -> OAuthUserInfo: + raw_info = self.get_raw_user_info(token) + return self._transform_user_info(raw_info) + + def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo: + raise NotImplementedError() + + +class GitHubOAuth(OAuth): + _AUTH_URL = 'https://github.com/login/oauth/authorize' + _TOKEN_URL = 'https://github.com/login/oauth/access_token' + _USER_INFO_URL = 'https://api.github.com/user' + _EMAIL_INFO_URL = 'https://api.github.com/user/emails' + + def get_authorization_url(self): + params = { + 'client_id': self.client_id, + 'redirect_uri': self.redirect_uri, + 'scope': 'user:email' # Request only basic user information + } + return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" + + def get_access_token(self, code: str): + data = { + 'client_id': self.client_id, + 'client_secret': self.client_secret, + 'code': code, + 'redirect_uri': self.redirect_uri + } + headers = {'Accept': 'application/json'} + response = requests.post(self._TOKEN_URL, data=data, headers=headers) + + response_json = response.json() + access_token = response_json.get('access_token') + + if not access_token: + raise ValueError(f"Error in GitHub OAuth: {response_json}") + + return access_token + + def get_raw_user_info(self, token: str): + headers = {'Authorization': f"token {token}"} + response = requests.get(self._USER_INFO_URL, headers=headers) + response.raise_for_status() + user_info = response.json() + + email_response = requests.get(self._EMAIL_INFO_URL, headers=headers) + email_info = email_response.json() + primary_email = next((email for email in email_info if email['primary'] == True), None) + + return {**user_info, 'email': primary_email['email']} + + def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo: + email = raw_info.get('email') + if not email: + email = f"{raw_info['id']}+{raw_info['login']}@users.noreply.github.com" + return OAuthUserInfo( + id=str(raw_info['id']), + name=raw_info['name'], + email=email + ) + + +class GoogleOAuth(OAuth): + _AUTH_URL = 'https://accounts.google.com/o/oauth2/v2/auth' + _TOKEN_URL = 'https://oauth2.googleapis.com/token' + _USER_INFO_URL = 'https://www.googleapis.com/oauth2/v3/userinfo' + + def get_authorization_url(self): + params = { + 'client_id': self.client_id, + 'response_type': 'code', + 'redirect_uri': self.redirect_uri, + 'scope': 'openid email' + } + return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" + + def get_access_token(self, code: str): + data = { + 'client_id': self.client_id, + 'client_secret': self.client_secret, + 'code': code, + 'grant_type': 'authorization_code', + 'redirect_uri': self.redirect_uri + } + headers = {'Accept': 'application/json'} + response = requests.post(self._TOKEN_URL, data=data, headers=headers) + + response_json = response.json() + access_token = response_json.get('access_token') + + if not access_token: + raise ValueError(f"Error in Google OAuth: {response_json}") + + return access_token + + def get_raw_user_info(self, token: str): + headers = {'Authorization': f"Bearer {token}"} + response = requests.get(self._USER_INFO_URL, headers=headers) + response.raise_for_status() + return response.json() + + def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo: + return OAuthUserInfo( + id=str(raw_info['sub']), + name=None, + email=raw_info['email'] + ) diff --git a/api/libs/password.py b/api/libs/password.py new file mode 100644 index 0000000000..dde77c1046 --- /dev/null +++ b/api/libs/password.py @@ -0,0 +1,26 @@ +# -*- coding:utf-8 -*- +import base64 +import binascii +import hashlib +import re + +password_pattern = r"^(?=.*[a-zA-Z])(?=.*\d).{8,}$" + +def valid_password(password): + # Define a regex pattern for password rules + pattern = password_pattern + # Check if the password matches the pattern + if re.match(pattern, password) is not None: + return password + + raise ValueError('Not a valid password.') + + +def hash_password(password_str, salt_byte): + dk = hashlib.pbkdf2_hmac('sha256', password_str.encode('utf-8'), salt_byte, 10000) + return binascii.hexlify(dk) + + +def compare_password(password_str, password_hashed_base64, salt_base64): + # compare password for login + return hash_password(password_str, base64.b64decode(salt_base64)) == base64.b64decode(password_hashed_base64) diff --git a/api/libs/rsa.py b/api/libs/rsa.py new file mode 100644 index 0000000000..8741989a9a --- /dev/null +++ b/api/libs/rsa.py @@ -0,0 +1,58 @@ +# -*- coding:utf-8 -*- +import hashlib + +from Crypto.Cipher import PKCS1_OAEP +from Crypto.PublicKey import RSA + +from extensions.ext_redis import redis_client +from extensions.ext_storage import storage + + +# TODO: PKCS1_OAEP is no longer recommended for new systems and protocols. It is recommended to migrate to PKCS1_PSS. + + +def generate_key_pair(tenant_id): + private_key = RSA.generate(2048) + public_key = private_key.publickey() + + pem_private = private_key.export_key() + pem_public = public_key.export_key() + + filepath = "privkeys/{tenant_id}".format(tenant_id=tenant_id) + "/private.pem" + + storage.save(filepath, pem_private) + + return pem_public.decode() + + +def encrypt(text, public_key): + if isinstance(public_key, str): + public_key = public_key.encode() + + rsa_key = RSA.import_key(public_key) + cipher = PKCS1_OAEP.new(rsa_key) + encrypted_text = cipher.encrypt(text.encode()) + return encrypted_text + + +def decrypt(encrypted_text, tenant_id): + filepath = "privkeys/{tenant_id}".format(tenant_id=tenant_id) + "/private.pem" + + cache_key = 'tenant_privkey:{hash}'.format(hash=hashlib.sha3_256(filepath.encode()).hexdigest()) + private_key = redis_client.get(cache_key) + if not private_key: + try: + private_key = storage.load(filepath) + except FileNotFoundError: + raise PrivkeyNotFoundError("Private key not found, tenant_id: {tenant_id}".format(tenant_id=tenant_id)) + + redis_client.setex(cache_key, 120, private_key) + + rsa_key = RSA.import_key(private_key) + cipher = PKCS1_OAEP.new(rsa_key) + decrypted_text = cipher.decrypt(encrypted_text) + return decrypted_text.decode() + + +class PrivkeyNotFoundError(Exception): + pass diff --git a/api/migrations/README b/api/migrations/README new file mode 100644 index 0000000000..0e04844159 --- /dev/null +++ b/api/migrations/README @@ -0,0 +1 @@ +Single-database configuration for Flask. diff --git a/api/migrations/alembic.ini b/api/migrations/alembic.ini new file mode 100644 index 0000000000..ec9d45c26a --- /dev/null +++ b/api/migrations/alembic.ini @@ -0,0 +1,50 @@ +# A generic, single database configuration. + +[alembic] +# template used to generate migration files +# file_template = %%(rev)s_%%(slug)s + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic,flask_migrate + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[logger_flask_migrate] +level = INFO +handlers = +qualname = flask_migrate + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/api/migrations/env.py b/api/migrations/env.py new file mode 100644 index 0000000000..0ac25ee989 --- /dev/null +++ b/api/migrations/env.py @@ -0,0 +1,113 @@ +import logging +from logging.config import fileConfig + +from flask import current_app + +from alembic import context + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +fileConfig(config.config_file_name) +logger = logging.getLogger('alembic.env') + + +def get_engine(): + return current_app.extensions['migrate'].db.engine + + +def get_engine_url(): + try: + return get_engine().url.render_as_string(hide_password=False).replace( + '%', '%%') + except AttributeError: + return str(get_engine().url).replace('%', '%%') + + +# add your model's MetaData object here +# for 'autogenerate' support +# from myapp import mymodel +# target_metadata = mymodel.Base.metadata +config.set_main_option('sqlalchemy.url', get_engine_url()) +target_db = current_app.extensions['migrate'].db + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def get_metadata(): + if hasattr(target_db, 'metadatas'): + return target_db.metadatas[None] + return target_db.metadata + + +def include_object(object, name, type_, reflected, compare_to): + if type_ == "foreign_key_constraint": + return False + else: + return True + + +def run_migrations_offline(): + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, target_metadata=get_metadata(), literal_binds=True + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online(): + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + + # this callback is used to prevent an auto-migration from being generated + # when there are no changes to the schema + # reference: http://alembic.zzzcomputing.com/en/latest/cookbook.html + def process_revision_directives(context, revision, directives): + if getattr(config.cmd_opts, 'autogenerate', False): + script = directives[0] + if script.upgrade_ops.is_empty(): + directives[:] = [] + logger.info('No changes in schema detected.') + + connectable = get_engine() + + with connectable.connect() as connection: + context.configure( + connection=connection, + target_metadata=get_metadata(), + process_revision_directives=process_revision_directives, + include_object=include_object, + **current_app.extensions['migrate'].configure_args + ) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/api/migrations/script.py.mako b/api/migrations/script.py.mako new file mode 100644 index 0000000000..2c0156303a --- /dev/null +++ b/api/migrations/script.py.mako @@ -0,0 +1,24 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision = ${repr(up_revision)} +down_revision = ${repr(down_revision)} +branch_labels = ${repr(branch_labels)} +depends_on = ${repr(depends_on)} + + +def upgrade(): + ${upgrades if upgrades else "pass"} + + +def downgrade(): + ${downgrades if downgrades else "pass"} diff --git a/api/migrations/versions/64b051264f32_init.py b/api/migrations/versions/64b051264f32_init.py new file mode 100644 index 0000000000..84e1f2af17 --- /dev/null +++ b/api/migrations/versions/64b051264f32_init.py @@ -0,0 +1,793 @@ +"""init + +Revision ID: 64b051264f32 +Revises: +Create Date: 2023-05-13 14:26:59.085018 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '64b051264f32' +down_revision = None +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('account_integrates', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('account_id', postgresql.UUID(), nullable=False), + sa.Column('provider', sa.String(length=16), nullable=False), + sa.Column('open_id', sa.String(length=255), nullable=False), + sa.Column('encrypted_token', sa.String(length=255), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='account_integrate_pkey'), + sa.UniqueConstraint('account_id', 'provider', name='unique_account_provider'), + sa.UniqueConstraint('provider', 'open_id', name='unique_provider_open_id') + ) + op.create_table('accounts', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('name', sa.String(length=255), nullable=False), + sa.Column('email', sa.String(length=255), nullable=False), + sa.Column('password', sa.String(length=255), nullable=True), + sa.Column('password_salt', sa.String(length=255), nullable=True), + sa.Column('avatar', sa.String(length=255), nullable=True), + sa.Column('interface_language', sa.String(length=255), nullable=True), + sa.Column('interface_theme', sa.String(length=255), nullable=True), + sa.Column('timezone', sa.String(length=255), nullable=True), + sa.Column('last_login_at', sa.DateTime(), nullable=True), + sa.Column('last_login_ip', sa.String(length=255), nullable=True), + sa.Column('status', sa.String(length=16), server_default=sa.text("'active'::character varying"), nullable=False), + sa.Column('initialized_at', sa.DateTime(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='account_pkey') + ) + with op.batch_alter_table('accounts', schema=None) as batch_op: + batch_op.create_index('account_email_idx', ['email'], unique=False) + + op.create_table('api_requests', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('api_token_id', postgresql.UUID(), nullable=False), + sa.Column('path', sa.String(length=255), nullable=False), + sa.Column('request', sa.Text(), nullable=True), + sa.Column('response', sa.Text(), nullable=True), + sa.Column('ip', sa.String(length=255), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='api_request_pkey') + ) + with op.batch_alter_table('api_requests', schema=None) as batch_op: + batch_op.create_index('api_request_token_idx', ['tenant_id', 'api_token_id'], unique=False) + + op.create_table('api_tokens', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('app_id', postgresql.UUID(), nullable=True), + sa.Column('dataset_id', postgresql.UUID(), nullable=True), + sa.Column('type', sa.String(length=16), nullable=False), + sa.Column('token', sa.String(length=255), nullable=False), + sa.Column('last_used_at', sa.DateTime(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='api_token_pkey') + ) + with op.batch_alter_table('api_tokens', schema=None) as batch_op: + batch_op.create_index('api_token_app_id_type_idx', ['app_id', 'type'], unique=False) + batch_op.create_index('api_token_token_idx', ['token', 'type'], unique=False) + + op.create_table('app_dataset_joins', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('app_id', postgresql.UUID(), nullable=False), + sa.Column('dataset_id', postgresql.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='app_dataset_join_pkey') + ) + with op.batch_alter_table('app_dataset_joins', schema=None) as batch_op: + batch_op.create_index('app_dataset_join_app_dataset_idx', ['dataset_id', 'app_id'], unique=False) + + op.create_table('app_model_configs', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('app_id', postgresql.UUID(), nullable=False), + sa.Column('provider', sa.String(length=255), nullable=False), + sa.Column('model_id', sa.String(length=255), nullable=False), + sa.Column('configs', sa.JSON(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('opening_statement', sa.Text(), nullable=True), + sa.Column('suggested_questions', sa.Text(), nullable=True), + sa.Column('suggested_questions_after_answer', sa.Text(), nullable=True), + sa.Column('more_like_this', sa.Text(), nullable=True), + sa.Column('model', sa.Text(), nullable=True), + sa.Column('user_input_form', sa.Text(), nullable=True), + sa.Column('pre_prompt', sa.Text(), nullable=True), + sa.Column('agent_mode', sa.Text(), nullable=True), + sa.PrimaryKeyConstraint('id', name='app_model_config_pkey') + ) + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.create_index('app_app_id_idx', ['app_id'], unique=False) + + op.create_table('apps', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('name', sa.String(length=255), nullable=False), + sa.Column('mode', sa.String(length=255), nullable=False), + sa.Column('icon', sa.String(length=255), nullable=True), + sa.Column('icon_background', sa.String(length=255), nullable=True), + sa.Column('app_model_config_id', postgresql.UUID(), nullable=True), + sa.Column('status', sa.String(length=255), server_default=sa.text("'normal'::character varying"), nullable=False), + sa.Column('enable_site', sa.Boolean(), nullable=False), + sa.Column('enable_api', sa.Boolean(), nullable=False), + sa.Column('api_rpm', sa.Integer(), nullable=False), + sa.Column('api_rph', sa.Integer(), nullable=False), + sa.Column('is_demo', sa.Boolean(), server_default=sa.text('false'), nullable=False), + sa.Column('is_public', sa.Boolean(), server_default=sa.text('false'), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='app_pkey') + ) + with op.batch_alter_table('apps', schema=None) as batch_op: + batch_op.create_index('app_tenant_id_idx', ['tenant_id'], unique=False) + + op.execute('CREATE SEQUENCE task_id_sequence;') + op.execute('CREATE SEQUENCE taskset_id_sequence;') + + op.create_table('celery_taskmeta', + sa.Column('id', sa.Integer(), nullable=False, + server_default=sa.text('nextval(\'task_id_sequence\')')), + sa.Column('task_id', sa.String(length=155), nullable=True), + sa.Column('status', sa.String(length=50), nullable=True), + sa.Column('result', sa.PickleType(), nullable=True), + sa.Column('date_done', sa.DateTime(), nullable=True), + sa.Column('traceback', sa.Text(), nullable=True), + sa.Column('name', sa.String(length=155), nullable=True), + sa.Column('args', sa.LargeBinary(), nullable=True), + sa.Column('kwargs', sa.LargeBinary(), nullable=True), + sa.Column('worker', sa.String(length=155), nullable=True), + sa.Column('retries', sa.Integer(), nullable=True), + sa.Column('queue', sa.String(length=155), nullable=True), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('task_id') + ) + op.create_table('celery_tasksetmeta', + sa.Column('id', sa.Integer(), nullable=False, + server_default=sa.text('nextval(\'taskset_id_sequence\')')), + sa.Column('taskset_id', sa.String(length=155), nullable=True), + sa.Column('result', sa.PickleType(), nullable=True), + sa.Column('date_done', sa.DateTime(), nullable=True), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('taskset_id') + ) + op.create_table('conversations', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('app_id', postgresql.UUID(), nullable=False), + sa.Column('app_model_config_id', postgresql.UUID(), nullable=False), + sa.Column('model_provider', sa.String(length=255), nullable=False), + sa.Column('override_model_configs', sa.Text(), nullable=True), + sa.Column('model_id', sa.String(length=255), nullable=False), + sa.Column('mode', sa.String(length=255), nullable=False), + sa.Column('name', sa.String(length=255), nullable=False), + sa.Column('summary', sa.Text(), nullable=True), + sa.Column('inputs', sa.JSON(), nullable=True), + sa.Column('introduction', sa.Text(), nullable=True), + sa.Column('system_instruction', sa.Text(), nullable=True), + sa.Column('system_instruction_tokens', sa.Integer(), server_default=sa.text('0'), nullable=False), + sa.Column('status', sa.String(length=255), nullable=False), + sa.Column('from_source', sa.String(length=255), nullable=False), + sa.Column('from_end_user_id', postgresql.UUID(), nullable=True), + sa.Column('from_account_id', postgresql.UUID(), nullable=True), + sa.Column('read_at', sa.DateTime(), nullable=True), + sa.Column('read_account_id', postgresql.UUID(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='conversation_pkey') + ) + with op.batch_alter_table('conversations', schema=None) as batch_op: + batch_op.create_index('conversation_app_from_user_idx', ['app_id', 'from_source', 'from_end_user_id'], unique=False) + + op.create_table('dataset_keyword_tables', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('dataset_id', postgresql.UUID(), nullable=False), + sa.Column('keyword_table', sa.Text(), nullable=False), + sa.PrimaryKeyConstraint('id', name='dataset_keyword_table_pkey'), + sa.UniqueConstraint('dataset_id') + ) + with op.batch_alter_table('dataset_keyword_tables', schema=None) as batch_op: + batch_op.create_index('dataset_keyword_table_dataset_id_idx', ['dataset_id'], unique=False) + + op.create_table('dataset_process_rules', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('dataset_id', postgresql.UUID(), nullable=False), + sa.Column('mode', sa.String(length=255), server_default=sa.text("'automatic'::character varying"), nullable=False), + sa.Column('rules', sa.Text(), nullable=True), + sa.Column('created_by', postgresql.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='dataset_process_rule_pkey') + ) + with op.batch_alter_table('dataset_process_rules', schema=None) as batch_op: + batch_op.create_index('dataset_process_rule_dataset_id_idx', ['dataset_id'], unique=False) + + op.create_table('dataset_queries', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('dataset_id', postgresql.UUID(), nullable=False), + sa.Column('content', sa.Text(), nullable=False), + sa.Column('source', sa.String(length=255), nullable=False), + sa.Column('source_app_id', postgresql.UUID(), nullable=True), + sa.Column('created_by_role', sa.String(), nullable=False), + sa.Column('created_by', postgresql.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='dataset_query_pkey') + ) + with op.batch_alter_table('dataset_queries', schema=None) as batch_op: + batch_op.create_index('dataset_query_dataset_id_idx', ['dataset_id'], unique=False) + + op.create_table('datasets', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('name', sa.String(length=255), nullable=False), + sa.Column('description', sa.Text(), nullable=True), + sa.Column('provider', sa.String(length=255), server_default=sa.text("'vendor'::character varying"), nullable=False), + sa.Column('permission', sa.String(length=255), server_default=sa.text("'only_me'::character varying"), nullable=False), + sa.Column('data_source_type', sa.String(length=255), nullable=True), + sa.Column('indexing_technique', sa.String(length=255), nullable=True), + sa.Column('index_struct', sa.Text(), nullable=True), + sa.Column('created_by', postgresql.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_by', postgresql.UUID(), nullable=True), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='dataset_pkey') + ) + with op.batch_alter_table('datasets', schema=None) as batch_op: + batch_op.create_index('dataset_tenant_idx', ['tenant_id'], unique=False) + + op.create_table('dify_setups', + sa.Column('version', sa.String(length=255), nullable=False), + sa.Column('setup_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('version', name='dify_setup_pkey') + ) + op.create_table('document_segments', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('dataset_id', postgresql.UUID(), nullable=False), + sa.Column('document_id', postgresql.UUID(), nullable=False), + sa.Column('position', sa.Integer(), nullable=False), + sa.Column('content', sa.Text(), nullable=False), + sa.Column('word_count', sa.Integer(), nullable=False), + sa.Column('tokens', sa.Integer(), nullable=False), + sa.Column('keywords', sa.JSON(), nullable=True), + sa.Column('index_node_id', sa.String(length=255), nullable=True), + sa.Column('index_node_hash', sa.String(length=255), nullable=True), + sa.Column('hit_count', sa.Integer(), nullable=False), + sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False), + sa.Column('disabled_at', sa.DateTime(), nullable=True), + sa.Column('disabled_by', postgresql.UUID(), nullable=True), + sa.Column('status', sa.String(length=255), server_default=sa.text("'waiting'::character varying"), nullable=False), + sa.Column('created_by', postgresql.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('indexing_at', sa.DateTime(), nullable=True), + sa.Column('completed_at', sa.DateTime(), nullable=True), + sa.Column('error', sa.Text(), nullable=True), + sa.Column('stopped_at', sa.DateTime(), nullable=True), + sa.PrimaryKeyConstraint('id', name='document_segment_pkey') + ) + with op.batch_alter_table('document_segments', schema=None) as batch_op: + batch_op.create_index('document_segment_dataset_id_idx', ['dataset_id'], unique=False) + batch_op.create_index('document_segment_dataset_node_idx', ['dataset_id', 'index_node_id'], unique=False) + batch_op.create_index('document_segment_document_id_idx', ['document_id'], unique=False) + batch_op.create_index('document_segment_tenant_dataset_idx', ['dataset_id', 'tenant_id'], unique=False) + batch_op.create_index('document_segment_tenant_document_idx', ['document_id', 'tenant_id'], unique=False) + + op.create_table('documents', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('dataset_id', postgresql.UUID(), nullable=False), + sa.Column('position', sa.Integer(), nullable=False), + sa.Column('data_source_type', sa.String(length=255), nullable=False), + sa.Column('data_source_info', sa.Text(), nullable=True), + sa.Column('dataset_process_rule_id', postgresql.UUID(), nullable=True), + sa.Column('batch', sa.String(length=255), nullable=False), + sa.Column('name', sa.String(length=255), nullable=False), + sa.Column('created_from', sa.String(length=255), nullable=False), + sa.Column('created_by', postgresql.UUID(), nullable=False), + sa.Column('created_api_request_id', postgresql.UUID(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('processing_started_at', sa.DateTime(), nullable=True), + sa.Column('file_id', sa.Text(), nullable=True), + sa.Column('word_count', sa.Integer(), nullable=True), + sa.Column('parsing_completed_at', sa.DateTime(), nullable=True), + sa.Column('cleaning_completed_at', sa.DateTime(), nullable=True), + sa.Column('splitting_completed_at', sa.DateTime(), nullable=True), + sa.Column('tokens', sa.Integer(), nullable=True), + sa.Column('indexing_latency', sa.Float(), nullable=True), + sa.Column('completed_at', sa.DateTime(), nullable=True), + sa.Column('is_paused', sa.Boolean(), server_default=sa.text('false'), nullable=True), + sa.Column('paused_by', postgresql.UUID(), nullable=True), + sa.Column('paused_at', sa.DateTime(), nullable=True), + sa.Column('error', sa.Text(), nullable=True), + sa.Column('stopped_at', sa.DateTime(), nullable=True), + sa.Column('indexing_status', sa.String(length=255), server_default=sa.text("'waiting'::character varying"), nullable=False), + sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False), + sa.Column('disabled_at', sa.DateTime(), nullable=True), + sa.Column('disabled_by', postgresql.UUID(), nullable=True), + sa.Column('archived', sa.Boolean(), server_default=sa.text('false'), nullable=False), + sa.Column('archived_reason', sa.String(length=255), nullable=True), + sa.Column('archived_by', postgresql.UUID(), nullable=True), + sa.Column('archived_at', sa.DateTime(), nullable=True), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('doc_type', sa.String(length=40), nullable=True), + sa.Column('doc_metadata', sa.JSON(), nullable=True), + sa.PrimaryKeyConstraint('id', name='document_pkey') + ) + with op.batch_alter_table('documents', schema=None) as batch_op: + batch_op.create_index('document_dataset_id_idx', ['dataset_id'], unique=False) + batch_op.create_index('document_is_paused_idx', ['is_paused'], unique=False) + + op.create_table('embeddings', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('hash', sa.String(length=64), nullable=False), + sa.Column('embedding', sa.LargeBinary(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='embedding_pkey'), + sa.UniqueConstraint('hash', name='embedding_hash_idx') + ) + op.create_table('end_users', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('app_id', postgresql.UUID(), nullable=True), + sa.Column('type', sa.String(length=255), nullable=False), + sa.Column('external_user_id', sa.String(length=255), nullable=True), + sa.Column('name', sa.String(length=255), nullable=True), + sa.Column('is_anonymous', sa.Boolean(), server_default=sa.text('true'), nullable=False), + sa.Column('session_id', sa.String(length=255), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='end_user_pkey') + ) + with op.batch_alter_table('end_users', schema=None) as batch_op: + batch_op.create_index('end_user_session_id_idx', ['session_id', 'type'], unique=False) + batch_op.create_index('end_user_tenant_session_id_idx', ['tenant_id', 'session_id', 'type'], unique=False) + + op.create_table('installed_apps', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('app_id', postgresql.UUID(), nullable=False), + sa.Column('app_owner_tenant_id', postgresql.UUID(), nullable=False), + sa.Column('position', sa.Integer(), nullable=False), + sa.Column('is_pinned', sa.Boolean(), server_default=sa.text('false'), nullable=False), + sa.Column('last_used_at', sa.DateTime(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='installed_app_pkey'), + sa.UniqueConstraint('tenant_id', 'app_id', name='unique_tenant_app') + ) + with op.batch_alter_table('installed_apps', schema=None) as batch_op: + batch_op.create_index('installed_app_app_id_idx', ['app_id'], unique=False) + batch_op.create_index('installed_app_tenant_id_idx', ['tenant_id'], unique=False) + + op.create_table('invitation_codes', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('batch', sa.String(length=255), nullable=False), + sa.Column('code', sa.String(length=32), nullable=False), + sa.Column('status', sa.String(length=16), server_default=sa.text("'unused'::character varying"), nullable=False), + sa.Column('used_at', sa.DateTime(), nullable=True), + sa.Column('used_by_tenant_id', postgresql.UUID(), nullable=True), + sa.Column('used_by_account_id', postgresql.UUID(), nullable=True), + sa.Column('deprecated_at', sa.DateTime(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='invitation_code_pkey') + ) + with op.batch_alter_table('invitation_codes', schema=None) as batch_op: + batch_op.create_index('invitation_codes_batch_idx', ['batch'], unique=False) + batch_op.create_index('invitation_codes_code_idx', ['code', 'status'], unique=False) + + op.create_table('message_agent_thoughts', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('message_id', postgresql.UUID(), nullable=False), + sa.Column('message_chain_id', postgresql.UUID(), nullable=False), + sa.Column('position', sa.Integer(), nullable=False), + sa.Column('thought', sa.Text(), nullable=True), + sa.Column('tool', sa.Text(), nullable=True), + sa.Column('tool_input', sa.Text(), nullable=True), + sa.Column('observation', sa.Text(), nullable=True), + sa.Column('tool_process_data', sa.Text(), nullable=True), + sa.Column('message', sa.Text(), nullable=True), + sa.Column('message_token', sa.Integer(), nullable=True), + sa.Column('message_unit_price', sa.Numeric(), nullable=True), + sa.Column('answer', sa.Text(), nullable=True), + sa.Column('answer_token', sa.Integer(), nullable=True), + sa.Column('answer_unit_price', sa.Numeric(), nullable=True), + sa.Column('tokens', sa.Integer(), nullable=True), + sa.Column('total_price', sa.Numeric(), nullable=True), + sa.Column('currency', sa.String(), nullable=True), + sa.Column('latency', sa.Float(), nullable=True), + sa.Column('created_by_role', sa.String(), nullable=False), + sa.Column('created_by', postgresql.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='message_agent_thought_pkey') + ) + with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: + batch_op.create_index('message_agent_thought_message_chain_id_idx', ['message_chain_id'], unique=False) + batch_op.create_index('message_agent_thought_message_id_idx', ['message_id'], unique=False) + + op.create_table('message_chains', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('message_id', postgresql.UUID(), nullable=False), + sa.Column('type', sa.String(length=255), nullable=False), + sa.Column('input', sa.Text(), nullable=True), + sa.Column('output', sa.Text(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='message_chain_pkey') + ) + with op.batch_alter_table('message_chains', schema=None) as batch_op: + batch_op.create_index('message_chain_message_id_idx', ['message_id'], unique=False) + + op.create_table('message_feedbacks', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('app_id', postgresql.UUID(), nullable=False), + sa.Column('conversation_id', postgresql.UUID(), nullable=False), + sa.Column('message_id', postgresql.UUID(), nullable=False), + sa.Column('rating', sa.String(length=255), nullable=False), + sa.Column('content', sa.Text(), nullable=True), + sa.Column('from_source', sa.String(length=255), nullable=False), + sa.Column('from_end_user_id', postgresql.UUID(), nullable=True), + sa.Column('from_account_id', postgresql.UUID(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='message_feedback_pkey') + ) + with op.batch_alter_table('message_feedbacks', schema=None) as batch_op: + batch_op.create_index('message_feedback_app_idx', ['app_id'], unique=False) + batch_op.create_index('message_feedback_conversation_idx', ['conversation_id', 'from_source', 'rating'], unique=False) + batch_op.create_index('message_feedback_message_idx', ['message_id', 'from_source'], unique=False) + + op.create_table('operation_logs', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('account_id', postgresql.UUID(), nullable=False), + sa.Column('action', sa.String(length=255), nullable=False), + sa.Column('content', sa.JSON(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('created_ip', sa.String(length=255), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='operation_log_pkey') + ) + with op.batch_alter_table('operation_logs', schema=None) as batch_op: + batch_op.create_index('operation_log_account_action_idx', ['tenant_id', 'account_id', 'action'], unique=False) + + op.create_table('pinned_conversations', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('app_id', postgresql.UUID(), nullable=False), + sa.Column('conversation_id', postgresql.UUID(), nullable=False), + sa.Column('created_by', postgresql.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='pinned_conversation_pkey') + ) + with op.batch_alter_table('pinned_conversations', schema=None) as batch_op: + batch_op.create_index('pinned_conversation_conversation_idx', ['app_id', 'conversation_id', 'created_by'], unique=False) + + op.create_table('providers', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('provider_name', sa.String(length=40), nullable=False), + sa.Column('provider_type', sa.String(length=40), nullable=False, server_default=sa.text("'custom'::character varying")), + sa.Column('encrypted_config', sa.Text(), nullable=True), + sa.Column('is_valid', sa.Boolean(), server_default=sa.text('false'), nullable=False), + sa.Column('last_used', sa.DateTime(), nullable=True), + sa.Column('quota_type', sa.String(length=40), nullable=True, server_default=sa.text("''::character varying")), + sa.Column('quota_limit', sa.Integer(), nullable=True), + sa.Column('quota_used', sa.Integer(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='provider_pkey'), + sa.UniqueConstraint('tenant_id', 'provider_name', 'provider_type', 'quota_type', name='unique_provider_name_type_quota') + ) + with op.batch_alter_table('providers', schema=None) as batch_op: + batch_op.create_index('provider_tenant_id_provider_idx', ['tenant_id', 'provider_name'], unique=False) + + op.create_table('recommended_apps', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('app_id', postgresql.UUID(), nullable=False), + sa.Column('description', sa.JSON(), nullable=False), + sa.Column('copyright', sa.String(length=255), nullable=False), + sa.Column('privacy_policy', sa.String(length=255), nullable=False), + sa.Column('category', sa.String(length=255), nullable=False), + sa.Column('position', sa.Integer(), nullable=False), + sa.Column('is_listed', sa.Boolean(), nullable=False), + sa.Column('install_count', sa.Integer(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='recommended_app_pkey') + ) + with op.batch_alter_table('recommended_apps', schema=None) as batch_op: + batch_op.create_index('recommended_app_app_id_idx', ['app_id'], unique=False) + batch_op.create_index('recommended_app_is_listed_idx', ['is_listed'], unique=False) + + op.create_table('saved_messages', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('app_id', postgresql.UUID(), nullable=False), + sa.Column('message_id', postgresql.UUID(), nullable=False), + sa.Column('created_by', postgresql.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='saved_message_pkey') + ) + with op.batch_alter_table('saved_messages', schema=None) as batch_op: + batch_op.create_index('saved_message_message_idx', ['app_id', 'message_id', 'created_by'], unique=False) + + op.create_table('sessions', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('session_id', sa.String(length=255), nullable=True), + sa.Column('data', sa.LargeBinary(), nullable=True), + sa.Column('expiry', sa.DateTime(), nullable=True), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('session_id') + ) + op.create_table('sites', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('app_id', postgresql.UUID(), nullable=False), + sa.Column('title', sa.String(length=255), nullable=False), + sa.Column('icon', sa.String(length=255), nullable=True), + sa.Column('icon_background', sa.String(length=255), nullable=True), + sa.Column('description', sa.String(length=255), nullable=True), + sa.Column('default_language', sa.String(length=255), nullable=False), + sa.Column('copyright', sa.String(length=255), nullable=True), + sa.Column('privacy_policy', sa.String(length=255), nullable=True), + sa.Column('customize_domain', sa.String(length=255), nullable=True), + sa.Column('customize_token_strategy', sa.String(length=255), nullable=False), + sa.Column('prompt_public', sa.Boolean(), server_default=sa.text('false'), nullable=False), + sa.Column('status', sa.String(length=255), server_default=sa.text("'normal'::character varying"), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('code', sa.String(length=255), nullable=True), + sa.PrimaryKeyConstraint('id', name='site_pkey') + ) + with op.batch_alter_table('sites', schema=None) as batch_op: + batch_op.create_index('site_app_id_idx', ['app_id'], unique=False) + batch_op.create_index('site_code_idx', ['code', 'status'], unique=False) + + op.create_table('tenant_account_joins', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('account_id', postgresql.UUID(), nullable=False), + sa.Column('role', sa.String(length=16), server_default='normal', nullable=False), + sa.Column('invited_by', postgresql.UUID(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='tenant_account_join_pkey'), + sa.UniqueConstraint('tenant_id', 'account_id', name='unique_tenant_account_join') + ) + with op.batch_alter_table('tenant_account_joins', schema=None) as batch_op: + batch_op.create_index('tenant_account_join_account_id_idx', ['account_id'], unique=False) + batch_op.create_index('tenant_account_join_tenant_id_idx', ['tenant_id'], unique=False) + + op.create_table('tenants', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('name', sa.String(length=255), nullable=False), + sa.Column('encrypt_public_key', sa.Text(), nullable=True), + sa.Column('plan', sa.String(length=255), server_default=sa.text("'basic'::character varying"), nullable=False), + sa.Column('status', sa.String(length=255), server_default=sa.text("'normal'::character varying"), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='tenant_pkey') + ) + op.create_table('upload_files', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('storage_type', sa.String(length=255), nullable=False), + sa.Column('key', sa.String(length=255), nullable=False), + sa.Column('name', sa.String(length=255), nullable=False), + sa.Column('size', sa.Integer(), nullable=False), + sa.Column('extension', sa.String(length=255), nullable=False), + sa.Column('mime_type', sa.String(length=255), nullable=True), + sa.Column('created_by', postgresql.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('used', sa.Boolean(), server_default=sa.text('false'), nullable=False), + sa.Column('used_by', postgresql.UUID(), nullable=True), + sa.Column('used_at', sa.DateTime(), nullable=True), + sa.Column('hash', sa.String(length=255), nullable=True), + sa.PrimaryKeyConstraint('id', name='upload_file_pkey') + ) + with op.batch_alter_table('upload_files', schema=None) as batch_op: + batch_op.create_index('upload_file_tenant_idx', ['tenant_id'], unique=False) + + op.create_table('message_annotations', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('app_id', postgresql.UUID(), nullable=False), + sa.Column('conversation_id', postgresql.UUID(), nullable=False), + sa.Column('message_id', postgresql.UUID(), nullable=False), + sa.Column('content', sa.Text(), nullable=False), + sa.Column('account_id', postgresql.UUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='message_annotation_pkey') + ) + with op.batch_alter_table('message_annotations', schema=None) as batch_op: + batch_op.create_index('message_annotation_app_idx', ['app_id'], unique=False) + batch_op.create_index('message_annotation_conversation_idx', ['conversation_id'], unique=False) + batch_op.create_index('message_annotation_message_idx', ['message_id'], unique=False) + + op.create_table('messages', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('app_id', postgresql.UUID(), nullable=False), + sa.Column('model_provider', sa.String(length=255), nullable=False), + sa.Column('model_id', sa.String(length=255), nullable=False), + sa.Column('override_model_configs', sa.Text(), nullable=True), + sa.Column('conversation_id', postgresql.UUID(), nullable=False), + sa.Column('inputs', sa.JSON(), nullable=True), + sa.Column('query', sa.Text(), nullable=False), + sa.Column('message', sa.JSON(), nullable=False), + sa.Column('message_tokens', sa.Integer(), server_default=sa.text('0'), nullable=False), + sa.Column('message_unit_price', sa.Numeric(precision=10, scale=4), nullable=False), + sa.Column('answer', sa.Text(), nullable=False), + sa.Column('answer_tokens', sa.Integer(), server_default=sa.text('0'), nullable=False), + sa.Column('answer_unit_price', sa.Numeric(precision=10, scale=4), nullable=False), + sa.Column('provider_response_latency', sa.Float(), server_default=sa.text('0'), nullable=False), + sa.Column('total_price', sa.Numeric(precision=10, scale=7), nullable=True), + sa.Column('currency', sa.String(length=255), nullable=False), + sa.Column('from_source', sa.String(length=255), nullable=False), + sa.Column('from_end_user_id', postgresql.UUID(), nullable=True), + sa.Column('from_account_id', postgresql.UUID(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('agent_based', sa.Boolean(), server_default=sa.text('false'), nullable=False), + sa.PrimaryKeyConstraint('id', name='message_pkey') + ) + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.create_index('message_account_idx', ['app_id', 'from_source', 'from_account_id'], unique=False) + batch_op.create_index('message_app_id_idx', ['app_id', 'created_at'], unique=False) + batch_op.create_index('message_conversation_id_idx', ['conversation_id'], unique=False) + batch_op.create_index('message_end_user_idx', ['app_id', 'from_source', 'from_end_user_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('messages', schema=None) as batch_op: + batch_op.drop_index('message_end_user_idx') + batch_op.drop_index('message_conversation_id_idx') + batch_op.drop_index('message_app_id_idx') + batch_op.drop_index('message_account_idx') + + op.drop_table('messages') + with op.batch_alter_table('message_annotations', schema=None) as batch_op: + batch_op.drop_index('message_annotation_message_idx') + batch_op.drop_index('message_annotation_conversation_idx') + batch_op.drop_index('message_annotation_app_idx') + + op.drop_table('message_annotations') + with op.batch_alter_table('upload_files', schema=None) as batch_op: + batch_op.drop_index('upload_file_tenant_idx') + + op.drop_table('upload_files') + op.drop_table('tenants') + with op.batch_alter_table('tenant_account_joins', schema=None) as batch_op: + batch_op.drop_index('tenant_account_join_tenant_id_idx') + batch_op.drop_index('tenant_account_join_account_id_idx') + + op.drop_table('tenant_account_joins') + with op.batch_alter_table('sites', schema=None) as batch_op: + batch_op.drop_index('site_code_idx') + batch_op.drop_index('site_app_id_idx') + + op.drop_table('sites') + op.drop_table('sessions') + with op.batch_alter_table('saved_messages', schema=None) as batch_op: + batch_op.drop_index('saved_message_message_idx') + + op.drop_table('saved_messages') + with op.batch_alter_table('recommended_apps', schema=None) as batch_op: + batch_op.drop_index('recommended_app_is_listed_idx') + batch_op.drop_index('recommended_app_app_id_idx') + + op.drop_table('recommended_apps') + with op.batch_alter_table('providers', schema=None) as batch_op: + batch_op.drop_index('provider_tenant_id_provider_idx') + + op.drop_table('providers') + with op.batch_alter_table('pinned_conversations', schema=None) as batch_op: + batch_op.drop_index('pinned_conversation_conversation_idx') + + op.drop_table('pinned_conversations') + with op.batch_alter_table('operation_logs', schema=None) as batch_op: + batch_op.drop_index('operation_log_account_action_idx') + + op.drop_table('operation_logs') + with op.batch_alter_table('message_feedbacks', schema=None) as batch_op: + batch_op.drop_index('message_feedback_message_idx') + batch_op.drop_index('message_feedback_conversation_idx') + batch_op.drop_index('message_feedback_app_idx') + + op.drop_table('message_feedbacks') + with op.batch_alter_table('message_chains', schema=None) as batch_op: + batch_op.drop_index('message_chain_message_id_idx') + + op.drop_table('message_chains') + with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op: + batch_op.drop_index('message_agent_thought_message_id_idx') + batch_op.drop_index('message_agent_thought_message_chain_id_idx') + + op.drop_table('message_agent_thoughts') + with op.batch_alter_table('invitation_codes', schema=None) as batch_op: + batch_op.drop_index('invitation_codes_code_idx') + batch_op.drop_index('invitation_codes_batch_idx') + + op.drop_table('invitation_codes') + with op.batch_alter_table('installed_apps', schema=None) as batch_op: + batch_op.drop_index('installed_app_tenant_id_idx') + batch_op.drop_index('installed_app_app_id_idx') + + op.drop_table('installed_apps') + with op.batch_alter_table('end_users', schema=None) as batch_op: + batch_op.drop_index('end_user_tenant_session_id_idx') + batch_op.drop_index('end_user_session_id_idx') + + op.drop_table('end_users') + op.drop_table('embeddings') + with op.batch_alter_table('documents', schema=None) as batch_op: + batch_op.drop_index('document_is_paused_idx') + batch_op.drop_index('document_dataset_id_idx') + + op.drop_table('documents') + with op.batch_alter_table('document_segments', schema=None) as batch_op: + batch_op.drop_index('document_segment_tenant_document_idx') + batch_op.drop_index('document_segment_tenant_dataset_idx') + batch_op.drop_index('document_segment_document_id_idx') + batch_op.drop_index('document_segment_dataset_node_idx') + batch_op.drop_index('document_segment_dataset_id_idx') + + op.drop_table('document_segments') + op.drop_table('dify_setups') + with op.batch_alter_table('datasets', schema=None) as batch_op: + batch_op.drop_index('dataset_tenant_idx') + + op.drop_table('datasets') + with op.batch_alter_table('dataset_queries', schema=None) as batch_op: + batch_op.drop_index('dataset_query_dataset_id_idx') + + op.drop_table('dataset_queries') + with op.batch_alter_table('dataset_process_rules', schema=None) as batch_op: + batch_op.drop_index('dataset_process_rule_dataset_id_idx') + + op.drop_table('dataset_process_rules') + with op.batch_alter_table('dataset_keyword_tables', schema=None) as batch_op: + batch_op.drop_index('dataset_keyword_table_dataset_id_idx') + + op.drop_table('dataset_keyword_tables') + with op.batch_alter_table('conversations', schema=None) as batch_op: + batch_op.drop_index('conversation_app_from_user_idx') + + op.drop_table('conversations') + op.drop_table('celery_tasksetmeta') + op.drop_table('celery_taskmeta') + + op.execute('DROP SEQUENCE taskset_id_sequence;') + op.execute('DROP SEQUENCE task_id_sequence;') + with op.batch_alter_table('apps', schema=None) as batch_op: + batch_op.drop_index('app_tenant_id_idx') + + op.drop_table('apps') + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.drop_index('app_app_id_idx') + + op.drop_table('app_model_configs') + with op.batch_alter_table('app_dataset_joins', schema=None) as batch_op: + batch_op.drop_index('app_dataset_join_app_dataset_idx') + + op.drop_table('app_dataset_joins') + with op.batch_alter_table('api_tokens', schema=None) as batch_op: + batch_op.drop_index('api_token_token_idx') + batch_op.drop_index('api_token_app_id_type_idx') + + op.drop_table('api_tokens') + with op.batch_alter_table('api_requests', schema=None) as batch_op: + batch_op.drop_index('api_request_token_idx') + + op.drop_table('api_requests') + with op.batch_alter_table('accounts', schema=None) as batch_op: + batch_op.drop_index('account_email_idx') + + op.drop_table('accounts') + op.drop_table('account_integrates') + # ### end Alembic commands ### diff --git a/api/models/__init__.py b/api/models/__init__.py new file mode 100644 index 0000000000..44d37d3052 --- /dev/null +++ b/api/models/__init__.py @@ -0,0 +1 @@ +# -*- coding:utf-8 -*- \ No newline at end of file diff --git a/api/models/account.py b/api/models/account.py new file mode 100644 index 0000000000..de2d3bd71f --- /dev/null +++ b/api/models/account.py @@ -0,0 +1,180 @@ +import enum +from typing import List + +from flask_login import UserMixin +from extensions.ext_database import db +from sqlalchemy.dialects.postgresql import UUID + + +class AccountStatus(str, enum.Enum): + PENDING = 'pending' + UNINITIALIZED = 'uninitialized' + ACTIVE = 'active' + BANNED = 'banned' + CLOSED = 'closed' + + +class Account(UserMixin, db.Model): + __tablename__ = 'accounts' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='account_pkey'), + db.Index('account_email_idx', 'email') + ) + + id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + name = db.Column(db.String(255), nullable=False) + email = db.Column(db.String(255), nullable=False) + password = db.Column(db.String(255), nullable=True) + password_salt = db.Column(db.String(255), nullable=True) + avatar = db.Column(db.String(255)) + interface_language = db.Column(db.String(255)) + interface_theme = db.Column(db.String(255)) + timezone = db.Column(db.String(255)) + last_login_at = db.Column(db.DateTime) + last_login_ip = db.Column(db.String(255)) + status = db.Column(db.String(16), nullable=False, server_default=db.text("'active'::character varying")) + initialized_at = db.Column(db.DateTime) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + + _current_tenant: db.Model = None + + @property + def current_tenant(self): + return self._current_tenant + + @current_tenant.setter + def current_tenant(self, value): + tenant = value + ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=self.id).first() + if ta: + tenant.current_role = ta.role + else: + tenant = None + self._current_tenant = tenant + + @property + def current_tenant_id(self): + return self._current_tenant.id + + @current_tenant_id.setter + def current_tenant_id(self, value): + try: + tenant_account_join = db.session.query(Tenant, TenantAccountJoin) \ + .filter(Tenant.id == value) \ + .filter(TenantAccountJoin.tenant_id == Tenant.id) \ + .filter(TenantAccountJoin.account_id == self.id) \ + .one_or_none() + + if tenant_account_join: + tenant, ta = tenant_account_join + tenant.current_role = ta.role + else: + tenant = None + except: + tenant = None + + self._current_tenant = tenant + + def get_status(self) -> AccountStatus: + status_str = self.status + return AccountStatus(status_str) + + @classmethod + def get_by_openid(cls, provider: str, open_id: str) -> db.Model: + account_integrate = db.session.query(AccountIntegrate). \ + filter(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id). \ + one_or_none() + if account_integrate: + return db.session.query(Account). \ + filter(Account.id == account_integrate.account_id). \ + one_or_none() + return None + + def get_integrates(self) -> List[db.Model]: + ai = db.Model + return db.session.query(ai).filter( + ai.account_id == self.id + ).all() + + +class Tenant(db.Model): + __tablename__ = 'tenants' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='tenant_pkey'), + ) + + id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + name = db.Column(db.String(255), nullable=False) + encrypt_public_key = db.Column(db.Text) + plan = db.Column(db.String(255), nullable=False, server_default=db.text("'basic'::character varying")) + status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + + def get_accounts(self) -> List[db.Model]: + Account = db.Model + return db.session.query(Account).filter( + Account.id == TenantAccountJoin.account_id, + TenantAccountJoin.tenant_id == self.id + ).all() + + +class TenantAccountJoinRole(enum.Enum): + OWNER = 'owner' + ADMIN = 'admin' + NORMAL = 'normal' + + +class TenantAccountJoin(db.Model): + __tablename__ = 'tenant_account_joins' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='tenant_account_join_pkey'), + db.Index('tenant_account_join_account_id_idx', 'account_id'), + db.Index('tenant_account_join_tenant_id_idx', 'tenant_id'), + db.UniqueConstraint('tenant_id', 'account_id', name='unique_tenant_account_join') + ) + + id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(UUID, nullable=False) + account_id = db.Column(UUID, nullable=False) + role = db.Column(db.String(16), nullable=False, server_default='normal') + invited_by = db.Column(UUID, nullable=True) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + + +class AccountIntegrate(db.Model): + __tablename__ = 'account_integrates' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='account_integrate_pkey'), + db.UniqueConstraint('account_id', 'provider', name='unique_account_provider'), + db.UniqueConstraint('provider', 'open_id', name='unique_provider_open_id') + ) + + id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + account_id = db.Column(UUID, nullable=False) + provider = db.Column(db.String(16), nullable=False) + open_id = db.Column(db.String(255), nullable=False) + encrypted_token = db.Column(db.String(255), nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + + +class InvitationCode(db.Model): + __tablename__ = 'invitation_codes' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='invitation_code_pkey'), + db.Index('invitation_codes_batch_idx', 'batch'), + db.Index('invitation_codes_code_idx', 'code', 'status') + ) + + id = db.Column(db.Integer, nullable=False) + batch = db.Column(db.String(255), nullable=False) + code = db.Column(db.String(32), nullable=False) + status = db.Column(db.String(16), nullable=False, server_default=db.text("'unused'::character varying")) + used_at = db.Column(db.DateTime) + used_by_tenant_id = db.Column(UUID) + used_by_account_id = db.Column(UUID) + deprecated_at = db.Column(db.DateTime) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) diff --git a/api/models/dataset.py b/api/models/dataset.py new file mode 100644 index 0000000000..29588c1f38 --- /dev/null +++ b/api/models/dataset.py @@ -0,0 +1,415 @@ +import json +import pickle +from json import JSONDecodeError + +from sqlalchemy import func +from sqlalchemy.dialects.postgresql import UUID + +from extensions.ext_database import db +from models.account import Account +from models.model import App, UploadFile + +class Dataset(db.Model): + __tablename__ = 'datasets' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='dataset_pkey'), + db.Index('dataset_tenant_idx', 'tenant_id'), + ) + + INDEXING_TECHNIQUE_LIST = ['high_quality', 'economy'] + + id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(UUID, nullable=False) + name = db.Column(db.String(255), nullable=False) + description = db.Column(db.Text, nullable=True) + provider = db.Column(db.String(255), nullable=False, + server_default=db.text("'vendor'::character varying")) + permission = db.Column(db.String(255), nullable=False, + server_default=db.text("'only_me'::character varying")) + data_source_type = db.Column(db.String(255)) + indexing_technique = db.Column(db.String(255), nullable=True) + index_struct = db.Column(db.Text, nullable=True) + created_by = db.Column(UUID, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, + server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_by = db.Column(UUID, nullable=True) + updated_at = db.Column(db.DateTime, nullable=False, + server_default=db.text('CURRENT_TIMESTAMP(0)')) + + @property + def dataset_keyword_table(self): + dataset_keyword_table = db.session.query(DatasetKeywordTable).filter( + DatasetKeywordTable.dataset_id == self.id).first() + if dataset_keyword_table: + return dataset_keyword_table + + return None + + @property + def index_struct_dict(self): + return json.loads(self.index_struct) if self.index_struct else None + + @property + def created_by_account(self): + return Account.query.get(self.created_by) + + @property + def latest_process_rule(self): + return DatasetProcessRule.query.filter(DatasetProcessRule.dataset_id == self.id) \ + .order_by(DatasetProcessRule.created_at.desc()).first() + + @property + def app_count(self): + return db.session.query(func.count(AppDatasetJoin.id)).filter(AppDatasetJoin.dataset_id == self.id).scalar() + + @property + def document_count(self): + return db.session.query(func.count(Document.id)).filter(Document.dataset_id == self.id).scalar() + + @property + def word_count(self): + return Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) \ + .filter(Document.dataset_id == self.id).scalar() + + +class DatasetProcessRule(db.Model): + __tablename__ = 'dataset_process_rules' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='dataset_process_rule_pkey'), + db.Index('dataset_process_rule_dataset_id_idx', 'dataset_id'), + ) + + id = db.Column(UUID, nullable=False, + server_default=db.text('uuid_generate_v4()')) + dataset_id = db.Column(UUID, nullable=False) + mode = db.Column(db.String(255), nullable=False, + server_default=db.text("'automatic'::character varying")) + rules = db.Column(db.Text, nullable=True) + created_by = db.Column(UUID, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, + server_default=db.text('CURRENT_TIMESTAMP(0)')) + + MODES = ['automatic', 'custom'] + PRE_PROCESSING_RULES = ['remove_stopwords', 'remove_extra_spaces', 'remove_urls_emails'] + AUTOMATIC_RULES = { + 'pre_processing_rules': [ + {'id': 'remove_extra_spaces', 'enabled': True}, + {'id': 'remove_urls_emails', 'enabled': False} + ], + 'segmentation': { + 'delimiter': '\n', + 'max_tokens': 1000 + } + } + + def to_dict(self): + return { + 'id': self.id, + 'dataset_id': self.dataset_id, + 'mode': self.mode, + 'rules': self.rules_dict, + 'created_by': self.created_by, + 'created_at': self.created_at, + } + + @property + def rules_dict(self): + try: + return json.loads(self.rules) if self.rules else None + except JSONDecodeError: + return None + + +class Document(db.Model): + __tablename__ = 'documents' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='document_pkey'), + db.Index('document_dataset_id_idx', 'dataset_id'), + db.Index('document_is_paused_idx', 'is_paused'), + ) + + # initial fields + id = db.Column(UUID, nullable=False, + server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(UUID, nullable=False) + dataset_id = db.Column(UUID, nullable=False) + position = db.Column(db.Integer, nullable=False) + data_source_type = db.Column(db.String(255), nullable=False) + data_source_info = db.Column(db.Text, nullable=True) + dataset_process_rule_id = db.Column(UUID, nullable=True) + batch = db.Column(db.String(255), nullable=False) + name = db.Column(db.String(255), nullable=False) + created_from = db.Column(db.String(255), nullable=False) + created_by = db.Column(UUID, nullable=False) + created_api_request_id = db.Column(UUID, nullable=True) + created_at = db.Column(db.DateTime, nullable=False, + server_default=db.text('CURRENT_TIMESTAMP(0)')) + + # start processing + processing_started_at = db.Column(db.DateTime, nullable=True) + + # parsing + file_id = db.Column(db.Text, nullable=True) + word_count = db.Column(db.Integer, nullable=True) + parsing_completed_at = db.Column(db.DateTime, nullable=True) + + # cleaning + cleaning_completed_at = db.Column(db.DateTime, nullable=True) + + # split + splitting_completed_at = db.Column(db.DateTime, nullable=True) + + # indexing + tokens = db.Column(db.Integer, nullable=True) + indexing_latency = db.Column(db.Float, nullable=True) + completed_at = db.Column(db.DateTime, nullable=True) + + # pause + is_paused = db.Column(db.Boolean, nullable=True, server_default=db.text('false')) + paused_by = db.Column(UUID, nullable=True) + paused_at = db.Column(db.DateTime, nullable=True) + + # error + error = db.Column(db.Text, nullable=True) + stopped_at = db.Column(db.DateTime, nullable=True) + + # basic fields + indexing_status = db.Column(db.String( + 255), nullable=False, server_default=db.text("'waiting'::character varying")) + enabled = db.Column(db.Boolean, nullable=False, + server_default=db.text('true')) + disabled_at = db.Column(db.DateTime, nullable=True) + disabled_by = db.Column(UUID, nullable=True) + archived = db.Column(db.Boolean, nullable=False, + server_default=db.text('false')) + archived_reason = db.Column(db.String(255), nullable=True) + archived_by = db.Column(UUID, nullable=True) + archived_at = db.Column(db.DateTime, nullable=True) + updated_at = db.Column(db.DateTime, nullable=False, + server_default=db.text('CURRENT_TIMESTAMP(0)')) + doc_type = db.Column(db.String(40), nullable=True) + doc_metadata = db.Column(db.JSON, nullable=True) + + DATA_SOURCES = ['upload_file'] + + @property + def display_status(self): + status = None + if self.indexing_status == 'waiting': + status = 'queuing' + elif self.indexing_status not in ['completed', 'error', 'waiting'] and self.is_paused: + status = 'paused' + elif self.indexing_status in ['parsing', 'cleaning', 'splitting', 'indexing']: + status = 'indexing' + elif self.indexing_status == 'error': + status = 'error' + elif self.indexing_status == 'completed' and not self.archived and self.enabled: + status = 'available' + elif self.indexing_status == 'completed' and not self.archived and not self.enabled: + status = 'disabled' + elif self.indexing_status == 'completed' and self.archived: + status = 'archived' + return status + + @property + def data_source_info_dict(self): + if self.data_source_info: + try: + data_source_info_dict = json.loads(self.data_source_info) + except JSONDecodeError: + data_source_info_dict = {} + + return data_source_info_dict + return None + + @property + def data_source_detail_dict(self): + if self.data_source_info: + if self.data_source_type == 'upload_file': + data_source_info_dict = json.loads(self.data_source_info) + file_detail = db.session.query(UploadFile). \ + filter(UploadFile.id == data_source_info_dict['upload_file_id']). \ + one_or_none() + if file_detail: + return { + 'upload_file': { + 'id': file_detail.id, + 'name': file_detail.name, + 'size': file_detail.size, + 'extension': file_detail.extension, + 'mime_type': file_detail.mime_type, + 'created_by': file_detail.created_by, + 'created_at': file_detail.created_at.timestamp() + } + } + return {} + + @property + def average_segment_length(self): + if self.word_count and self.word_count != 0 and self.segment_count and self.segment_count != 0: + return self.word_count//self.segment_count + return 0 + + @property + def dataset_process_rule(self): + if self.dataset_process_rule_id: + return DatasetProcessRule.query.get(self.dataset_process_rule_id) + return None + + @property + def dataset(self): + return Dataset.query.get(self.dataset_id) + + @property + def segment_count(self): + return DocumentSegment.query.filter(DocumentSegment.document_id == self.id).count() + + @property + def hit_count(self): + return DocumentSegment.query.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count))) \ + .filter(DocumentSegment.document_id == self.id).scalar() + + +class DocumentSegment(db.Model): + __tablename__ = 'document_segments' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='document_segment_pkey'), + db.Index('document_segment_dataset_id_idx', 'dataset_id'), + db.Index('document_segment_document_id_idx', 'document_id'), + db.Index('document_segment_tenant_dataset_idx', 'dataset_id', 'tenant_id'), + db.Index('document_segment_tenant_document_idx', 'document_id', 'tenant_id'), + db.Index('document_segment_dataset_node_idx', 'dataset_id', 'index_node_id'), + ) + + # initial fields + id = db.Column(UUID, nullable=False, + server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(UUID, nullable=False) + dataset_id = db.Column(UUID, nullable=False) + document_id = db.Column(UUID, nullable=False) + position = db.Column(db.Integer, nullable=False) + content = db.Column(db.Text, nullable=False) + word_count = db.Column(db.Integer, nullable=False) + tokens = db.Column(db.Integer, nullable=False) + + # indexing fields + keywords = db.Column(db.JSON, nullable=True) + index_node_id = db.Column(db.String(255), nullable=True) + index_node_hash = db.Column(db.String(255), nullable=True) + + # basic fields + hit_count = db.Column(db.Integer, nullable=False, default=0) + enabled = db.Column(db.Boolean, nullable=False, + server_default=db.text('true')) + disabled_at = db.Column(db.DateTime, nullable=True) + disabled_by = db.Column(UUID, nullable=True) + status = db.Column(db.String(255), nullable=False, + server_default=db.text("'waiting'::character varying")) + created_by = db.Column(UUID, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, + server_default=db.text('CURRENT_TIMESTAMP(0)')) + indexing_at = db.Column(db.DateTime, nullable=True) + completed_at = db.Column(db.DateTime, nullable=True) + error = db.Column(db.Text, nullable=True) + stopped_at = db.Column(db.DateTime, nullable=True) + + @property + def dataset(self): + return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).first() + + @property + def document(self): + return db.session.query(Document).filter(Document.id == self.document_id).first() + + @property + def embedding(self): + embedding = db.session.query(Embedding).filter(Embedding.hash == self.index_node_hash).first() \ + if self.index_node_hash else None + + if embedding: + return embedding.embedding + + return None + + @property + def previous_segment(self): + return db.session.query(DocumentSegment).filter( + DocumentSegment.document_id == self.document_id, + DocumentSegment.position == self.position - 1 + ).first() + + @property + def next_segment(self): + return db.session.query(DocumentSegment).filter( + DocumentSegment.document_id == self.document_id, + DocumentSegment.position == self.position + 1 + ).first() + + +class AppDatasetJoin(db.Model): + __tablename__ = 'app_dataset_joins' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='app_dataset_join_pkey'), + db.Index('app_dataset_join_app_dataset_idx', 'dataset_id', 'app_id'), + ) + + id = db.Column(UUID, primary_key=True, nullable=False, server_default=db.text('uuid_generate_v4()')) + app_id = db.Column(UUID, nullable=False) + dataset_id = db.Column(UUID, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) + + @property + def app(self): + return App.query.get(self.app_id) + + +class DatasetQuery(db.Model): + __tablename__ = 'dataset_queries' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='dataset_query_pkey'), + db.Index('dataset_query_dataset_id_idx', 'dataset_id'), + ) + + id = db.Column(UUID, primary_key=True, nullable=False, server_default=db.text('uuid_generate_v4()')) + dataset_id = db.Column(UUID, nullable=False) + content = db.Column(db.Text, nullable=False) + source = db.Column(db.String(255), nullable=False) + source_app_id = db.Column(UUID, nullable=True) + created_by_role = db.Column(db.String, nullable=False) + created_by = db.Column(UUID, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) + + +class DatasetKeywordTable(db.Model): + __tablename__ = 'dataset_keyword_tables' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='dataset_keyword_table_pkey'), + db.Index('dataset_keyword_table_dataset_id_idx', 'dataset_id'), + ) + + id = db.Column(UUID, primary_key=True, server_default=db.text('uuid_generate_v4()')) + dataset_id = db.Column(UUID, nullable=False, unique=True) + keyword_table = db.Column(db.Text, nullable=False) + + @property + def keyword_table_dict(self): + return json.loads(self.keyword_table) if self.keyword_table else None + + +class Embedding(db.Model): + __tablename__ = 'embeddings' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='embedding_pkey'), + db.UniqueConstraint('hash', name='embedding_hash_idx') + ) + + id = db.Column(UUID, primary_key=True, server_default=db.text('uuid_generate_v4()')) + hash = db.Column(db.String(64), nullable=False) + embedding = db.Column(db.LargeBinary, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + + def set_embedding(self, embedding_data: list[float]): + self.embedding = pickle.dumps(embedding_data, protocol=pickle.HIGHEST_PROTOCOL) + + def get_embedding(self) -> list[float]: + return pickle.loads(self.embedding) diff --git a/api/models/model.py b/api/models/model.py new file mode 100644 index 0000000000..9b4e2a38bc --- /dev/null +++ b/api/models/model.py @@ -0,0 +1,622 @@ +import json + +from flask import current_app +from flask_login import UserMixin +from sqlalchemy.dialects.postgresql import UUID + +from libs.helper import generate_string +from extensions.ext_database import db +from .account import Account, Tenant + + +class DifySetup(db.Model): + __tablename__ = 'dify_setups' + __table_args__ = ( + db.PrimaryKeyConstraint('version', name='dify_setup_pkey'), + ) + + version = db.Column(db.String(255), nullable=False) + setup_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + + +class App(db.Model): + __tablename__ = 'apps' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='app_pkey'), + db.Index('app_tenant_id_idx', 'tenant_id') + ) + + id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(UUID, nullable=False) + name = db.Column(db.String(255), nullable=False) + mode = db.Column(db.String(255), nullable=False) + icon = db.Column(db.String(255)) + icon_background = db.Column(db.String(255)) + app_model_config_id = db.Column(UUID, nullable=True) + status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) + enable_site = db.Column(db.Boolean, nullable=False) + enable_api = db.Column(db.Boolean, nullable=False) + api_rpm = db.Column(db.Integer, nullable=False) + api_rph = db.Column(db.Integer, nullable=False) + is_demo = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) + is_public = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + + @property + def site(self): + site = db.session.query(Site).filter(Site.app_id == self.id).first() + return site + + @property + def app_model_config(self): + app_model_config = db.session.query(AppModelConfig).filter( + AppModelConfig.id == self.app_model_config_id).first() + return app_model_config + + @property + def api_base_url(self): + return current_app.config['API_URL'] + '/v1' + + @property + def tenant(self): + tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() + return tenant + + +class AppModelConfig(db.Model): + __tablename__ = 'app_model_configs' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='app_model_config_pkey'), + db.Index('app_app_id_idx', 'app_id') + ) + + id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + app_id = db.Column(UUID, nullable=False) + provider = db.Column(db.String(255), nullable=False) + model_id = db.Column(db.String(255), nullable=False) + configs = db.Column(db.JSON, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + opening_statement = db.Column(db.Text) + suggested_questions = db.Column(db.Text) + suggested_questions_after_answer = db.Column(db.Text) + more_like_this = db.Column(db.Text) + model = db.Column(db.Text) + user_input_form = db.Column(db.Text) + pre_prompt = db.Column(db.Text) + agent_mode = db.Column(db.Text) + + @property + def app(self): + app = db.session.query(App).filter(App.id == self.app_id).first() + return app + + @property + def model_dict(self) -> dict: + return json.loads(self.model) if self.model else None + + @property + def suggested_questions_list(self) -> list: + return json.loads(self.suggested_questions) if self.suggested_questions else [] + + @property + def suggested_questions_after_answer_dict(self) -> dict: + return json.loads(self.suggested_questions_after_answer) if self.suggested_questions_after_answer \ + else {"enabled": False} + + @property + def more_like_this_dict(self) -> dict: + return json.loads(self.more_like_this) if self.more_like_this else {"enabled": False} + + @property + def user_input_form_list(self) -> dict: + return json.loads(self.user_input_form) if self.user_input_form else [] + + @property + def agent_mode_dict(self) -> dict: + return json.loads(self.agent_mode) if self.agent_mode else {"enabled": False, "tools": []} + + +class RecommendedApp(db.Model): + __tablename__ = 'recommended_apps' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='recommended_app_pkey'), + db.Index('recommended_app_app_id_idx', 'app_id'), + db.Index('recommended_app_is_listed_idx', 'is_listed') + ) + + id = db.Column(UUID, primary_key=True, server_default=db.text('uuid_generate_v4()')) + app_id = db.Column(UUID, nullable=False) + description = db.Column(db.JSON, nullable=False) + copyright = db.Column(db.String(255), nullable=False) + privacy_policy = db.Column(db.String(255), nullable=False) + category = db.Column(db.String(255), nullable=False) + position = db.Column(db.Integer, nullable=False, default=0) + is_listed = db.Column(db.Boolean, nullable=False, default=True) + install_count = db.Column(db.Integer, nullable=False, default=0) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + + @property + def app(self): + app = db.session.query(App).filter(App.id == self.app_id).first() + return app + + # def set_description(self, lang, desc): + # if self.description is None: + # self.description = {} + # self.description[lang] = desc + + def get_description(self, lang): + if self.description and lang in self.description: + return self.description[lang] + else: + return self.description.get('en') + + +class InstalledApp(db.Model): + __tablename__ = 'installed_apps' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='installed_app_pkey'), + db.Index('installed_app_tenant_id_idx', 'tenant_id'), + db.Index('installed_app_app_id_idx', 'app_id'), + db.UniqueConstraint('tenant_id', 'app_id', name='unique_tenant_app') + ) + + id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(UUID, nullable=False) + app_id = db.Column(UUID, nullable=False) + app_owner_tenant_id = db.Column(UUID, nullable=False) + position = db.Column(db.Integer, nullable=False, default=0) + is_pinned = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) + last_used_at = db.Column(db.DateTime, nullable=True) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + + @property + def app(self): + app = db.session.query(App).filter(App.id == self.app_id).first() + return app + + @property + def tenant(self): + tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() + return tenant + + +class Conversation(db.Model): + __tablename__ = 'conversations' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='conversation_pkey'), + db.Index('conversation_app_from_user_idx', 'app_id', 'from_source', 'from_end_user_id') + ) + + id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + app_id = db.Column(UUID, nullable=False) + app_model_config_id = db.Column(UUID, nullable=False) + model_provider = db.Column(db.String(255), nullable=False) + override_model_configs = db.Column(db.Text) + model_id = db.Column(db.String(255), nullable=False) + mode = db.Column(db.String(255), nullable=False) + name = db.Column(db.String(255), nullable=False) + summary = db.Column(db.Text) + inputs = db.Column(db.JSON) + introduction = db.Column(db.Text) + system_instruction = db.Column(db.Text) + system_instruction_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0')) + status = db.Column(db.String(255), nullable=False) + from_source = db.Column(db.String(255), nullable=False) + from_end_user_id = db.Column(UUID) + from_account_id = db.Column(UUID) + read_at = db.Column(db.DateTime) + read_account_id = db.Column(UUID) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + + messages = db.relationship("Message", backref="conversation", lazy='select', passive_deletes="all") + message_annotations = db.relationship("MessageAnnotation", backref="conversation", lazy='select', passive_deletes="all") + + @property + def model_config(self): + model_config = {} + if self.override_model_configs: + override_model_configs = json.loads(self.override_model_configs) + + if 'model' in override_model_configs: + model_config['model'] = override_model_configs['model'] + model_config['pre_prompt'] = override_model_configs['pre_prompt'] + model_config['agent_mode'] = override_model_configs['agent_mode'] + model_config['opening_statement'] = override_model_configs['opening_statement'] + model_config['suggested_questions'] = override_model_configs['suggested_questions'] + model_config['suggested_questions_after_answer'] = override_model_configs[ + 'suggested_questions_after_answer'] \ + if 'suggested_questions_after_answer' in override_model_configs else {"enabled": False} + model_config['more_like_this'] = override_model_configs['more_like_this'] \ + if 'more_like_this' in override_model_configs else {"enabled": False} + model_config['user_input_form'] = override_model_configs['user_input_form'] + else: + model_config['configs'] = override_model_configs + else: + app_model_config = db.session.query(AppModelConfig).filter( + AppModelConfig.id == self.app_model_config_id).first() + + model_config['configs'] = app_model_config.configs + model_config['model'] = app_model_config.model_dict + model_config['pre_prompt'] = app_model_config.pre_prompt + model_config['agent_mode'] = app_model_config.agent_mode_dict + model_config['opening_statement'] = app_model_config.opening_statement + model_config['suggested_questions'] = app_model_config.suggested_questions_list + model_config['suggested_questions_after_answer'] = app_model_config.suggested_questions_after_answer_dict + model_config['more_like_this'] = app_model_config.more_like_this_dict + model_config['user_input_form'] = app_model_config.user_input_form_list + + model_config['model_id'] = self.model_id + model_config['provider'] = self.model_provider + + return model_config + + @property + def summary_or_query(self): + if self.summary: + return self.summary + else: + first_message = self.first_message + if first_message: + return first_message.query + else: + return '' + + @property + def annotated(self): + return db.session.query(MessageAnnotation).filter(MessageAnnotation.conversation_id == self.id).count() > 0 + + @property + def annotation(self): + return db.session.query(MessageAnnotation).filter(MessageAnnotation.conversation_id == self.id).first() + + @property + def message_count(self): + return db.session.query(Message).filter(Message.conversation_id == self.id).count() + + @property + def user_feedback_stats(self): + like = db.session.query(MessageFeedback) \ + .filter(MessageFeedback.conversation_id == self.id, + MessageFeedback.from_source == 'user', + MessageFeedback.rating == 'like').count() + + dislike = db.session.query(MessageFeedback) \ + .filter(MessageFeedback.conversation_id == self.id, + MessageFeedback.from_source == 'user', + MessageFeedback.rating == 'dislike').count() + + return {'like': like, 'dislike': dislike} + + @property + def admin_feedback_stats(self): + like = db.session.query(MessageFeedback) \ + .filter(MessageFeedback.conversation_id == self.id, + MessageFeedback.from_source == 'admin', + MessageFeedback.rating == 'like').count() + + dislike = db.session.query(MessageFeedback) \ + .filter(MessageFeedback.conversation_id == self.id, + MessageFeedback.from_source == 'admin', + MessageFeedback.rating == 'dislike').count() + + return {'like': like, 'dislike': dislike} + + @property + def first_message(self): + return db.session.query(Message).filter(Message.conversation_id == self.id).first() + + @property + def app(self): + return db.session.query(App).filter(App.id == self.app_id).first() + + +class Message(db.Model): + __tablename__ = 'messages' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='message_pkey'), + db.Index('message_app_id_idx', 'app_id', 'created_at'), + db.Index('message_conversation_id_idx', 'conversation_id'), + db.Index('message_end_user_idx', 'app_id', 'from_source', 'from_end_user_id'), + db.Index('message_account_idx', 'app_id', 'from_source', 'from_account_id'), + ) + + id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + app_id = db.Column(UUID, nullable=False) + model_provider = db.Column(db.String(255), nullable=False) + model_id = db.Column(db.String(255), nullable=False) + override_model_configs = db.Column(db.Text) + conversation_id = db.Column(UUID, db.ForeignKey('conversations.id'), nullable=False) + inputs = db.Column(db.JSON) + query = db.Column(db.Text, nullable=False) + message = db.Column(db.JSON, nullable=False) + message_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0')) + message_unit_price = db.Column(db.Numeric(10, 4), nullable=False) + answer = db.Column(db.Text, nullable=False) + answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0')) + answer_unit_price = db.Column(db.Numeric(10, 4), nullable=False) + provider_response_latency = db.Column(db.Float, nullable=False, server_default=db.text('0')) + total_price = db.Column(db.Numeric(10, 7)) + currency = db.Column(db.String(255), nullable=False) + from_source = db.Column(db.String(255), nullable=False) + from_end_user_id = db.Column(UUID) + from_account_id = db.Column(UUID) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) + + @property + def user_feedback(self): + feedback = db.session.query(MessageFeedback).filter(MessageFeedback.message_id == self.id, + MessageFeedback.from_source == 'user').first() + return feedback + + @property + def admin_feedback(self): + feedback = db.session.query(MessageFeedback).filter(MessageFeedback.message_id == self.id, + MessageFeedback.from_source == 'admin').first() + return feedback + + @property + def feedbacks(self): + feedbacks = db.session.query(MessageFeedback).filter(MessageFeedback.message_id == self.id).all() + return feedbacks + + @property + def annotation(self): + annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.message_id == self.id).first() + return annotation + + @property + def app_model_config(self): + conversation = db.session.query(Conversation).filter(Conversation.id == self.conversation_id).first() + if conversation: + return db.session.query(AppModelConfig).filter( + AppModelConfig.id == conversation.app_model_config_id).first() + + return None + + +class MessageFeedback(db.Model): + __tablename__ = 'message_feedbacks' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='message_feedback_pkey'), + db.Index('message_feedback_app_idx', 'app_id'), + db.Index('message_feedback_message_idx', 'message_id', 'from_source'), + db.Index('message_feedback_conversation_idx', 'conversation_id', 'from_source', 'rating') + ) + + id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + app_id = db.Column(UUID, nullable=False) + conversation_id = db.Column(UUID, nullable=False) + message_id = db.Column(UUID, nullable=False) + rating = db.Column(db.String(255), nullable=False) + content = db.Column(db.Text) + from_source = db.Column(db.String(255), nullable=False) + from_end_user_id = db.Column(UUID) + from_account_id = db.Column(UUID) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + + @property + def from_account(self): + account = db.session.query(Account).filter(Account.id == self.from_account_id).first() + return account + + +class MessageAnnotation(db.Model): + __tablename__ = 'message_annotations' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='message_annotation_pkey'), + db.Index('message_annotation_app_idx', 'app_id'), + db.Index('message_annotation_conversation_idx', 'conversation_id'), + db.Index('message_annotation_message_idx', 'message_id') + ) + + id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + app_id = db.Column(UUID, nullable=False) + conversation_id = db.Column(UUID, db.ForeignKey('conversations.id'), nullable=False) + message_id = db.Column(UUID, nullable=False) + content = db.Column(db.Text, nullable=False) + account_id = db.Column(UUID, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + + @property + def account(self): + account = db.session.query(Account).filter(Account.id == self.account_id).first() + return account + + +class OperationLog(db.Model): + __tablename__ = 'operation_logs' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='operation_log_pkey'), + db.Index('operation_log_account_action_idx', 'tenant_id', 'account_id', 'action') + ) + + id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(UUID, nullable=False) + account_id = db.Column(UUID, nullable=False) + action = db.Column(db.String(255), nullable=False) + content = db.Column(db.JSON) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + created_ip = db.Column(db.String(255), nullable=False) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + + +class EndUser(UserMixin, db.Model): + __tablename__ = 'end_users' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='end_user_pkey'), + db.Index('end_user_session_id_idx', 'session_id', 'type'), + db.Index('end_user_tenant_session_id_idx', 'tenant_id', 'session_id', 'type'), + ) + + id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(UUID, nullable=False) + app_id = db.Column(UUID, nullable=True) + type = db.Column(db.String(255), nullable=False) + external_user_id = db.Column(db.String(255), nullable=True) + name = db.Column(db.String(255)) + is_anonymous = db.Column(db.Boolean, nullable=False, server_default=db.text('true')) + session_id = db.Column(db.String(255), nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + + +class Site(db.Model): + __tablename__ = 'sites' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='site_pkey'), + db.Index('site_app_id_idx', 'app_id'), + db.Index('site_code_idx', 'code', 'status') + ) + + id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + app_id = db.Column(UUID, nullable=False) + title = db.Column(db.String(255), nullable=False) + icon = db.Column(db.String(255)) + icon_background = db.Column(db.String(255)) + description = db.Column(db.String(255)) + default_language = db.Column(db.String(255), nullable=False) + copyright = db.Column(db.String(255)) + privacy_policy = db.Column(db.String(255)) + customize_domain = db.Column(db.String(255)) + customize_token_strategy = db.Column(db.String(255), nullable=False) + prompt_public = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) + status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + code = db.Column(db.String(255)) + + @staticmethod + def generate_code(n): + while True: + result = generate_string(n) + while db.session.query(Site).filter(Site.code == result).count() > 0: + result = generate_string(n) + + return result + + @property + def app_base_url(self): + return current_app.config['APP_URL'] + + +class ApiToken(db.Model): + __tablename__ = 'api_tokens' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='api_token_pkey'), + db.Index('api_token_app_id_type_idx', 'app_id', 'type'), + db.Index('api_token_token_idx', 'token', 'type') + ) + + id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + app_id = db.Column(UUID, nullable=True) + dataset_id = db.Column(UUID, nullable=True) + type = db.Column(db.String(16), nullable=False) + token = db.Column(db.String(255), nullable=False) + last_used_at = db.Column(db.DateTime, nullable=True) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + + @staticmethod + def generate_api_key(prefix, n): + while True: + result = prefix + generate_string(n) + while db.session.query(ApiToken).filter(ApiToken.token == result).count() > 0: + result = prefix + generate_string(n) + + return result + + +class UploadFile(db.Model): + __tablename__ = 'upload_files' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='upload_file_pkey'), + db.Index('upload_file_tenant_idx', 'tenant_id') + ) + + id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(UUID, nullable=False) + storage_type = db.Column(db.String(255), nullable=False) + key = db.Column(db.String(255), nullable=False) + name = db.Column(db.String(255), nullable=False) + size = db.Column(db.Integer, nullable=False) + extension = db.Column(db.String(255), nullable=False) + mime_type = db.Column(db.String(255), nullable=True) + created_by = db.Column(UUID, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + used = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) + used_by = db.Column(UUID, nullable=True) + used_at = db.Column(db.DateTime, nullable=True) + hash = db.Column(db.String(255), nullable=True) + + +class ApiRequest(db.Model): + __tablename__ = 'api_requests' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='api_request_pkey'), + db.Index('api_request_token_idx', 'tenant_id', 'api_token_id') + ) + + id = db.Column(UUID, nullable=False, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(UUID, nullable=False) + api_token_id = db.Column(UUID, nullable=False) + path = db.Column(db.String(255), nullable=False) + request = db.Column(db.Text, nullable=True) + response = db.Column(db.Text, nullable=True) + ip = db.Column(db.String(255), nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + + +class MessageChain(db.Model): + __tablename__ = 'message_chains' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='message_chain_pkey'), + db.Index('message_chain_message_id_idx', 'message_id') + ) + + id = db.Column(UUID, nullable=False, server_default=db.text('uuid_generate_v4()')) + message_id = db.Column(UUID, nullable=False) + type = db.Column(db.String(255), nullable=False) + input = db.Column(db.Text, nullable=True) + output = db.Column(db.Text, nullable=True) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) + + +class MessageAgentThought(db.Model): + __tablename__ = 'message_agent_thoughts' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='message_agent_thought_pkey'), + db.Index('message_agent_thought_message_id_idx', 'message_id'), + db.Index('message_agent_thought_message_chain_id_idx', 'message_chain_id'), + ) + + id = db.Column(UUID, nullable=False, server_default=db.text('uuid_generate_v4()')) + message_id = db.Column(UUID, nullable=False) + message_chain_id = db.Column(UUID, nullable=False) + position = db.Column(db.Integer, nullable=False) + thought = db.Column(db.Text, nullable=True) + tool = db.Column(db.Text, nullable=True) + tool_input = db.Column(db.Text, nullable=True) + observation = db.Column(db.Text, nullable=True) + # plugin_id = db.Column(UUID, nullable=True) ## for future design + tool_process_data = db.Column(db.Text, nullable=True) + message = db.Column(db.Text, nullable=True) + message_token = db.Column(db.Integer, nullable=True) + message_unit_price = db.Column(db.Numeric, nullable=True) + answer = db.Column(db.Text, nullable=True) + answer_token = db.Column(db.Integer, nullable=True) + answer_unit_price = db.Column(db.Numeric, nullable=True) + tokens = db.Column(db.Integer, nullable=True) + total_price = db.Column(db.Numeric, nullable=True) + currency = db.Column(db.String, nullable=True) + latency = db.Column(db.Float, nullable=True) + created_by_role = db.Column(db.String, nullable=False) + created_by = db.Column(UUID, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) diff --git a/api/models/provider.py b/api/models/provider.py new file mode 100644 index 0000000000..e4ecfa1241 --- /dev/null +++ b/api/models/provider.py @@ -0,0 +1,77 @@ +from enum import Enum + +from sqlalchemy.dialects.postgresql import UUID + +from extensions.ext_database import db + + +class ProviderType(Enum): + CUSTOM = 'custom' + SYSTEM = 'system' + + +class ProviderName(Enum): + OPENAI = 'openai' + AZURE_OPENAI = 'azure_openai' + ANTHROPIC = 'anthropic' + COHERE = 'cohere' + HUGGINGFACEHUB = 'huggingfacehub' + + @staticmethod + def value_of(value): + for member in ProviderName: + if member.value == value: + return member + raise ValueError(f"No matching enum found for value '{value}'") + + +class ProviderQuotaType(Enum): + MONTHLY = 'monthly' + TRIAL = 'trial' + + +class Provider(db.Model): + """ + Provider model representing the API providers and their configurations. + """ + __tablename__ = 'providers' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='provider_pkey'), + db.Index('provider_tenant_id_provider_idx', 'tenant_id', 'provider_name'), + db.UniqueConstraint('tenant_id', 'provider_name', 'provider_type', 'quota_type', name='unique_provider_name_type_quota') + ) + + id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(UUID, nullable=False) + provider_name = db.Column(db.String(40), nullable=False) + provider_type = db.Column(db.String(40), nullable=False, server_default=db.text("'custom'::character varying")) + encrypted_config = db.Column(db.Text, nullable=True) + is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) + last_used = db.Column(db.DateTime, nullable=True) + + quota_type = db.Column(db.String(40), nullable=True, server_default=db.text("''::character varying")) + quota_limit = db.Column(db.Integer, nullable=True) + quota_used = db.Column(db.Integer, default=0) + + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + + def __repr__(self): + return f"" + + @property + def token_is_set(self): + """ + Returns True if the encrypted_config is not None, indicating that the token is set. + """ + return self.encrypted_config is not None + + @property + def is_enabled(self): + """ + Returns True if the provider is enabled. + """ + if self.provider_type == ProviderType.SYSTEM.value: + return self.is_valid + else: + return self.is_valid and self.token_is_set diff --git a/api/models/task.py b/api/models/task.py new file mode 100644 index 0000000000..d85cf16d7c --- /dev/null +++ b/api/models/task.py @@ -0,0 +1,37 @@ +from extensions.ext_database import db +from celery import states +from datetime import datetime + + +class CeleryTask(db.Model): + """Task result/status.""" + + __tablename__ = 'celery_taskmeta' + + id = db.Column(db.Integer, db.Sequence('task_id_sequence'), + primary_key=True, autoincrement=True) + task_id = db.Column(db.String(155), unique=True) + status = db.Column(db.String(50), default=states.PENDING) + result = db.Column(db.PickleType, nullable=True) + date_done = db.Column(db.DateTime, default=datetime.utcnow, + onupdate=datetime.utcnow, nullable=True) + traceback = db.Column(db.Text, nullable=True) + name = db.Column(db.String(155), nullable=True) + args = db.Column(db.LargeBinary, nullable=True) + kwargs = db.Column(db.LargeBinary, nullable=True) + worker = db.Column(db.String(155), nullable=True) + retries = db.Column(db.Integer, nullable=True) + queue = db.Column(db.String(155), nullable=True) + + +class CeleryTaskSet(db.Model): + """TaskSet result.""" + + __tablename__ = 'celery_tasksetmeta' + + id = db.Column(db.Integer, db.Sequence('taskset_id_sequence'), + autoincrement=True, primary_key=True) + taskset_id = db.Column(db.String(155), unique=True) + result = db.Column(db.PickleType, nullable=True) + date_done = db.Column(db.DateTime, default=datetime.utcnow, + nullable=True) diff --git a/api/models/web.py b/api/models/web.py new file mode 100644 index 0000000000..1580ce74c9 --- /dev/null +++ b/api/models/web.py @@ -0,0 +1,36 @@ +from sqlalchemy.dialects.postgresql import UUID + +from extensions.ext_database import db +from models.model import Message + + +class SavedMessage(db.Model): + __tablename__ = 'saved_messages' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='saved_message_pkey'), + db.Index('saved_message_message_idx', 'app_id', 'message_id', 'created_by'), + ) + + id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + app_id = db.Column(UUID, nullable=False) + message_id = db.Column(UUID, nullable=False) + created_by = db.Column(UUID, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + + @property + def message(self): + return db.session.query(Message).filter(Message.id == self.message_id).first() + + +class PinnedConversation(db.Model): + __tablename__ = 'pinned_conversations' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='pinned_conversation_pkey'), + db.Index('pinned_conversation_conversation_idx', 'app_id', 'conversation_id', 'created_by'), + ) + + id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + app_id = db.Column(UUID, nullable=False) + conversation_id = db.Column(UUID, nullable=False) + created_by = db.Column(UUID, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) diff --git a/api/requirements.txt b/api/requirements.txt new file mode 100644 index 0000000000..511625ccfd --- /dev/null +++ b/api/requirements.txt @@ -0,0 +1,32 @@ +coverage~=7.2.4 +beautifulsoup4==4.12.2 +flask~=2.3.2 +Flask-SQLAlchemy~=3.0.3 +flask-login==0.6.2 +flask-migrate~=4.0.4 +flask-restful==0.3.9 +flask-session2==1.3.1 +flask-cors==3.0.10 +gunicorn~=20.1.0 +gevent~=22.10.2 +langchain==0.0.142 +llama-index==0.5.27 +openai~=0.27.5 +psycopg2-binary~=2.9.6 +pycryptodome==3.17 +python-dotenv==1.0.0 +pytest~=7.3.1 +tiktoken==0.3.3 +Authlib==1.2.0 +boto3~=1.26.123 +tenacity==8.2.2 +cachetools~=5.3.0 +weaviate-client~=3.16.2 +qdrant_client~=1.1.6 +mailchimp-transactional~=1.0.50 +scikit-learn==1.2.2 +sentry-sdk[flask]~=1.21.1 +jieba==0.42.1 +celery==5.2.7 +redis~=4.5.4 +pypdf==3.8.1 \ No newline at end of file diff --git a/api/services/__init__.py b/api/services/__init__.py new file mode 100644 index 0000000000..36a7704385 --- /dev/null +++ b/api/services/__init__.py @@ -0,0 +1,2 @@ +# -*- coding:utf-8 -*- +import services.errors diff --git a/api/services/account_service.py b/api/services/account_service.py new file mode 100644 index 0000000000..8442e0eab8 --- /dev/null +++ b/api/services/account_service.py @@ -0,0 +1,382 @@ +# -*- coding:utf-8 -*- +import base64 +import logging +import secrets +from datetime import datetime +from typing import Optional + +from flask import session +from sqlalchemy import func + +from events.tenant_event import tenant_was_created +from services.errors.account import AccountLoginError, CurrentPasswordIncorrectError, LinkAccountIntegrateError, \ + TenantNotFound, AccountNotLinkTenantError, InvalidActionError, CannotOperateSelfError, MemberNotInTenantError, \ + RoleAlreadyAssignedError, NoPermissionError, AccountRegisterError, AccountAlreadyInTenantError +from libs.helper import get_remote_ip +from libs.password import compare_password, hash_password +from libs.rsa import generate_key_pair +from models.account import * + + +class AccountService: + + @staticmethod + def load_user(account_id: int) -> Account: + # todo: used by flask_login + pass + + @staticmethod + def authenticate(email: str, password: str) -> Account: + """authenticate account with email and password""" + + account = Account.query.filter_by(email=email).first() + if not account: + raise AccountLoginError('Invalid email or password.') + + if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value: + raise AccountLoginError('Account is banned or closed.') + + if account.status == AccountStatus.PENDING.value: + account.status = AccountStatus.ACTIVE.value + account.initialized_at = datetime.utcnow() + db.session.commit() + + if account.password is None or not compare_password(password, account.password, account.password_salt): + raise AccountLoginError('Invalid email or password.') + return account + + @staticmethod + def update_account_password(account, password, new_password): + """update account password""" + # todo: split validation and update + if account.password and not compare_password(password, account.password, account.password_salt): + raise CurrentPasswordIncorrectError("Current password is incorrect.") + password_hashed = hash_password(new_password, account.password_salt) + base64_password_hashed = base64.b64encode(password_hashed).decode() + account.password = base64_password_hashed + db.session.commit() + return account + + @staticmethod + def create_account(email: str, name: str, password: str = None, + interface_language: str = 'en-US', interface_theme: str = 'light', + timezone: str = 'America/New_York', ) -> Account: + """create account""" + account = Account() + account.email = email + account.name = name + + if password: + # generate password salt + salt = secrets.token_bytes(16) + base64_salt = base64.b64encode(salt).decode() + + # encrypt password with salt + password_hashed = hash_password(password, salt) + base64_password_hashed = base64.b64encode(password_hashed).decode() + + account.password = base64_password_hashed + account.password_salt = base64_salt + + account.interface_language = interface_language + account.interface_theme = interface_theme + + if interface_language == 'zh-Hans': + account.timezone = 'Asia/Shanghai' + else: + account.timezone = timezone + + db.session.add(account) + db.session.commit() + return account + + @staticmethod + def link_account_integrate(provider: str, open_id: str, account: Account) -> None: + """Link account integrate""" + try: + # Query whether there is an existing binding record for the same provider + account_integrate: Optional[AccountIntegrate] = AccountIntegrate.query.filter_by(account_id=account.id, + provider=provider).first() + + if account_integrate: + # If it exists, update the record + account_integrate.open_id = open_id + account_integrate.encrypted_token = "" # todo + account_integrate.updated_at = datetime.utcnow() + else: + # If it does not exist, create a new record + account_integrate = AccountIntegrate(account_id=account.id, provider=provider, open_id=open_id, + encrypted_token="") + db.session.add(account_integrate) + + db.session.commit() + logging.info(f'Account {account.id} linked {provider} account {open_id}.') + except Exception as e: + logging.exception(f'Failed to link {provider} account {open_id} to Account {account.id}') + raise LinkAccountIntegrateError('Failed to link account.') from e + + @staticmethod + def close_account(account: Account) -> None: + """todo: Close account""" + account.status = AccountStatus.CLOSED.value + db.session.commit() + + @staticmethod + def update_account(account, **kwargs): + """Update account fields""" + for field, value in kwargs.items(): + if hasattr(account, field): + setattr(account, field, value) + else: + raise AttributeError(f"Invalid field: {field}") + + db.session.commit() + return account + + @staticmethod + def update_last_login(account: Account, request) -> None: + """Update last login time and ip""" + account.last_login_at = datetime.utcnow() + account.last_login_ip = get_remote_ip(request) + db.session.add(account) + db.session.commit() + logging.info(f'Account {account.id} logged in successfully.') + + +class TenantService: + + @staticmethod + def create_tenant(name: str) -> Tenant: + """Create tenant""" + tenant = Tenant(name=name) + + db.session.add(tenant) + db.session.commit() + + tenant.encrypt_public_key = generate_key_pair(tenant.id) + db.session.commit() + return tenant + + @staticmethod + def create_tenant_member(tenant: Tenant, account: Account, role: str = 'normal') -> TenantAccountJoin: + """Create tenant member""" + if role == TenantAccountJoinRole.OWNER.value: + if TenantService.has_roles(tenant, [TenantAccountJoinRole.OWNER]): + logging.error(f'Tenant {tenant.id} has already an owner.') + raise Exception('Tenant already has an owner.') + + ta = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=role + ) + db.session.add(ta) + db.session.commit() + return ta + + @staticmethod + def get_join_tenants(account: Account) -> List[Tenant]: + """Get account join tenants""" + return db.session.query(Tenant).join( + TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id + ).filter(TenantAccountJoin.account_id == account.id).all() + + @staticmethod + def get_current_tenant_by_account(account: Account): + """Get tenant by account and add the role""" + tenant = account.current_tenant + if not tenant: + raise TenantNotFound("Tenant not found.") + + ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first() + if ta: + tenant.role = ta.role + else: + raise TenantNotFound("Tenant not found for the account.") + return tenant + + @staticmethod + def switch_tenant(account: Account, tenant_id: int = None) -> None: + """Switch the current workspace for the account""" + if not tenant_id: + tenant_account_join = TenantAccountJoin.query.filter_by(account_id=account.id).first() + else: + tenant_account_join = TenantAccountJoin.query.filter_by(account_id=account.id, tenant_id=tenant_id).first() + + # Check if the tenant exists and the account is a member of the tenant + if not tenant_account_join: + raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.") + + # Set the current tenant for the account + account.current_tenant_id = tenant_account_join.tenant_id + session['workspace_id'] = account.current_tenant.id + + @staticmethod + def get_tenant_members(tenant: Tenant) -> List[Account]: + """Get tenant members""" + query = ( + db.session.query(Account, TenantAccountJoin.role) + .select_from(Account) + .join( + TenantAccountJoin, Account.id == TenantAccountJoin.account_id + ) + .filter(TenantAccountJoin.tenant_id == tenant.id) + ) + + # Initialize an empty list to store the updated accounts + updated_accounts = [] + + for account, role in query: + account.role = role + updated_accounts.append(account) + + return updated_accounts + + @staticmethod + def has_roles(tenant: Tenant, roles: List[TenantAccountJoinRole]) -> bool: + """Check if user has any of the given roles for a tenant""" + if not all(isinstance(role, TenantAccountJoinRole) for role in roles): + raise ValueError('all roles must be TenantAccountJoinRole') + + return db.session.query(TenantAccountJoin).filter( + TenantAccountJoin.tenant_id == tenant.id, + TenantAccountJoin.role.in_([role.value for role in roles]) + ).first() is not None + + @staticmethod + def get_user_role(account: Account, tenant: Tenant) -> Optional[TenantAccountJoinRole]: + """Get the role of the current account for a given tenant""" + join = db.session.query(TenantAccountJoin).filter( + TenantAccountJoin.tenant_id == tenant.id, + TenantAccountJoin.account_id == account.id + ).first() + return join.role if join else None + + @staticmethod + def get_tenant_count() -> int: + """Get tenant count""" + return db.session.query(func.count(Tenant.id)).scalar() + + @staticmethod + def check_member_permission(tenant: Tenant, operator: Account, member: Account, action: str) -> None: + """Check member permission""" + perms = { + 'add': ['owner', 'admin'], + 'remove': ['owner'], + 'update': ['owner'] + } + if action not in ['add', 'remove', 'update']: + raise InvalidActionError("Invalid action.") + + if operator.id == member.id: + raise CannotOperateSelfError("Cannot operate self.") + + ta_operator = TenantAccountJoin.query.filter_by( + tenant_id=tenant.id, + account_id=operator.id + ).first() + + if not ta_operator or ta_operator.role not in perms[action]: + raise NoPermissionError(f'No permission to {action} member.') + + @staticmethod + def remove_member_from_tenant(tenant: Tenant, account: Account, operator: Account) -> None: + """Remove member from tenant""" + # todo: check permission + + if operator.id == account.id and TenantService.check_member_permission(tenant, operator, account, 'remove'): + raise CannotOperateSelfError("Cannot operate self.") + + ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first() + if not ta: + raise MemberNotInTenantError("Member not in tenant.") + + db.session.delete(ta) + db.session.commit() + + @staticmethod + def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account) -> None: + """Update member role""" + TenantService.check_member_permission(tenant, operator, member, 'update') + + target_member_join = TenantAccountJoin.query.filter_by( + tenant_id=tenant.id, + account_id=member.id + ).first() + + if target_member_join.role == new_role: + raise RoleAlreadyAssignedError("The provided role is already assigned to the member.") + + if new_role == 'owner': + # Find the current owner and change their role to 'admin' + current_owner_join = TenantAccountJoin.query.filter_by( + tenant_id=tenant.id, + role='owner' + ).first() + current_owner_join.role = 'admin' + + # Update the role of the target member + target_member_join.role = new_role + db.session.commit() + + @staticmethod + def dissolve_tenant(tenant: Tenant, operator: Account) -> None: + """Dissolve tenant""" + if not TenantService.check_member_permission(tenant, operator, operator, 'remove'): + raise NoPermissionError('No permission to dissolve tenant.') + db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id).delete() + db.session.delete(tenant) + db.session.commit() + + +class RegisterService: + + @staticmethod + def register(email, name, password: str = None, open_id: str = None, provider: str = None) -> Account: + db.session.begin_nested() + """Register account""" + try: + account = AccountService.create_account(email, name, password) + account.status = AccountStatus.ACTIVE.value + account.initialized_at = datetime.utcnow() + + if open_id is not None or provider is not None: + AccountService.link_account_integrate(provider, open_id, account) + + tenant = TenantService.create_tenant(f"{account.name}'s Workspace") + + TenantService.create_tenant_member(tenant, account, role='owner') + account.current_tenant = tenant + + db.session.commit() + except Exception as e: + db.session.rollback() # todo: do not work + logging.error(f'Register failed: {e}') + raise AccountRegisterError(f'Registration failed: {e}') from e + + tenant_was_created.send(tenant) + + return account + + @staticmethod + def invite_new_member(tenant: Tenant, email: str, role: str = 'normal', + inviter: Account = None) -> TenantAccountJoin: + """Invite new member""" + account = Account.query.filter_by(email=email).first() + + if not account: + name = email.split('@')[0] + account = AccountService.create_account(email, name) + account.status = AccountStatus.PENDING.value + db.session.commit() + else: + TenantService.check_member_permission(tenant, inviter, account, 'add') + ta = TenantAccountJoin.query.filter_by( + tenant_id=tenant.id, + account_id=account.id + ).first() + if ta: + raise AccountAlreadyInTenantError("Account already in tenant.") + + ta = TenantService.create_tenant_member(tenant, account, role) + return ta diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py new file mode 100644 index 0000000000..70b63fbe19 --- /dev/null +++ b/api/services/app_model_config_service.py @@ -0,0 +1,292 @@ +import re +import uuid + +from core.constant import llm_constant +from models.account import Account +from services.dataset_service import DatasetService +from services.errors.account import NoPermissionError + + +class AppModelConfigService: + @staticmethod + def is_dataset_exists(account: Account, dataset_id: str) -> bool: + # verify if the dataset ID exists + dataset = DatasetService.get_dataset(dataset_id) + + if not dataset: + return False + + if dataset.tenant_id != account.current_tenant_id: + return False + + return True + + @staticmethod + def validate_model_completion_params(cp: dict, model_name: str) -> dict: + # 6. model.completion_params + if not isinstance(cp, dict): + raise ValueError("model.completion_params must be of object type") + + # max_tokens + if 'max_tokens' not in cp: + cp["max_tokens"] = 512 + + if not isinstance(cp["max_tokens"], int) or cp["max_tokens"] <= 0 or cp["max_tokens"] > \ + llm_constant.max_context_token_length[model_name]: + raise ValueError( + "max_tokens must be an integer greater than 0 and not exceeding the maximum value of the corresponding model") + + # temperature + if 'temperature' not in cp: + cp["temperature"] = 1 + + if not isinstance(cp["temperature"], (float, int)) or cp["temperature"] < 0 or cp["temperature"] > 2: + raise ValueError("temperature must be a float between 0 and 2") + + # top_p + if 'top_p' not in cp: + cp["top_p"] = 1 + + if not isinstance(cp["top_p"], (float, int)) or cp["top_p"] < 0 or cp["top_p"] > 2: + raise ValueError("top_p must be a float between 0 and 2") + + # presence_penalty + if 'presence_penalty' not in cp: + cp["presence_penalty"] = 0 + + if not isinstance(cp["presence_penalty"], (float, int)) or cp["presence_penalty"] < -2 or cp["presence_penalty"] > 2: + raise ValueError("presence_penalty must be a float between -2 and 2") + + # presence_penalty + if 'frequency_penalty' not in cp: + cp["frequency_penalty"] = 0 + + if not isinstance(cp["frequency_penalty"], (float, int)) or cp["frequency_penalty"] < -2 or cp["frequency_penalty"] > 2: + raise ValueError("frequency_penalty must be a float between -2 and 2") + + # Filter out extra parameters + filtered_cp = { + "max_tokens": cp["max_tokens"], + "temperature": cp["temperature"], + "top_p": cp["top_p"], + "presence_penalty": cp["presence_penalty"], + "frequency_penalty": cp["frequency_penalty"] + } + + return filtered_cp + + @staticmethod + def validate_configuration(account: Account, config: dict, mode: str) -> dict: + # opening_statement + if 'opening_statement' not in config or not config["opening_statement"]: + config["opening_statement"] = "" + + if not isinstance(config["opening_statement"], str): + raise ValueError("opening_statement must be of string type") + + # suggested_questions + if 'suggested_questions' not in config or not config["suggested_questions"]: + config["suggested_questions"] = [] + + if not isinstance(config["suggested_questions"], list): + raise ValueError("suggested_questions must be of list type") + + for question in config["suggested_questions"]: + if not isinstance(question, str): + raise ValueError("Elements in suggested_questions list must be of string type") + + # suggested_questions_after_answer + if 'suggested_questions_after_answer' not in config or not config["suggested_questions_after_answer"]: + config["suggested_questions_after_answer"] = { + "enabled": False + } + + if not isinstance(config["suggested_questions_after_answer"], dict): + raise ValueError("suggested_questions_after_answer must be of dict type") + + if "enabled" not in config["suggested_questions_after_answer"] or not config["suggested_questions_after_answer"]["enabled"]: + config["suggested_questions_after_answer"]["enabled"] = False + + if not isinstance(config["suggested_questions_after_answer"]["enabled"], bool): + raise ValueError("enabled in suggested_questions_after_answer must be of boolean type") + + # more_like_this + if 'more_like_this' not in config or not config["more_like_this"]: + config["more_like_this"] = { + "enabled": False + } + + if not isinstance(config["more_like_this"], dict): + raise ValueError("more_like_this must be of dict type") + + if "enabled" not in config["more_like_this"] or not config["more_like_this"]["enabled"]: + config["more_like_this"]["enabled"] = False + + if not isinstance(config["more_like_this"]["enabled"], bool): + raise ValueError("enabled in more_like_this must be of boolean type") + + # model + if 'model' not in config: + raise ValueError("model is required") + + if not isinstance(config["model"], dict): + raise ValueError("model must be of object type") + + # model.provider + if 'provider' not in config["model"] or config["model"]["provider"] != "openai": + raise ValueError("model.provider must be 'openai'") + + # model.name + if 'name' not in config["model"]: + raise ValueError("model.name is required") + + if config["model"]["name"] not in llm_constant.models_by_mode[mode]: + raise ValueError("model.name must be in the specified model list") + + # model.completion_params + if 'completion_params' not in config["model"]: + raise ValueError("model.completion_params is required") + + config["model"]["completion_params"] = AppModelConfigService.validate_model_completion_params( + config["model"]["completion_params"], + config["model"]["name"] + ) + + # user_input_form + if "user_input_form" not in config or not config["user_input_form"]: + config["user_input_form"] = [] + + if not isinstance(config["user_input_form"], list): + raise ValueError("user_input_form must be a list of objects") + + variables = [] + for item in config["user_input_form"]: + key = list(item.keys())[0] + if key not in ["text-input", "select"]: + raise ValueError("Keys in user_input_form list can only be 'text-input' or 'select'") + + form_item = item[key] + if 'label' not in form_item: + raise ValueError("label is required in user_input_form") + + if not isinstance(form_item["label"], str): + raise ValueError("label in user_input_form must be of string type") + + if 'variable' not in form_item: + raise ValueError("variable is required in user_input_form") + + if not isinstance(form_item["variable"], str): + raise ValueError("variable in user_input_form must be of string type") + + pattern = re.compile(r"^(?!\d)[\u4e00-\u9fa5A-Za-z0-9_\U0001F300-\U0001F64F\U0001F680-\U0001F6FF]{1,100}$") + if pattern.match(form_item["variable"]) is None: + raise ValueError("variable in user_input_form must be a string, " + "and cannot start with a number") + + variables.append(form_item["variable"]) + + if 'required' not in form_item or not form_item["required"]: + form_item["required"] = False + + if not isinstance(form_item["required"], bool): + raise ValueError("required in user_input_form must be of boolean type") + + if key == "select": + if 'options' not in form_item or not form_item["options"]: + form_item["options"] = [] + + if not isinstance(form_item["options"], list): + raise ValueError("options in user_input_form must be a list of strings") + + if "default" in form_item and form_item['default'] \ + and form_item["default"] not in form_item["options"]: + raise ValueError("default value in user_input_form must be in the options list") + + # pre_prompt + if "pre_prompt" not in config or not config["pre_prompt"]: + config["pre_prompt"] = "" + + if not isinstance(config["pre_prompt"], str): + raise ValueError("pre_prompt must be of string type") + + template_vars = re.findall(r"\{\{(\w+)\}\}", config["pre_prompt"]) + for var in template_vars: + if var not in variables: + raise ValueError("Template variables in pre_prompt must be defined in user_input_form") + + # agent_mode + if "agent_mode" not in config or not config["agent_mode"]: + config["agent_mode"] = { + "enabled": False, + "tools": [] + } + + if not isinstance(config["agent_mode"], dict): + raise ValueError("agent_mode must be of object type") + + if "enabled" not in config["agent_mode"] or not config["agent_mode"]["enabled"]: + config["agent_mode"]["enabled"] = False + + if not isinstance(config["agent_mode"]["enabled"], bool): + raise ValueError("enabled in agent_mode must be of boolean type") + + if "tools" not in config["agent_mode"] or not config["agent_mode"]["tools"]: + config["agent_mode"]["tools"] = [] + + if not isinstance(config["agent_mode"]["tools"], list): + raise ValueError("tools in agent_mode must be a list of objects") + + for tool in config["agent_mode"]["tools"]: + key = list(tool.keys())[0] + if key not in ["sensitive-word-avoidance", "dataset"]: + raise ValueError("Keys in agent_mode.tools list can only be 'sensitive-word-avoidance' or 'dataset'") + + tool_item = tool[key] + + if "enabled" not in tool_item or not tool_item["enabled"]: + tool_item["enabled"] = False + + if not isinstance(tool_item["enabled"], bool): + raise ValueError("enabled in agent_mode.tools must be of boolean type") + + if key == "sensitive-word-avoidance": + if "words" not in tool_item or not tool_item["words"]: + tool_item["words"] = "" + + if not isinstance(tool_item["words"], str): + raise ValueError("words in sensitive-word-avoidance must be of string type") + + if "canned_response" not in tool_item or not tool_item["canned_response"]: + tool_item["canned_response"] = "" + + if not isinstance(tool_item["canned_response"], str): + raise ValueError("canned_response in sensitive-word-avoidance must be of string type") + elif key == "dataset": + if 'id' not in tool_item: + raise ValueError("id is required in dataset") + + try: + uuid.UUID(tool_item["id"]) + except ValueError: + raise ValueError("id in dataset must be of UUID type") + + if not AppModelConfigService.is_dataset_exists(account, tool_item["id"]): + raise ValueError("Dataset ID does not exist, please check your permission.") + + # Filter out extra parameters + filtered_config = { + "opening_statement": config["opening_statement"], + "suggested_questions": config["suggested_questions"], + "suggested_questions_after_answer": config["suggested_questions_after_answer"], + "more_like_this": config["more_like_this"], + "model": { + "provider": config["model"]["provider"], + "name": config["model"]["name"], + "completion_params": config["model"]["completion_params"] + }, + "user_input_form": config["user_input_form"], + "pre_prompt": config["pre_prompt"], + "agent_mode": config["agent_mode"] + } + + return filtered_config diff --git a/api/services/completion_service.py b/api/services/completion_service.py new file mode 100644 index 0000000000..f94cd23d3e --- /dev/null +++ b/api/services/completion_service.py @@ -0,0 +1,506 @@ +import json +import logging +import threading +import time +import uuid +from typing import Generator, Union, Any + +from flask import current_app, Flask +from redis.client import PubSub +from sqlalchemy import and_ + +from core.completion import Completion +from core.conversation_message_task import PubHandler, ConversationTaskStoppedException +from core.llm.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, \ + LLMAuthorizationError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.model import Conversation, AppModelConfig, App, Account, EndUser, Message +from services.app_model_config_service import AppModelConfigService +from services.errors.app import MoreLikeThisDisabledError +from services.errors.app_model_config import AppModelConfigBrokenError +from services.errors.completion import CompletionStoppedError +from services.errors.conversation import ConversationNotExistsError, ConversationCompletedError +from services.errors.message import MessageNotExistsError + + +class CompletionService: + + @classmethod + def completion(cls, app_model: App, user: Union[Account | EndUser], args: Any, + from_source: str, streaming: bool = True, + is_model_config_override: bool = False) -> Union[dict | Generator]: + # is streaming mode + inputs = args['inputs'] + query = args['query'] + conversation_id = args['conversation_id'] if 'conversation_id' in args else None + + conversation = None + if conversation_id: + conversation_filter = [ + Conversation.id == args['conversation_id'], + Conversation.app_id == app_model.id, + Conversation.status == 'normal' + ] + + if from_source == 'console': + conversation_filter.append(Conversation.from_account_id == user.id) + else: + conversation_filter.append(Conversation.from_end_user_id == user.id if user else None) + + conversation = db.session.query(Conversation).filter(and_(*conversation_filter)).first() + + if not conversation: + raise ConversationNotExistsError() + + if conversation.status != 'normal': + raise ConversationCompletedError() + + if not conversation.override_model_configs: + app_model_config = db.session.query(AppModelConfig).get(conversation.app_model_config_id) + + if not app_model_config: + raise AppModelConfigBrokenError() + else: + conversation_override_model_configs = json.loads(conversation.override_model_configs) + app_model_config = AppModelConfig( + id=conversation.app_model_config_id, + app_id=app_model.id, + provider="", + model_id="", + configs="", + opening_statement=conversation_override_model_configs['opening_statement'], + suggested_questions=json.dumps(conversation_override_model_configs['suggested_questions']), + model=json.dumps(conversation_override_model_configs['model']), + user_input_form=json.dumps(conversation_override_model_configs['user_input_form']), + pre_prompt=conversation_override_model_configs['pre_prompt'], + agent_mode=json.dumps(conversation_override_model_configs['agent_mode']), + ) + + if is_model_config_override: + # build new app model config + if 'model' not in args['model_config']: + raise ValueError('model_config.model is required') + + if 'completion_params' not in args['model_config']['model']: + raise ValueError('model_config.model.completion_params is required') + + completion_params = AppModelConfigService.validate_model_completion_params( + cp=args['model_config']['model']['completion_params'], + model_name=app_model_config.model_dict["name"] + ) + + app_model_config_model = app_model_config.model_dict + app_model_config_model['completion_params'] = completion_params + + app_model_config = AppModelConfig( + id=app_model_config.id, + app_id=app_model.id, + provider="", + model_id="", + configs="", + opening_statement=app_model_config.opening_statement, + suggested_questions=app_model_config.suggested_questions, + model=json.dumps(app_model_config_model), + user_input_form=app_model_config.user_input_form, + pre_prompt=app_model_config.pre_prompt, + agent_mode=app_model_config.agent_mode, + ) + else: + if app_model.app_model_config_id is None: + raise AppModelConfigBrokenError() + + app_model_config = app_model.app_model_config + + if not app_model_config: + raise AppModelConfigBrokenError() + + if is_model_config_override: + if not isinstance(user, Account): + raise Exception("Only account can override model config") + + # validate config + model_config = AppModelConfigService.validate_configuration( + account=user, + config=args['model_config'], + mode=app_model.mode + ) + + app_model_config = AppModelConfig( + id=app_model_config.id, + app_id=app_model.id, + provider="", + model_id="", + configs="", + opening_statement=model_config['opening_statement'], + suggested_questions=json.dumps(model_config['suggested_questions']), + suggested_questions_after_answer=json.dumps(model_config['suggested_questions_after_answer']), + more_like_this=json.dumps(model_config['more_like_this']), + model=json.dumps(model_config['model']), + user_input_form=json.dumps(model_config['user_input_form']), + pre_prompt=model_config['pre_prompt'], + agent_mode=json.dumps(model_config['agent_mode']), + ) + + # clean input by app_model_config form rules + inputs = cls.get_cleaned_inputs(inputs, app_model_config) + + generate_task_id = str(uuid.uuid4()) + + pubsub = redis_client.pubsub() + pubsub.subscribe(PubHandler.generate_channel_name(user, generate_task_id)) + + user = cls.get_real_user_instead_of_proxy_obj(user) + + generate_worker_thread = threading.Thread(target=cls.generate_worker, kwargs={ + 'flask_app': current_app._get_current_object(), + 'generate_task_id': generate_task_id, + 'app_model': app_model, + 'app_model_config': app_model_config, + 'query': query, + 'inputs': inputs, + 'user': user, + 'conversation': conversation, + 'streaming': streaming, + 'is_model_config_override': is_model_config_override + }) + + generate_worker_thread.start() + + # wait for 5 minutes to close the thread + cls.countdown_and_close(generate_worker_thread, pubsub, user, generate_task_id) + + return cls.compact_response(pubsub, streaming) + + @classmethod + def get_real_user_instead_of_proxy_obj(cls, user: Union[Account, EndUser]): + if isinstance(user, Account): + user = db.session.query(Account).get(user.id) + elif isinstance(user, EndUser): + user = db.session.query(EndUser).get(user.id) + else: + raise Exception("Unknown user type") + + return user + + @classmethod + def generate_worker(cls, flask_app: Flask, generate_task_id: str, app_model: App, app_model_config: AppModelConfig, + query: str, inputs: dict, user: Union[Account, EndUser], + conversation: Conversation, streaming: bool, is_model_config_override: bool): + with flask_app.app_context(): + try: + if conversation: + # fixed the state of the conversation object when it detached from the original session + conversation = db.session.query(Conversation).filter_by(id=conversation.id).first() + + # run + Completion.generate( + task_id=generate_task_id, + app=app_model, + app_model_config=app_model_config, + query=query, + inputs=inputs, + user=user, + conversation=conversation, + streaming=streaming, + is_override=is_model_config_override, + ) + except ConversationTaskStoppedException: + pass + except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, + LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, + ModelCurrentlyNotSupportError) as e: + db.session.rollback() + PubHandler.pub_error(user, generate_task_id, e) + except LLMAuthorizationError: + db.session.rollback() + PubHandler.pub_error(user, generate_task_id, LLMAuthorizationError('Incorrect API key provided')) + except Exception as e: + db.session.rollback() + logging.exception("Unknown Error in completion") + PubHandler.pub_error(user, generate_task_id, e) + + @classmethod + def countdown_and_close(cls, worker_thread, pubsub, user, generate_task_id) -> threading.Thread: + # wait for 5 minutes to close the thread + timeout = 300 + + def close_pubsub(): + sleep_iterations = 0 + while sleep_iterations < timeout and worker_thread.is_alive(): + time.sleep(1) + sleep_iterations += 1 + + if worker_thread.is_alive(): + PubHandler.stop(user, generate_task_id) + try: + pubsub.close() + except: + pass + + countdown_thread = threading.Thread(target=close_pubsub) + countdown_thread.start() + + return countdown_thread + + @classmethod + def generate_more_like_this(cls, app_model: App, user: Union[Account | EndUser], + message_id: str, streaming: bool = True) -> Union[dict | Generator]: + if not user: + raise ValueError('user cannot be None') + + message = db.session.query(Message).filter( + Message.id == message_id, + Message.app_id == app_model.id, + Message.from_source == ('api' if isinstance(user, EndUser) else 'console'), + Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None), + Message.from_account_id == (user.id if isinstance(user, Account) else None), + ).first() + + if not message: + raise MessageNotExistsError() + + current_app_model_config = app_model.app_model_config + more_like_this = current_app_model_config.more_like_this_dict + + if not current_app_model_config.more_like_this or more_like_this.get("enabled", False) is False: + raise MoreLikeThisDisabledError() + + app_model_config = message.app_model_config + + if message.override_model_configs: + override_model_configs = json.loads(message.override_model_configs) + pre_prompt = override_model_configs.get("pre_prompt", '') + elif app_model_config: + pre_prompt = app_model_config.pre_prompt + else: + raise AppModelConfigBrokenError() + + generate_task_id = str(uuid.uuid4()) + + pubsub = redis_client.pubsub() + pubsub.subscribe(PubHandler.generate_channel_name(user, generate_task_id)) + + user = cls.get_real_user_instead_of_proxy_obj(user) + + generate_worker_thread = threading.Thread(target=cls.generate_more_like_this_worker, kwargs={ + 'flask_app': current_app._get_current_object(), + 'generate_task_id': generate_task_id, + 'app_model': app_model, + 'app_model_config': app_model_config, + 'message': message, + 'pre_prompt': pre_prompt, + 'user': user, + 'streaming': streaming + }) + + generate_worker_thread.start() + + cls.countdown_and_close(generate_worker_thread, pubsub, user, generate_task_id) + + return cls.compact_response(pubsub, streaming) + + @classmethod + def generate_more_like_this_worker(cls, flask_app: Flask, generate_task_id: str, app_model: App, + app_model_config: AppModelConfig, message: Message, pre_prompt: str, + user: Union[Account, EndUser], streaming: bool): + with flask_app.app_context(): + try: + # run + Completion.generate_more_like_this( + task_id=generate_task_id, + app=app_model, + user=user, + message=message, + pre_prompt=pre_prompt, + app_model_config=app_model_config, + streaming=streaming + ) + except ConversationTaskStoppedException: + pass + except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, + LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, + ModelCurrentlyNotSupportError) as e: + db.session.rollback() + PubHandler.pub_error(user, generate_task_id, e) + except LLMAuthorizationError: + db.session.rollback() + PubHandler.pub_error(user, generate_task_id, LLMAuthorizationError('Incorrect API key provided')) + except Exception as e: + db.session.rollback() + logging.exception("Unknown Error in completion") + PubHandler.pub_error(user, generate_task_id, e) + + @classmethod + def get_cleaned_inputs(cls, user_inputs: dict, app_model_config: AppModelConfig): + if user_inputs is None: + user_inputs = {} + + filtered_inputs = {} + + # Filter input variables from form configuration, handle required fields, default values, and option values + input_form_config = app_model_config.user_input_form_list + for config in input_form_config: + input_config = list(config.values())[0] + variable = input_config["variable"] + + input_type = list(config.keys())[0] + + if variable not in user_inputs or not user_inputs[variable]: + if "required" in input_config and input_config["required"]: + raise ValueError(f"{variable} is required in input form") + else: + filtered_inputs[variable] = input_config["default"] if "default" in input_config else "" + continue + + value = user_inputs[variable] + + if input_type == "select": + options = input_config["options"] if "options" in input_config else [] + if value not in options: + raise ValueError(f"{variable} in input form must be one of the following: {options}") + else: + if 'max_length' in variable: + max_length = variable['max_length'] + if len(value) > max_length: + raise ValueError(f'{variable} in input form must be less than {max_length} characters') + + filtered_inputs[variable] = value + + return filtered_inputs + + @classmethod + def compact_response(cls, pubsub: PubSub, streaming: bool = False) -> Union[dict | Generator]: + generate_channel = list(pubsub.channels.keys())[0].decode('utf-8') + if not streaming: + try: + for message in pubsub.listen(): + if message["type"] == "message": + result = message["data"].decode('utf-8') + result = json.loads(result) + if result.get('error'): + cls.handle_error(result) + + return cls.get_message_response_data(result.get('data')) + except ValueError as e: + if e.args[0] != "I/O operation on closed file.": # ignore this error + raise CompletionStoppedError() + else: + logging.exception(e) + raise + finally: + try: + pubsub.unsubscribe(generate_channel) + except ConnectionError: + pass + else: + def generate() -> Generator: + try: + for message in pubsub.listen(): + if message["type"] == "message": + result = message["data"].decode('utf-8') + result = json.loads(result) + if result.get('error'): + cls.handle_error(result) + + event = result.get('event') + if event == "end": + logging.debug("{} finished".format(generate_channel)) + break + + if event == 'message': + yield "data: " + json.dumps(cls.get_message_response_data(result.get('data'))) + "\n\n" + elif event == 'chain': + yield "data: " + json.dumps(cls.get_chain_response_data(result.get('data'))) + "\n\n" + elif event == 'agent_thought': + yield "data: " + json.dumps(cls.get_agent_thought_response_data(result.get('data'))) + "\n\n" + except ValueError as e: + if e.args[0] != "I/O operation on closed file.": # ignore this error + logging.exception(e) + raise + finally: + try: + pubsub.unsubscribe(generate_channel) + except ConnectionError: + pass + + return generate() + + @classmethod + def get_message_response_data(cls, data: dict): + response_data = { + 'event': 'message', + 'task_id': data.get('task_id'), + 'id': data.get('message_id'), + 'answer': data.get('text'), + 'created_at': int(time.time()) + } + + if data.get('mode') == 'chat': + response_data['conversation_id'] = data.get('conversation_id') + + return response_data + + @classmethod + def get_chain_response_data(cls, data: dict): + response_data = { + 'event': 'chain', + 'id': data.get('chain_id'), + 'task_id': data.get('task_id'), + 'message_id': data.get('message_id'), + 'type': data.get('type'), + 'input': data.get('input'), + 'output': data.get('output'), + 'created_at': int(time.time()) + } + + if data.get('mode') == 'chat': + response_data['conversation_id'] = data.get('conversation_id') + + return response_data + + @classmethod + def get_agent_thought_response_data(cls, data: dict): + response_data = { + 'event': 'agent_thought', + 'id': data.get('agent_thought_id'), + 'chain_id': data.get('chain_id'), + 'task_id': data.get('task_id'), + 'message_id': data.get('message_id'), + 'position': data.get('position'), + 'thought': data.get('thought'), + 'tool': data.get('tool'), # todo use real dataset obj replace it + 'tool_input': data.get('tool_input'), + 'observation': data.get('observation'), + 'answer': data.get('answer') if not data.get('thought') else '', + 'created_at': int(time.time()) + } + + if data.get('mode') == 'chat': + response_data['conversation_id'] = data.get('conversation_id') + + return response_data + + @classmethod + def handle_error(cls, result: dict): + logging.debug("error: %s", result) + error = result.get('error') + description = result.get('description') + + # handle errors + llm_errors = { + 'LLMBadRequestError': LLMBadRequestError, + 'LLMAPIConnectionError': LLMAPIConnectionError, + 'LLMAPIUnavailableError': LLMAPIUnavailableError, + 'LLMRateLimitError': LLMRateLimitError, + 'ProviderTokenNotInitError': ProviderTokenNotInitError, + 'QuotaExceededError': QuotaExceededError, + 'ModelCurrentlyNotSupportError': ModelCurrentlyNotSupportError + } + + if error in llm_errors: + raise llm_errors[error](description) + elif error == 'LLMAuthorizationError': + raise LLMAuthorizationError('Incorrect API key provided') + else: + raise Exception(description) diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py new file mode 100644 index 0000000000..968209c71e --- /dev/null +++ b/api/services/conversation_service.py @@ -0,0 +1,94 @@ +from typing import Union, Optional + +from libs.infinite_scroll_pagination import InfiniteScrollPagination +from extensions.ext_database import db +from models.account import Account +from models.model import Conversation, App, EndUser +from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError + + +class ConversationService: + @classmethod + def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account | EndUser]], + last_id: Optional[str], limit: int, + include_ids: Optional[list] = None, exclude_ids: Optional[list] = None) -> InfiniteScrollPagination: + if not user: + return InfiniteScrollPagination(data=[], limit=limit, has_more=False) + + base_query = db.session.query(Conversation).filter( + Conversation.app_id == app_model.id, + Conversation.from_source == ('api' if isinstance(user, EndUser) else 'console'), + Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None), + Conversation.from_account_id == (user.id if isinstance(user, Account) else None), + ) + + if include_ids is not None: + base_query = base_query.filter(Conversation.id.in_(include_ids)) + + if exclude_ids is not None: + base_query = base_query.filter(~Conversation.id.in_(exclude_ids)) + + if last_id: + last_conversation = base_query.filter( + Conversation.id == last_id, + ).first() + + if not last_conversation: + raise LastConversationNotExistsError() + + conversations = base_query.filter( + Conversation.created_at < last_conversation.created_at, + Conversation.id != last_conversation.id + ).order_by(Conversation.created_at.desc()).limit(limit).all() + else: + conversations = base_query.order_by(Conversation.created_at.desc()).limit(limit).all() + + has_more = False + if len(conversations) == limit: + current_page_first_conversation = conversations[-1] + rest_count = base_query.filter( + Conversation.created_at < current_page_first_conversation.created_at, + Conversation.id != current_page_first_conversation.id + ).count() + + if rest_count > 0: + has_more = True + + return InfiniteScrollPagination( + data=conversations, + limit=limit, + has_more=has_more + ) + + @classmethod + def rename(cls, app_model: App, conversation_id: str, + user: Optional[Union[Account | EndUser]], name: str): + conversation = cls.get_conversation(app_model, conversation_id, user) + + conversation.name = name + db.session.commit() + + return conversation + + @classmethod + def get_conversation(cls, app_model: App, conversation_id: str, user: Optional[Union[Account | EndUser]]): + conversation = db.session.query(Conversation) \ + .filter( + Conversation.id == conversation_id, + Conversation.app_id == app_model.id, + Conversation.from_source == ('api' if isinstance(user, EndUser) else 'console'), + Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None), + Conversation.from_account_id == (user.id if isinstance(user, Account) else None), + ).first() + + if not conversation: + raise ConversationNotExistsError() + + return conversation + + @classmethod + def delete(cls, app_model: App, conversation_id: str, user: Optional[Union[Account | EndUser]]): + conversation = cls.get_conversation(app_model, conversation_id, user) + + db.session.delete(conversation) + db.session.commit() diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py new file mode 100644 index 0000000000..39004c3437 --- /dev/null +++ b/api/services/dataset_service.py @@ -0,0 +1,521 @@ +import json +import logging +import datetime +import time +import random +from typing import Optional +from extensions.ext_redis import redis_client +from flask_login import current_user + +from core.index.index_builder import IndexBuilder +from events.dataset_event import dataset_was_deleted +from events.document_event import document_was_deleted +from extensions.ext_database import db +from models.account import Account +from models.dataset import Dataset, Document, DatasetQuery, DatasetProcessRule, AppDatasetJoin +from models.model import UploadFile +from services.errors.account import NoPermissionError +from services.errors.dataset import DatasetNameDuplicateError +from services.errors.document import DocumentIndexingError +from services.errors.file import FileNotExistsError +from tasks.document_indexing_task import document_indexing_task + + +class DatasetService: + + @staticmethod + def get_datasets(page, per_page, provider="vendor", tenant_id=None, user=None): + if user: + permission_filter = db.or_(Dataset.created_by == user.id, + Dataset.permission == 'all_team_members') + else: + permission_filter = Dataset.permission == 'all_team_members' + datasets = Dataset.query.filter( + db.and_(Dataset.provider == provider, Dataset.tenant_id == tenant_id, permission_filter)) \ + .paginate( + page=page, + per_page=per_page, + max_per_page=100, + error_out=False + ) + + return datasets.items, datasets.total + + @staticmethod + def get_process_rules(dataset_id): + # get the latest process rule + dataset_process_rule = db.session.query(DatasetProcessRule). \ + filter(DatasetProcessRule.dataset_id == dataset_id). \ + order_by(DatasetProcessRule.created_at.desc()). \ + limit(1). \ + one_or_none() + if dataset_process_rule: + mode = dataset_process_rule.mode + rules = dataset_process_rule.rules_dict + else: + mode = DocumentService.DEFAULT_RULES['mode'] + rules = DocumentService.DEFAULT_RULES['rules'] + return { + 'mode': mode, + 'rules': rules + } + + @staticmethod + def get_datasets_by_ids(ids, tenant_id): + datasets = Dataset.query.filter(Dataset.id.in_(ids), + Dataset.tenant_id == tenant_id).paginate( + page=1, per_page=len(ids), max_per_page=len(ids), error_out=False) + return datasets.items, datasets.total + + @staticmethod + def create_empty_dataset(tenant_id: str, name: str, indexing_technique: Optional[str], account: Account): + # check if dataset name already exists + if Dataset.query.filter_by(name=name, tenant_id=tenant_id).first(): + raise DatasetNameDuplicateError( + f'Dataset with name {name} already exists.') + + dataset = Dataset(name=name, indexing_technique=indexing_technique, data_source_type='upload_file') + # dataset = Dataset(name=name, provider=provider, config=config) + dataset.created_by = account.id + dataset.updated_by = account.id + dataset.tenant_id = tenant_id + db.session.add(dataset) + db.session.commit() + return dataset + + @staticmethod + def get_dataset(dataset_id): + dataset = Dataset.query.filter_by( + id=dataset_id + ).first() + if dataset is None: + return None + else: + return dataset + + @staticmethod + def update_dataset(dataset_id, data, user): + dataset = DatasetService.get_dataset(dataset_id) + DatasetService.check_dataset_permission(dataset, user) + + filtered_data = {k: v for k, v in data.items() if v is not None or k == 'description'} + + filtered_data['updated_by'] = user.id + filtered_data['updated_at'] = datetime.datetime.now() + + dataset.query.filter_by(id=dataset_id).update(filtered_data) + + db.session.commit() + + return dataset + + @staticmethod + def delete_dataset(dataset_id, user): + # todo: cannot delete dataset if it is being processed + + dataset = DatasetService.get_dataset(dataset_id) + + if dataset is None: + return False + + DatasetService.check_dataset_permission(dataset, user) + + dataset_was_deleted.send(dataset) + + db.session.delete(dataset) + db.session.commit() + return True + + @staticmethod + def check_dataset_permission(dataset, user): + if dataset.tenant_id != user.current_tenant_id: + logging.debug( + f'User {user.id} does not have permission to access dataset {dataset.id}') + raise NoPermissionError( + 'You do not have permission to access this dataset.') + if dataset.permission == 'only_me' and dataset.created_by != user.id: + logging.debug( + f'User {user.id} does not have permission to access dataset {dataset.id}') + raise NoPermissionError( + 'You do not have permission to access this dataset.') + + @staticmethod + def get_dataset_queries(dataset_id: str, page: int, per_page: int): + dataset_queries = DatasetQuery.query.filter_by(dataset_id=dataset_id) \ + .order_by(db.desc(DatasetQuery.created_at)) \ + .paginate( + page=page, per_page=per_page, max_per_page=100, error_out=False + ) + return dataset_queries.items, dataset_queries.total + + @staticmethod + def get_related_apps(dataset_id: str): + return AppDatasetJoin.query.filter(AppDatasetJoin.dataset_id == dataset_id) \ + .order_by(db.desc(AppDatasetJoin.created_at)).all() + + +class DocumentService: + DEFAULT_RULES = { + 'mode': 'custom', + 'rules': { + 'pre_processing_rules': [ + {'id': 'remove_extra_spaces', 'enabled': True}, + {'id': 'remove_urls_emails', 'enabled': False} + ], + 'segmentation': { + 'delimiter': '\n', + 'max_tokens': 500 + } + } + } + + DOCUMENT_METADATA_SCHEMA = { + "book": { + "title": str, + "language": str, + "author": str, + "publisher": str, + "publication_date": str, + "isbn": str, + "category": str, + }, + "web_page": { + "title": str, + "url": str, + "language": str, + "publish_date": str, + "author/publisher": str, + "topic/keywords": str, + "description": str, + }, + "paper": { + "title": str, + "language": str, + "author": str, + "publish_date": str, + "journal/conference_name": str, + "volume/issue/page_numbers": str, + "doi": str, + "topic/keywords": str, + "abstract": str, + }, + "social_media_post": { + "platform": str, + "author/username": str, + "publish_date": str, + "post_url": str, + "topic/tags": str, + }, + "wikipedia_entry": { + "title": str, + "language": str, + "web_page_url": str, + "last_edit_date": str, + "editor/contributor": str, + "summary/introduction": str, + }, + "personal_document": { + "title": str, + "author": str, + "creation_date": str, + "last_modified_date": str, + "document_type": str, + "tags/category": str, + }, + "business_document": { + "title": str, + "author": str, + "creation_date": str, + "last_modified_date": str, + "document_type": str, + "department/team": str, + }, + "im_chat_log": { + "chat_platform": str, + "chat_participants/group_name": str, + "start_date": str, + "end_date": str, + "summary": str, + }, + "synced_from_notion": { + "title": str, + "language": str, + "author/creator": str, + "creation_date": str, + "last_modified_date": str, + "notion_page_link": str, + "category/tags": str, + "description": str, + }, + "synced_from_github": { + "repository_name": str, + "repository_description": str, + "repository_owner/organization": str, + "code_filename": str, + "code_file_path": str, + "programming_language": str, + "github_link": str, + "open_source_license": str, + "commit_date": str, + "commit_author": str + } + } + + @staticmethod + def get_document(dataset_id: str, document_id: str) -> Optional[Document]: + document = db.session.query(Document).filter( + Document.id == document_id, + Document.dataset_id == dataset_id + ).first() + + return document + + @staticmethod + def get_document_file_detail(file_id: str): + file_detail = db.session.query(UploadFile). \ + filter(UploadFile.id == file_id). \ + one_or_none() + return file_detail + + @staticmethod + def check_archived(document): + if document.archived: + return True + else: + return False + + @staticmethod + def delete_document(document): + if document.indexing_status in ["parsing", "cleaning", "splitting", "indexing"]: + raise DocumentIndexingError() + + # trigger document_was_deleted signal + document_was_deleted.send(document.id, dataset_id=document.dataset_id) + + db.session.delete(document) + db.session.commit() + + @staticmethod + def pause_document(document): + if document.indexing_status not in ["waiting", "parsing", "cleaning", "splitting", "indexing"]: + raise DocumentIndexingError() + # update document to be paused + document.is_paused = True + document.paused_by = current_user.id + document.paused_at = datetime.datetime.utcnow() + + db.session.add(document) + db.session.commit() + # set document paused flag + indexing_cache_key = 'document_{}_is_paused'.format(document.id) + redis_client.setnx(indexing_cache_key, "True") + + @staticmethod + def recover_document(document): + if not document.is_paused: + raise DocumentIndexingError() + # update document to be recover + document.is_paused = False + document.paused_by = current_user.id + document.paused_at = time.time() + + db.session.add(document) + db.session.commit() + # delete paused flag + indexing_cache_key = 'document_{}_is_paused'.format(document.id) + redis_client.delete(indexing_cache_key) + # trigger async task + document_indexing_task.delay(document.dataset_id, document.id) + + @staticmethod + def get_documents_position(dataset_id): + documents = Document.query.filter_by(dataset_id=dataset_id).all() + if documents: + return len(documents) + 1 + else: + return 1 + + @staticmethod + def save_document_with_dataset_id(dataset: Dataset, document_data: dict, + account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None, + created_from: str = 'web'): + if not dataset.indexing_technique: + if 'indexing_technique' not in document_data \ + or document_data['indexing_technique'] not in Dataset.INDEXING_TECHNIQUE_LIST: + raise ValueError("Indexing technique is required") + + dataset.indexing_technique = document_data["indexing_technique"] + + if dataset.indexing_technique == 'high_quality': + IndexBuilder.get_default_service_context(dataset.tenant_id) + + # save process rule + if not dataset_process_rule: + process_rule = document_data["process_rule"] + if process_rule["mode"] == "custom": + dataset_process_rule = DatasetProcessRule( + dataset_id=dataset.id, + mode=process_rule["mode"], + rules=json.dumps(process_rule["rules"]), + created_by=account.id + ) + elif process_rule["mode"] == "automatic": + dataset_process_rule = DatasetProcessRule( + dataset_id=dataset.id, + mode=process_rule["mode"], + rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), + created_by=account.id + ) + db.session.add(dataset_process_rule) + db.session.commit() + + file_name = '' + data_source_info = {} + if document_data["data_source"]["type"] == "upload_file": + file_id = document_data["data_source"]["info"] + file = db.session.query(UploadFile).filter( + UploadFile.tenant_id == dataset.tenant_id, + UploadFile.id == file_id + ).first() + + # raise error if file not found + if not file: + raise FileNotExistsError() + + file_name = file.name + data_source_info = { + "upload_file_id": file_id, + } + + # save document + position = DocumentService.get_documents_position(dataset.id) + document = Document( + tenant_id=dataset.tenant_id, + dataset_id=dataset.id, + position=position, + data_source_type=document_data["data_source"]["type"], + data_source_info=json.dumps(data_source_info), + dataset_process_rule_id=dataset_process_rule.id, + batch=time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)), + name=file_name, + created_from=created_from, + created_by=account.id, + # created_api_request_id = db.Column(UUID, nullable=True) + ) + + db.session.add(document) + db.session.commit() + + # trigger async task + document_indexing_task.delay(document.dataset_id, document.id) + + return document + + @staticmethod + def save_document_without_dataset_id(tenant_id: str, document_data: dict, account: Account): + # save dataset + dataset = Dataset( + tenant_id=tenant_id, + name='', + data_source_type=document_data["data_source"]["type"], + indexing_technique=document_data["indexing_technique"], + created_by=account.id + ) + + db.session.add(dataset) + db.session.flush() + + document = DocumentService.save_document_with_dataset_id(dataset, document_data, account) + + cut_length = 18 + cut_name = document.name[:cut_length] + dataset.name = cut_name + '...' if len(document.name) > cut_length else cut_name + dataset.description = 'useful for when you want to answer queries about the ' + document.name + db.session.commit() + + return dataset, document + + @classmethod + def document_create_args_validate(cls, args: dict): + if 'data_source' not in args or not args['data_source']: + raise ValueError("Data source is required") + + if not isinstance(args['data_source'], dict): + raise ValueError("Data source is invalid") + + if 'type' not in args['data_source'] or not args['data_source']['type']: + raise ValueError("Data source type is required") + + if args['data_source']['type'] not in Document.DATA_SOURCES: + raise ValueError("Data source type is invalid") + + if args['data_source']['type'] == 'upload_file': + if 'info' not in args['data_source'] or not args['data_source']['info']: + raise ValueError("Data source info is required") + + if 'process_rule' not in args or not args['process_rule']: + raise ValueError("Process rule is required") + + if not isinstance(args['process_rule'], dict): + raise ValueError("Process rule is invalid") + + if 'mode' not in args['process_rule'] or not args['process_rule']['mode']: + raise ValueError("Process rule mode is required") + + if args['process_rule']['mode'] not in DatasetProcessRule.MODES: + raise ValueError("Process rule mode is invalid") + + if args['process_rule']['mode'] == 'automatic': + args['process_rule']['rules'] = {} + else: + if 'rules' not in args['process_rule'] or not args['process_rule']['rules']: + raise ValueError("Process rule rules is required") + + if not isinstance(args['process_rule']['rules'], dict): + raise ValueError("Process rule rules is invalid") + + if 'pre_processing_rules' not in args['process_rule']['rules'] \ + or args['process_rule']['rules']['pre_processing_rules'] is None: + raise ValueError("Process rule pre_processing_rules is required") + + if not isinstance(args['process_rule']['rules']['pre_processing_rules'], list): + raise ValueError("Process rule pre_processing_rules is invalid") + + unique_pre_processing_rule_dicts = {} + for pre_processing_rule in args['process_rule']['rules']['pre_processing_rules']: + if 'id' not in pre_processing_rule or not pre_processing_rule['id']: + raise ValueError("Process rule pre_processing_rules id is required") + + if pre_processing_rule['id'] not in DatasetProcessRule.PRE_PROCESSING_RULES: + raise ValueError("Process rule pre_processing_rules id is invalid") + + if 'enabled' not in pre_processing_rule or pre_processing_rule['enabled'] is None: + raise ValueError("Process rule pre_processing_rules enabled is required") + + if not isinstance(pre_processing_rule['enabled'], bool): + raise ValueError("Process rule pre_processing_rules enabled is invalid") + + unique_pre_processing_rule_dicts[pre_processing_rule['id']] = pre_processing_rule + + args['process_rule']['rules']['pre_processing_rules'] = list(unique_pre_processing_rule_dicts.values()) + + if 'segmentation' not in args['process_rule']['rules'] \ + or args['process_rule']['rules']['segmentation'] is None: + raise ValueError("Process rule segmentation is required") + + if not isinstance(args['process_rule']['rules']['segmentation'], dict): + raise ValueError("Process rule segmentation is invalid") + + if 'separator' not in args['process_rule']['rules']['segmentation'] \ + or not args['process_rule']['rules']['segmentation']['separator']: + raise ValueError("Process rule segmentation separator is required") + + if not isinstance(args['process_rule']['rules']['segmentation']['separator'], str): + raise ValueError("Process rule segmentation separator is invalid") + + if 'max_tokens' not in args['process_rule']['rules']['segmentation'] \ + or not args['process_rule']['rules']['segmentation']['max_tokens']: + raise ValueError("Process rule segmentation max_tokens is required") + + if not isinstance(args['process_rule']['rules']['segmentation']['max_tokens'], int): + raise ValueError("Process rule segmentation max_tokens is invalid") diff --git a/api/services/errors/__init__.py b/api/services/errors/__init__.py new file mode 100644 index 0000000000..fe77ca86b6 --- /dev/null +++ b/api/services/errors/__init__.py @@ -0,0 +1,7 @@ +# -*- coding:utf-8 -*- +__all__ = [ + 'base', 'conversation', 'message', 'index', 'app_model_config', 'account', 'document', 'dataset', + 'app', 'completion' +] + +from . import * diff --git a/api/services/errors/account.py b/api/services/errors/account.py new file mode 100644 index 0000000000..14612eed75 --- /dev/null +++ b/api/services/errors/account.py @@ -0,0 +1,53 @@ +from services.errors.base import BaseServiceError + + +class AccountNotFound(BaseServiceError): + pass + + +class AccountRegisterError(BaseServiceError): + pass + + +class AccountLoginError(BaseServiceError): + pass + + +class AccountNotLinkTenantError(BaseServiceError): + pass + + +class CurrentPasswordIncorrectError(BaseServiceError): + pass + + +class LinkAccountIntegrateError(BaseServiceError): + pass + + +class TenantNotFound(BaseServiceError): + pass + + +class AccountAlreadyInTenantError(BaseServiceError): + pass + + +class InvalidActionError(BaseServiceError): + pass + + +class CannotOperateSelfError(BaseServiceError): + pass + + +class NoPermissionError(BaseServiceError): + pass + + +class MemberNotInTenantError(BaseServiceError): + pass + + +class RoleAlreadyAssignedError(BaseServiceError): + pass diff --git a/api/services/errors/app.py b/api/services/errors/app.py new file mode 100644 index 0000000000..7c4ca99c2a --- /dev/null +++ b/api/services/errors/app.py @@ -0,0 +1,2 @@ +class MoreLikeThisDisabledError(Exception): + pass diff --git a/api/services/errors/app_model_config.py b/api/services/errors/app_model_config.py new file mode 100644 index 0000000000..c0669ed231 --- /dev/null +++ b/api/services/errors/app_model_config.py @@ -0,0 +1,5 @@ +from services.errors.base import BaseServiceError + + +class AppModelConfigBrokenError(BaseServiceError): + pass diff --git a/api/services/errors/base.py b/api/services/errors/base.py new file mode 100644 index 0000000000..f5d41e17f1 --- /dev/null +++ b/api/services/errors/base.py @@ -0,0 +1,3 @@ +class BaseServiceError(Exception): + def __init__(self, description: str = None): + self.description = description \ No newline at end of file diff --git a/api/services/errors/completion.py b/api/services/errors/completion.py new file mode 100644 index 0000000000..7fc50a588e --- /dev/null +++ b/api/services/errors/completion.py @@ -0,0 +1,5 @@ +from services.errors.base import BaseServiceError + + +class CompletionStoppedError(BaseServiceError): + pass diff --git a/api/services/errors/conversation.py b/api/services/errors/conversation.py new file mode 100644 index 0000000000..139dd9a70a --- /dev/null +++ b/api/services/errors/conversation.py @@ -0,0 +1,13 @@ +from services.errors.base import BaseServiceError + + +class LastConversationNotExistsError(BaseServiceError): + pass + + +class ConversationNotExistsError(BaseServiceError): + pass + + +class ConversationCompletedError(Exception): + pass diff --git a/api/services/errors/dataset.py b/api/services/errors/dataset.py new file mode 100644 index 0000000000..254a70ffe3 --- /dev/null +++ b/api/services/errors/dataset.py @@ -0,0 +1,5 @@ +from services.errors.base import BaseServiceError + + +class DatasetNameDuplicateError(BaseServiceError): + pass diff --git a/api/services/errors/document.py b/api/services/errors/document.py new file mode 100644 index 0000000000..7327b9d032 --- /dev/null +++ b/api/services/errors/document.py @@ -0,0 +1,5 @@ +from services.errors.base import BaseServiceError + + +class DocumentIndexingError(BaseServiceError): + pass diff --git a/api/services/errors/file.py b/api/services/errors/file.py new file mode 100644 index 0000000000..3674eca3e7 --- /dev/null +++ b/api/services/errors/file.py @@ -0,0 +1,5 @@ +from services.errors.base import BaseServiceError + + +class FileNotExistsError(BaseServiceError): + pass diff --git a/api/services/errors/index.py b/api/services/errors/index.py new file mode 100644 index 0000000000..8513b6a55d --- /dev/null +++ b/api/services/errors/index.py @@ -0,0 +1,5 @@ +from services.errors.base import BaseServiceError + + +class IndexNotInitializedError(BaseServiceError): + pass diff --git a/api/services/errors/message.py b/api/services/errors/message.py new file mode 100644 index 0000000000..969447df9f --- /dev/null +++ b/api/services/errors/message.py @@ -0,0 +1,17 @@ +from services.errors.base import BaseServiceError + + +class FirstMessageNotExistsError(BaseServiceError): + pass + + +class LastMessageNotExistsError(BaseServiceError): + pass + + +class MessageNotExistsError(BaseServiceError): + pass + + +class SuggestedQuestionsAfterAnswerDisabledError(BaseServiceError): + pass diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py new file mode 100644 index 0000000000..619df1b873 --- /dev/null +++ b/api/services/hit_testing_service.py @@ -0,0 +1,130 @@ +import logging +import time +from typing import List + +import numpy as np +from llama_index.data_structs.node_v2 import NodeWithScore +from llama_index.indices.query.schema import QueryBundle +from llama_index.indices.vector_store import GPTVectorStoreIndexQuery +from sklearn.manifold import TSNE + +from core.docstore.empty_docstore import EmptyDocumentStore +from core.index.vector_index import VectorIndex +from extensions.ext_database import db +from models.account import Account +from models.dataset import Dataset, DocumentSegment, DatasetQuery +from services.errors.index import IndexNotInitializedError + + +class HitTestingService: + @classmethod + def retrieve(cls, dataset: Dataset, query: str, account: Account, limit: int = 10) -> dict: + index = VectorIndex(dataset=dataset).query_index + + if not index: + raise IndexNotInitializedError() + + index_query = GPTVectorStoreIndexQuery( + index_struct=index.index_struct, + service_context=index.service_context, + vector_store=index.query_context.get('vector_store'), + docstore=EmptyDocumentStore(), + response_synthesizer=None, + similarity_top_k=limit + ) + + query_bundle = QueryBundle( + query_str=query, + custom_embedding_strs=[query], + ) + + query_bundle.embedding = index.service_context.embed_model.get_agg_embedding_from_queries( + query_bundle.embedding_strs + ) + + start = time.perf_counter() + nodes = index_query.retrieve(query_bundle=query_bundle) + end = time.perf_counter() + logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds") + + dataset_query = DatasetQuery( + dataset_id=dataset.id, + content=query, + source='hit_testing', + created_by_role='account', + created_by=account.id + ) + + db.session.add(dataset_query) + db.session.commit() + + return cls.compact_retrieve_response(dataset, query_bundle, nodes) + + @classmethod + def compact_retrieve_response(cls, dataset: Dataset, query_bundle: QueryBundle, nodes: List[NodeWithScore]): + embeddings = [ + query_bundle.embedding + ] + + for node in nodes: + embeddings.append(node.node.embedding) + + tsne_position_data = cls.get_tsne_positions_from_embeddings(embeddings) + + query_position = tsne_position_data.pop(0) + + i = 0 + records = [] + for node in nodes: + index_node_id = node.node.doc_id + + segment = db.session.query(DocumentSegment).filter( + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.enabled == True, + DocumentSegment.status == 'completed', + DocumentSegment.index_node_id == index_node_id + ).first() + + if not segment: + i += 1 + continue + + record = { + "segment": segment, + "score": node.score, + "tsne_position": tsne_position_data[i] + } + + records.append(record) + + i += 1 + + return { + "query": { + "content": query_bundle.query_str, + "tsne_position": query_position, + }, + "records": records + } + + @classmethod + def get_tsne_positions_from_embeddings(cls, embeddings: list): + embedding_length = len(embeddings) + if embedding_length <= 1: + return [{'x': 0, 'y': 0}] + + concatenate_data = np.array(embeddings).reshape(embedding_length, -1) + # concatenate_data = np.concatenate(embeddings) + + perplexity = embedding_length / 2 + 1 + if perplexity >= embedding_length: + perplexity = max(embedding_length - 1, 1) + + tsne = TSNE(n_components=2, perplexity=perplexity, early_exaggeration=12.0) + data_tsne = tsne.fit_transform(concatenate_data) + + tsne_position_data = [] + for i in range(len(data_tsne)): + tsne_position_data.append({'x': float(data_tsne[i][0]), 'y': float(data_tsne[i][1])}) + + return tsne_position_data diff --git a/api/services/message_service.py b/api/services/message_service.py new file mode 100644 index 0000000000..b59fb0f10c --- /dev/null +++ b/api/services/message_service.py @@ -0,0 +1,212 @@ +from typing import Optional, Union, List + +from core.completion import Completion +from core.generator.llm_generator import LLMGenerator +from libs.infinite_scroll_pagination import InfiniteScrollPagination +from extensions.ext_database import db +from models.account import Account +from models.model import App, EndUser, Message, MessageFeedback +from services.conversation_service import ConversationService +from services.errors.message import FirstMessageNotExistsError, MessageNotExistsError, LastMessageNotExistsError, \ + SuggestedQuestionsAfterAnswerDisabledError + + +class MessageService: + @classmethod + def pagination_by_first_id(cls, app_model: App, user: Optional[Union[Account | EndUser]], + conversation_id: str, first_id: Optional[str], limit: int) -> InfiniteScrollPagination: + if not user: + return InfiniteScrollPagination(data=[], limit=limit, has_more=False) + + if not conversation_id: + return InfiniteScrollPagination(data=[], limit=limit, has_more=False) + + conversation = ConversationService.get_conversation( + app_model=app_model, + user=user, + conversation_id=conversation_id + ) + + if first_id: + first_message = db.session.query(Message) \ + .filter(Message.conversation_id == conversation.id, Message.id == first_id).first() + + if not first_message: + raise FirstMessageNotExistsError() + + history_messages = db.session.query(Message).filter( + Message.conversation_id == conversation.id, + Message.created_at < first_message.created_at, + Message.id != first_message.id + ) \ + .order_by(Message.created_at.desc()).limit(limit).all() + else: + history_messages = db.session.query(Message).filter(Message.conversation_id == conversation.id) \ + .order_by(Message.created_at.desc()).limit(limit).all() + + has_more = False + if len(history_messages) == limit: + current_page_first_message = history_messages[-1] + rest_count = db.session.query(Message).filter( + Message.conversation_id == conversation.id, + Message.created_at < current_page_first_message.created_at, + Message.id != current_page_first_message.id + ).count() + + if rest_count > 0: + has_more = True + + history_messages = list(reversed(history_messages)) + + return InfiniteScrollPagination( + data=history_messages, + limit=limit, + has_more=has_more + ) + + @classmethod + def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account | EndUser]], + last_id: Optional[str], limit: int, conversation_id: Optional[str] = None, + include_ids: Optional[list] = None) -> InfiniteScrollPagination: + if not user: + return InfiniteScrollPagination(data=[], limit=limit, has_more=False) + + base_query = db.session.query(Message) + + if conversation_id is not None: + conversation = ConversationService.get_conversation( + app_model=app_model, + user=user, + conversation_id=conversation_id + ) + + base_query = base_query.filter(Message.conversation_id == conversation.id) + + if include_ids is not None: + base_query = base_query.filter(Message.id.in_(include_ids)) + + if last_id: + last_message = base_query.filter(Message.id == last_id).first() + + if not last_message: + raise LastMessageNotExistsError() + + history_messages = base_query.filter( + Message.created_at < last_message.created_at, + Message.id != last_message.id + ).order_by(Message.created_at.desc()).limit(limit).all() + else: + history_messages = base_query.order_by(Message.created_at.desc()).limit(limit).all() + + has_more = False + if len(history_messages) == limit: + current_page_first_message = history_messages[-1] + rest_count = base_query.filter( + Message.created_at < current_page_first_message.created_at, + Message.id != current_page_first_message.id + ).count() + + if rest_count > 0: + has_more = True + + return InfiniteScrollPagination( + data=history_messages, + limit=limit, + has_more=has_more + ) + + @classmethod + def create_feedback(cls, app_model: App, message_id: str, user: Optional[Union[Account | EndUser]], + rating: Optional[str]) -> MessageFeedback: + if not user: + raise ValueError('user cannot be None') + + message = cls.get_message( + app_model=app_model, + user=user, + message_id=message_id + ) + + feedback = message.user_feedback + + if not rating and feedback: + db.session.delete(feedback) + elif rating and feedback: + feedback.rating = rating + elif not rating and not feedback: + raise ValueError('rating cannot be None when feedback not exists') + else: + feedback = MessageFeedback( + app_id=app_model.id, + conversation_id=message.conversation_id, + message_id=message.id, + rating=rating, + from_source=('user' if isinstance(user, EndUser) else 'admin'), + from_end_user_id=(user.id if isinstance(user, EndUser) else None), + from_account_id=(user.id if isinstance(user, Account) else None), + ) + db.session.add(feedback) + + db.session.commit() + + return feedback + + @classmethod + def get_message(cls, app_model: App, user: Optional[Union[Account | EndUser]], message_id: str): + message = db.session.query(Message).filter( + Message.id == message_id, + Message.app_id == app_model.id, + Message.from_source == ('api' if isinstance(user, EndUser) else 'console'), + Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None), + Message.from_account_id == (user.id if isinstance(user, Account) else None), + ).first() + + if not message: + raise MessageNotExistsError() + + return message + + @classmethod + def get_suggested_questions_after_answer(cls, app_model: App, user: Optional[Union[Account | EndUser]], + message_id: str, check_enabled: bool = True) -> List[Message]: + if not user: + raise ValueError('user cannot be None') + + app_model_config = app_model.app_model_config + suggested_questions_after_answer = app_model_config.suggested_questions_after_answer_dict + + if check_enabled and suggested_questions_after_answer.get("enabled", False) is False: + raise SuggestedQuestionsAfterAnswerDisabledError() + + message = cls.get_message( + app_model=app_model, + user=user, + message_id=message_id + ) + + conversation = ConversationService.get_conversation( + app_model=app_model, + conversation_id=message.conversation_id, + user=user + ) + + # get memory of conversation (read-only) + memory = Completion.get_memory_from_conversation( + tenant_id=app_model.tenant_id, + app_model_config=app_model.app_model_config, + conversation=conversation, + max_token_limit=3000, + message_limit=3, + return_messages=False, + memory_key="histories" + ) + + external_context = memory.load_memory_variables({}) + + questions = LLMGenerator.generate_suggested_questions_after_answer( + tenant_id=app_model.tenant_id, + **external_context + ) + + return questions + diff --git a/api/services/provider_service.py b/api/services/provider_service.py new file mode 100644 index 0000000000..7f6c7c9303 --- /dev/null +++ b/api/services/provider_service.py @@ -0,0 +1,96 @@ +from typing import Union + +from flask import current_app + +from core.llm.provider.llm_provider_service import LLMProviderService +from models.account import Tenant +from models.provider import * + + +class ProviderService: + + @staticmethod + def init_supported_provider(tenant, edition): + """Initialize the model provider, check whether the supported provider has a record""" + + providers = Provider.query.filter_by(tenant_id=tenant.id).all() + + openai_provider_exists = False + azure_openai_provider_exists = False + + # TODO: The cloud version needs to construct the data of the SYSTEM type + + for provider in providers: + if provider.provider_name == ProviderName.OPENAI.value and provider.provider_type == ProviderType.CUSTOM.value: + openai_provider_exists = True + if provider.provider_name == ProviderName.AZURE_OPENAI.value and provider.provider_type == ProviderType.CUSTOM.value: + azure_openai_provider_exists = True + + # Initialize the model provider, check whether the supported provider has a record + + # Create default providers if they don't exist + if not openai_provider_exists: + openai_provider = Provider( + tenant_id=tenant.id, + provider_name=ProviderName.OPENAI.value, + provider_type=ProviderType.CUSTOM.value, + is_valid=False + ) + db.session.add(openai_provider) + + if not azure_openai_provider_exists: + azure_openai_provider = Provider( + tenant_id=tenant.id, + provider_name=ProviderName.AZURE_OPENAI.value, + provider_type=ProviderType.CUSTOM.value, + is_valid=False + ) + db.session.add(azure_openai_provider) + + if not openai_provider_exists or not azure_openai_provider_exists: + db.session.commit() + + @staticmethod + def get_obfuscated_api_key(tenant, provider_name: ProviderName): + llm_provider_service = LLMProviderService(tenant.id, provider_name.value) + return llm_provider_service.get_provider_configs(obfuscated=True) + + @staticmethod + def get_token_type(tenant, provider_name: ProviderName): + llm_provider_service = LLMProviderService(tenant.id, provider_name.value) + return llm_provider_service.get_token_type() + + @staticmethod + def validate_provider_configs(tenant, provider_name: ProviderName, configs: Union[dict | str]): + llm_provider_service = LLMProviderService(tenant.id, provider_name.value) + return llm_provider_service.config_validate(configs) + + @staticmethod + def get_encrypted_token(tenant, provider_name: ProviderName, configs: Union[dict | str]): + llm_provider_service = LLMProviderService(tenant.id, provider_name.value) + return llm_provider_service.get_encrypted_token(configs) + + @staticmethod + def create_system_provider(tenant: Tenant, provider_name: str = ProviderName.OPENAI.value, + is_valid: bool = True): + if current_app.config['EDITION'] != 'CLOUD': + return + + provider = db.session.query(Provider).filter( + Provider.tenant_id == tenant.id, + Provider.provider_name == provider_name, + Provider.provider_type == ProviderType.SYSTEM.value + ).one_or_none() + + if not provider: + provider = Provider( + tenant_id=tenant.id, + provider_name=provider_name, + provider_type=ProviderType.SYSTEM.value, + quota_type=ProviderQuotaType.TRIAL.value, + quota_limit=200, + encrypted_config='', + is_valid=is_valid, + ) + db.session.add(provider) + db.session.commit() diff --git a/api/services/saved_message_service.py b/api/services/saved_message_service.py new file mode 100644 index 0000000000..1a68a1ba34 --- /dev/null +++ b/api/services/saved_message_service.py @@ -0,0 +1,66 @@ +from typing import Optional + +from libs.infinite_scroll_pagination import InfiniteScrollPagination +from extensions.ext_database import db +from models.model import App, EndUser +from models.web import SavedMessage +from services.message_service import MessageService + + +class SavedMessageService: + @classmethod + def pagination_by_last_id(cls, app_model: App, end_user: Optional[EndUser], + last_id: Optional[str], limit: int) -> InfiniteScrollPagination: + saved_messages = db.session.query(SavedMessage).filter( + SavedMessage.app_id == app_model.id, + SavedMessage.created_by == end_user.id + ).order_by(SavedMessage.created_at.desc()).all() + message_ids = [sm.message_id for sm in saved_messages] + + return MessageService.pagination_by_last_id( + app_model=app_model, + user=end_user, + last_id=last_id, + limit=limit, + include_ids=message_ids + ) + + @classmethod + def save(cls, app_model: App, user: Optional[EndUser], message_id: str): + saved_message = db.session.query(SavedMessage).filter( + SavedMessage.app_id == app_model.id, + SavedMessage.message_id == message_id, + SavedMessage.created_by == user.id + ).first() + + if saved_message: + return + + message = MessageService.get_message( + app_model=app_model, + user=user, + message_id=message_id + ) + + saved_message = SavedMessage( + app_id=app_model.id, + message_id=message.id, + created_by=user.id + ) + + db.session.add(saved_message) + db.session.commit() + + @classmethod + def delete(cls, app_model: App, user: Optional[EndUser], message_id: str): + saved_message = db.session.query(SavedMessage).filter( + SavedMessage.app_id == app_model.id, + SavedMessage.message_id == message_id, + SavedMessage.created_by == user.id + ).first() + + if not saved_message: + return + + db.session.delete(saved_message) + db.session.commit() diff --git a/api/services/web_conversation_service.py b/api/services/web_conversation_service.py new file mode 100644 index 0000000000..5cfab25006 --- /dev/null +++ b/api/services/web_conversation_service.py @@ -0,0 +1,74 @@ +from typing import Optional, Union + +from libs.infinite_scroll_pagination import InfiniteScrollPagination +from extensions.ext_database import db +from models.model import App, EndUser +from models.web import PinnedConversation +from services.conversation_service import ConversationService + + +class WebConversationService: + @classmethod + def pagination_by_last_id(cls, app_model: App, end_user: Optional[EndUser], + last_id: Optional[str], limit: int, pinned: Optional[bool] = None) -> InfiniteScrollPagination: + include_ids = None + exclude_ids = None + if pinned is not None: + pinned_conversations = db.session.query(PinnedConversation).filter( + PinnedConversation.app_id == app_model.id, + PinnedConversation.created_by == end_user.id + ).order_by(PinnedConversation.created_at.desc()).all() + pinned_conversation_ids = [pc.conversation_id for pc in pinned_conversations] + if pinned: + include_ids = pinned_conversation_ids + else: + exclude_ids = pinned_conversation_ids + + return ConversationService.pagination_by_last_id( + app_model=app_model, + user=end_user, + last_id=last_id, + limit=limit, + include_ids=include_ids, + exclude_ids=exclude_ids + ) + + @classmethod + def pin(cls, app_model: App, conversation_id: str, user: Optional[EndUser]): + pinned_conversation = db.session.query(PinnedConversation).filter( + PinnedConversation.app_id == app_model.id, + PinnedConversation.conversation_id == conversation_id, + PinnedConversation.created_by == user.id + ).first() + + if pinned_conversation: + return + + conversation = ConversationService.get_conversation( + app_model=app_model, + conversation_id=conversation_id, + user=user + ) + + pinned_conversation = PinnedConversation( + app_id=app_model.id, + conversation_id=conversation.id, + created_by=user.id + ) + + db.session.add(pinned_conversation) + db.session.commit() + + @classmethod + def unpin(cls, app_model: App, conversation_id: str, user: Optional[EndUser]): + pinned_conversation = db.session.query(PinnedConversation).filter( + PinnedConversation.app_id == app_model.id, + PinnedConversation.conversation_id == conversation_id, + PinnedConversation.created_by == user.id + ).first() + + if not pinned_conversation: + return + + db.session.delete(pinned_conversation) + db.session.commit() diff --git a/api/services/workspace_service.py b/api/services/workspace_service.py new file mode 100644 index 0000000000..92227ffa7a --- /dev/null +++ b/api/services/workspace_service.py @@ -0,0 +1,49 @@ +from extensions.ext_database import db +from models.account import Tenant +from models.provider import Provider, ProviderType + + +class WorkspaceService: + @classmethod + def get_tenant_info(cls, tenant: Tenant): + tenant_info = { + 'id': tenant.id, + 'name': tenant.name, + 'plan': tenant.plan, + 'status': tenant.status, + 'created_at': tenant.created_at, + 'providers': [], + 'in_trail': False, + 'trial_end_reason': 'using_custom' + } + + # Get providers + providers = db.session.query(Provider).filter( + Provider.tenant_id == tenant.id + ).all() + + # Add providers to the tenant info + tenant_info['providers'] = providers + + custom_provider = None + system_provider = None + + for provider in providers: + if provider.provider_type == ProviderType.CUSTOM.value: + if provider.is_valid and provider.encrypted_config: + custom_provider = provider + elif provider.provider_type == ProviderType.SYSTEM.value: + if provider.is_valid: + system_provider = provider + + if system_provider and not custom_provider: + quota_used = system_provider.quota_used if system_provider.quota_used is not None else 0 + quota_limit = system_provider.quota_limit if system_provider.quota_limit is not None else 0 + + if quota_used >= quota_limit: + tenant_info['trial_end_reason'] = 'trial_exceeded' + else: + tenant_info['in_trail'] = True + tenant_info['trial_end_reason'] = None + + return tenant_info diff --git a/api/tasks/add_document_to_index_task.py b/api/tasks/add_document_to_index_task.py new file mode 100644 index 0000000000..9ea259227e --- /dev/null +++ b/api/tasks/add_document_to_index_task.py @@ -0,0 +1,99 @@ +import datetime +import logging +import time + +import click +from celery import shared_task +from llama_index.data_structs import Node +from llama_index.data_structs.node_v2 import DocumentRelationship +from werkzeug.exceptions import NotFound + +from core.index.keyword_table_index import KeywordTableIndex +from core.index.vector_index import VectorIndex +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.dataset import DocumentSegment, Document + + +@shared_task +def add_document_to_index_task(document_id: str): + """ + Async Add document to index + :param document_id: + + Usage: add_document_to_index.delay(document_id) + """ + logging.info(click.style('Start add document to index: {}'.format(document_id), fg='green')) + start_at = time.perf_counter() + + document = db.session.query(Document).filter(Document.id == document_id).first() + if not document: + raise NotFound('Document not found') + + if document.indexing_status != 'completed': + return + + indexing_cache_key = 'document_{}_indexing'.format(document.id) + + try: + segments = db.session.query(DocumentSegment).filter( + DocumentSegment.document_id == document.id, + DocumentSegment.enabled == True + ) \ + .order_by(DocumentSegment.position.asc()).all() + + nodes = [] + previous_node = None + for segment in segments: + relationships = { + DocumentRelationship.SOURCE: document.id + } + + if previous_node: + relationships[DocumentRelationship.PREVIOUS] = previous_node.doc_id + + previous_node.relationships[DocumentRelationship.NEXT] = segment.index_node_id + + node = Node( + doc_id=segment.index_node_id, + doc_hash=segment.index_node_hash, + text=segment.content, + extra_info=None, + node_info=None, + relationships=relationships + ) + + previous_node = node + + nodes.append(node) + + dataset = document.dataset + + if not dataset: + raise Exception('Document has no dataset') + + vector_index = VectorIndex(dataset=dataset) + keyword_table_index = KeywordTableIndex(dataset=dataset) + + # save vector index + if dataset.indexing_technique == "high_quality": + vector_index.add_nodes( + nodes=nodes, + duplicate_check=True + ) + + # save keyword index + keyword_table_index.add_nodes(nodes) + + end_at = time.perf_counter() + logging.info( + click.style('Document added to index: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) + except Exception as e: + logging.exception("add document to index failed") + document.enabled = False + document.disabled_at = datetime.datetime.utcnow() + document.status = 'error' + document.error = str(e) + db.session.commit() + finally: + redis_client.delete(indexing_cache_key) diff --git a/api/tasks/add_segment_to_index_task.py b/api/tasks/add_segment_to_index_task.py new file mode 100644 index 0000000000..bd3cadfd3c --- /dev/null +++ b/api/tasks/add_segment_to_index_task.py @@ -0,0 +1,88 @@ +import datetime +import logging +import time + +import click +from celery import shared_task +from llama_index.data_structs import Node +from llama_index.data_structs.node_v2 import DocumentRelationship +from werkzeug.exceptions import NotFound + +from core.index.keyword_table_index import KeywordTableIndex +from core.index.vector_index import VectorIndex +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.dataset import DocumentSegment + + +@shared_task +def add_segment_to_index_task(segment_id: str): + """ + Async Add segment to index + :param segment_id: + + Usage: add_segment_to_index.delay(segment_id) + """ + logging.info(click.style('Start add segment to index: {}'.format(segment_id), fg='green')) + start_at = time.perf_counter() + + segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first() + if not segment: + raise NotFound('Segment not found') + + if segment.status != 'completed': + return + + indexing_cache_key = 'segment_{}_indexing'.format(segment.id) + + try: + relationships = { + DocumentRelationship.SOURCE: segment.document_id, + } + + previous_segment = segment.previous_segment + if previous_segment: + relationships[DocumentRelationship.PREVIOUS] = previous_segment.index_node_id + + next_segment = segment.next_segment + if next_segment: + relationships[DocumentRelationship.NEXT] = next_segment.index_node_id + + node = Node( + doc_id=segment.index_node_id, + doc_hash=segment.index_node_hash, + text=segment.content, + extra_info=None, + node_info=None, + relationships=relationships + ) + + dataset = segment.dataset + + if not dataset: + raise Exception('Segment has no dataset') + + vector_index = VectorIndex(dataset=dataset) + keyword_table_index = KeywordTableIndex(dataset=dataset) + + # save vector index + if dataset.indexing_technique == "high_quality": + vector_index.add_nodes( + nodes=[node], + duplicate_check=True + ) + + # save keyword index + keyword_table_index.add_nodes([node]) + + end_at = time.perf_counter() + logging.info(click.style('Segment added to index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green')) + except Exception as e: + logging.exception("add segment to index failed") + segment.enabled = False + segment.disabled_at = datetime.datetime.utcnow() + segment.status = 'error' + segment.error = str(e) + db.session.commit() + finally: + redis_client.delete(indexing_cache_key) diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py new file mode 100644 index 0000000000..3c5ea8eb95 --- /dev/null +++ b/api/tasks/clean_dataset_task.py @@ -0,0 +1,77 @@ +import logging +import time + +import click +from celery import shared_task + +from core.index.keyword_table_index import KeywordTableIndex +from core.index.vector_index import VectorIndex +from extensions.ext_database import db +from models.dataset import DocumentSegment, Dataset, DatasetKeywordTable, DatasetQuery, DatasetProcessRule, \ + AppDatasetJoin + + +@shared_task +def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, index_struct: str): + """ + Clean dataset when dataset deleted. + :param dataset_id: dataset id + :param tenant_id: tenant id + :param indexing_technique: indexing technique + :param index_struct: index struct dict + + Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct) + """ + logging.info(click.style('Start clean dataset when dataset deleted: {}'.format(dataset_id), fg='green')) + start_at = time.perf_counter() + + try: + dataset = Dataset( + id=dataset_id, + tenant_id=tenant_id, + indexing_technique=indexing_technique, + index_struct=index_struct + ) + + vector_index = VectorIndex(dataset=dataset) + keyword_table_index = KeywordTableIndex(dataset=dataset) + + documents = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all() + index_doc_ids = [document.id for document in documents] + segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all() + index_node_ids = [segment.index_node_id for segment in segments] + + # delete from vector index + if dataset.indexing_technique == "high_quality": + for index_doc_id in index_doc_ids: + try: + vector_index.del_doc(index_doc_id) + except Exception: + logging.exception("Delete doc index failed when dataset deleted.") + continue + + # delete from keyword index + if index_node_ids: + try: + keyword_table_index.del_nodes(index_node_ids) + except Exception: + logging.exception("Delete nodes index failed when dataset deleted.") + + for document in documents: + db.session.delete(document) + + for segment in segments: + db.session.delete(segment) + + db.session.query(DatasetKeywordTable).filter(DatasetKeywordTable.dataset_id == dataset_id).delete() + db.session.query(DatasetProcessRule).filter(DatasetProcessRule.dataset_id == dataset_id).delete() + db.session.query(DatasetQuery).filter(DatasetQuery.dataset_id == dataset_id).delete() + db.session.query(AppDatasetJoin).filter(AppDatasetJoin.dataset_id == dataset_id).delete() + + db.session.commit() + + end_at = time.perf_counter() + logging.info( + click.style('Cleaned dataset when dataset deleted: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green')) + except Exception: + logging.exception("Cleaned dataset when dataset deleted failed") diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py new file mode 100644 index 0000000000..5ca7f2d5c2 --- /dev/null +++ b/api/tasks/clean_document_task.py @@ -0,0 +1,52 @@ +import logging +import time + +import click +from celery import shared_task + +from core.index.keyword_table_index import KeywordTableIndex +from core.index.vector_index import VectorIndex +from extensions.ext_database import db +from models.dataset import DocumentSegment, Dataset + + +@shared_task +def clean_document_task(document_id: str, dataset_id: str): + """ + Clean document when document deleted. + :param document_id: document id + :param dataset_id: dataset id + + Usage: clean_document_task.delay(document_id, dataset_id) + """ + logging.info(click.style('Start clean document when document deleted: {}'.format(document_id), fg='green')) + start_at = time.perf_counter() + + try: + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + + if not dataset: + raise Exception('Document has no dataset') + + vector_index = VectorIndex(dataset=dataset) + keyword_table_index = KeywordTableIndex(dataset=dataset) + + segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() + index_node_ids = [segment.index_node_id for segment in segments] + + # delete from vector index + if dataset.indexing_technique == "high_quality": + vector_index.del_nodes(index_node_ids) + + # delete from keyword index + if index_node_ids: + keyword_table_index.del_nodes(index_node_ids) + + for segment in segments: + db.session.delete(segment) + + end_at = time.perf_counter() + logging.info( + click.style('Cleaned document when document deleted: {} latency: {}'.format(document_id, end_at - start_at), fg='green')) + except Exception: + logging.exception("Cleaned document when document deleted failed") diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py new file mode 100644 index 0000000000..59bbd4dc98 --- /dev/null +++ b/api/tasks/document_indexing_task.py @@ -0,0 +1,56 @@ +import datetime +import logging +import time + +import click +from celery import shared_task +from werkzeug.exceptions import NotFound + +from core.indexing_runner import IndexingRunner, DocumentIsPausedException +from core.llm.error import ProviderTokenNotInitError +from extensions.ext_database import db +from models.dataset import Document + + +@shared_task +def document_indexing_task(dataset_id: str, document_id: str): + """ + Async process document + :param dataset_id: + :param document_id: + + Usage: document_indexing_task.delay(dataset_id, document_id) + """ + logging.info(click.style('Start process document: {}'.format(document_id), fg='green')) + start_at = time.perf_counter() + + document = db.session.query(Document).filter( + Document.id == document_id, + Document.dataset_id == dataset_id + ).first() + + if not document: + raise NotFound('Document not found') + + document.indexing_status = 'parsing' + document.processing_started_at = datetime.datetime.utcnow() + db.session.commit() + + try: + indexing_runner = IndexingRunner() + indexing_runner.run(document) + end_at = time.perf_counter() + logging.info(click.style('Processed document: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) + except DocumentIsPausedException: + logging.info(click.style('Document paused, document id: {}'.format(document.id), fg='yellow')) + except ProviderTokenNotInitError as e: + document.indexing_status = 'error' + document.error = str(e.description) + document.stopped_at = datetime.datetime.utcnow() + db.session.commit() + except Exception as e: + logging.exception("consume document failed") + document.indexing_status = 'error' + document.error = str(e) + document.stopped_at = datetime.datetime.utcnow() + db.session.commit() diff --git a/api/tasks/generate_conversation_summary_task.py b/api/tasks/generate_conversation_summary_task.py new file mode 100644 index 0000000000..b19576f6fc --- /dev/null +++ b/api/tasks/generate_conversation_summary_task.py @@ -0,0 +1,46 @@ +import logging +import time + +import click +from celery import shared_task +from werkzeug.exceptions import NotFound + +from core.generator.llm_generator import LLMGenerator +from extensions.ext_database import db +from models.model import Conversation, Message + + +@shared_task +def generate_conversation_summary_task(conversation_id: str): + """ + Async Generate conversation summary + :param conversation_id: + + Usage: generate_conversation_summary_task.delay(conversation_id) + """ + logging.info(click.style('Start generate conversation summary: {}'.format(conversation_id), fg='green')) + start_at = time.perf_counter() + + conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first() + if not conversation: + raise NotFound('Conversation not found') + + try: + # get conversation messages count + history_message_count = conversation.message_count + if history_message_count >= 5: + app_model = conversation.app + if not app_model: + return + + history_messages = db.session.query(Message).filter(Message.conversation_id == conversation.id) \ + .order_by(Message.created_at.asc()).all() + + conversation.summary = LLMGenerator.generate_conversation_summary(app_model.tenant_id, history_messages) + db.session.add(conversation) + db.session.commit() + + end_at = time.perf_counter() + logging.info(click.style('Conversation summary generated: {} latency: {}'.format(conversation_id, end_at - start_at), fg='green')) + except Exception: + logging.exception("generate conversation summary failed") diff --git a/api/tasks/recover_document_indexing_task.py b/api/tasks/recover_document_indexing_task.py new file mode 100644 index 0000000000..c1a5d4336c --- /dev/null +++ b/api/tasks/recover_document_indexing_task.py @@ -0,0 +1,51 @@ +import datetime +import logging +import time + +import click +from celery import shared_task +from werkzeug.exceptions import NotFound + +from core.indexing_runner import IndexingRunner, DocumentIsPausedException +from extensions.ext_database import db +from models.dataset import Document + + +@shared_task +def recover_document_indexing_task(dataset_id: str, document_id: str): + """ + Async recover document + :param dataset_id: + :param document_id: + + Usage: recover_document_indexing_task.delay(dataset_id, document_id) + """ + logging.info(click.style('Recover document: {}'.format(document_id), fg='green')) + start_at = time.perf_counter() + + document = db.session.query(Document).filter( + Document.id == document_id, + Document.dataset_id == dataset_id + ).first() + + if not document: + raise NotFound('Document not found') + + try: + indexing_runner = IndexingRunner() + if document.indexing_status in ["waiting", "parsing", "cleaning"]: + indexing_runner.run(document) + elif document.indexing_status == "splitting": + indexing_runner.run_in_splitting_status(document) + elif document.indexing_status == "indexing": + indexing_runner.run_in_indexing_status(document) + end_at = time.perf_counter() + logging.info(click.style('Processed document: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) + except DocumentIsPausedException: + logging.info(click.style('Document paused, document id: {}'.format(document.id), fg='yellow')) + except Exception as e: + logging.exception("consume document failed") + document.indexing_status = 'error' + document.error = str(e) + document.stopped_at = datetime.datetime.utcnow() + db.session.commit() diff --git a/api/tasks/remove_document_from_index_task.py b/api/tasks/remove_document_from_index_task.py new file mode 100644 index 0000000000..3dc6f9cd77 --- /dev/null +++ b/api/tasks/remove_document_from_index_task.py @@ -0,0 +1,63 @@ +import logging +import time + +import click +from celery import shared_task +from werkzeug.exceptions import NotFound + +from core.index.keyword_table_index import KeywordTableIndex +from core.index.vector_index import VectorIndex +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.dataset import DocumentSegment, Document + + +@shared_task +def remove_document_from_index_task(document_id: str): + """ + Async Remove document from index + :param document_id: document id + + Usage: remove_document_from_index.delay(document_id) + """ + logging.info(click.style('Start remove document segments from index: {}'.format(document_id), fg='green')) + start_at = time.perf_counter() + + document = db.session.query(Document).filter(Document.id == document_id).first() + if not document: + raise NotFound('Document not found') + + if document.indexing_status != 'completed': + return + + indexing_cache_key = 'document_{}_indexing'.format(document.id) + + try: + dataset = document.dataset + + if not dataset: + raise Exception('Document has no dataset') + + vector_index = VectorIndex(dataset=dataset) + keyword_table_index = KeywordTableIndex(dataset=dataset) + + # delete from vector index + if dataset.indexing_technique == "high_quality": + vector_index.del_doc(document.id) + + # delete from keyword index + segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).all() + index_node_ids = [segment.index_node_id for segment in segments] + if index_node_ids: + keyword_table_index.del_nodes(index_node_ids) + + end_at = time.perf_counter() + logging.info( + click.style('Document removed from index: {} latency: {}'.format(document.id, end_at - start_at), fg='green')) + except Exception: + logging.exception("remove document from index failed") + if not document.archived: + document.enabled = True + db.session.commit() + finally: + redis_client.delete(indexing_cache_key) diff --git a/api/tasks/remove_segment_from_index_task.py b/api/tasks/remove_segment_from_index_task.py new file mode 100644 index 0000000000..48cebfc4d1 --- /dev/null +++ b/api/tasks/remove_segment_from_index_task.py @@ -0,0 +1,58 @@ +import logging +import time + +import click +from celery import shared_task +from werkzeug.exceptions import NotFound + +from core.index.keyword_table_index import KeywordTableIndex +from core.index.vector_index import VectorIndex +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.dataset import DocumentSegment + + +@shared_task +def remove_segment_from_index_task(segment_id: str): + """ + Async Remove segment from index + :param segment_id: + + Usage: remove_segment_from_index.delay(segment_id) + """ + logging.info(click.style('Start remove segment from index: {}'.format(segment_id), fg='green')) + start_at = time.perf_counter() + + segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first() + if not segment: + raise NotFound('Segment not found') + + if segment.status != 'completed': + return + + indexing_cache_key = 'segment_{}_indexing'.format(segment.id) + + try: + dataset = segment.dataset + + if not dataset: + raise Exception('Segment has no dataset') + + vector_index = VectorIndex(dataset=dataset) + keyword_table_index = KeywordTableIndex(dataset=dataset) + + # delete from vector index + if dataset.indexing_technique == "high_quality": + vector_index.del_nodes([segment.index_node_id]) + + # delete from keyword index + keyword_table_index.del_nodes([segment.index_node_id]) + + end_at = time.perf_counter() + logging.info(click.style('Segment removed from index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green')) + except Exception: + logging.exception("remove segment from index failed") + segment.enabled = True + db.session.commit() + finally: + redis_client.delete(indexing_cache_key) diff --git a/api/tests/__init__.py b/api/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/conftest.py b/api/tests/conftest.py new file mode 100644 index 0000000000..48de037846 --- /dev/null +++ b/api/tests/conftest.py @@ -0,0 +1,50 @@ +# -*- coding:utf-8 -*- + +import pytest +import flask_migrate + +from app import create_app +from extensions.ext_database import db + + +@pytest.fixture(scope='module') +def test_client(): + # Create a Flask app configured for testing + from config import TestConfig + flask_app = create_app(TestConfig()) + flask_app.config.from_object('config.TestingConfig') + + # Create a test client using the Flask application configured for testing + with flask_app.test_client() as testing_client: + # Establish an application context + with flask_app.app_context(): + yield testing_client # this is where the testing happens! + + +@pytest.fixture(scope='module') +def init_database(test_client): + # Initialize the database + with test_client.application.app_context(): + flask_migrate.upgrade() + + yield # this is where the testing happens! + + # Clean up the database + with test_client.application.app_context(): + flask_migrate.downgrade() + + +@pytest.fixture(scope='module') +def db_session(test_client): + with test_client.application.app_context(): + yield db.session + + +@pytest.fixture(scope='function') +def login_default_user(test_client): + + # todo + + yield # this is where the testing happens! + + test_client.get('/logout', follow_redirects=True) \ No newline at end of file diff --git a/api/tests/test_controllers/__init__.py b/api/tests/test_controllers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/test_controllers/test_account_api.py.bak b/api/tests/test_controllers/test_account_api.py.bak new file mode 100644 index 0000000000..a73c796b78 --- /dev/null +++ b/api/tests/test_controllers/test_account_api.py.bak @@ -0,0 +1,75 @@ +import json +import pytest +from flask import url_for + +from models.model import Account + +# Sample user data for testing +sample_user_data = { + 'name': 'Test User', + 'email': 'test@example.com', + 'interface_language': 'en-US', + 'interface_theme': 'light', + 'timezone': 'America/New_York', + 'password': 'testpassword', + 'new_password': 'newtestpassword', + 'repeat_new_password': 'newtestpassword' +} + +# Create a test user and log them in +@pytest.fixture(scope='function') +def logged_in_user(client, session): + # Create test user and add them to the database + # Replace this with your actual User model and any required fields + + # todo refer to api.controllers.setup.SetupApi.post() to create a user + db_user_data = sample_user_data.copy() + db_user_data['password_salt'] = 'testpasswordsalt' + del db_user_data['new_password'] + del db_user_data['repeat_new_password'] + test_user = Account(**db_user_data) + session.add(test_user) + session.commit() + + # Log in the test user + client.post(url_for('console.loginapi'), data={'email': sample_user_data['email'], 'password': sample_user_data['password']}) + + return test_user + +def test_account_profile(logged_in_user, client): + response = client.get(url_for('console.accountprofileapi')) + assert response.status_code == 200 + assert json.loads(response.data)['name'] == sample_user_data['name'] + +def test_account_name(logged_in_user, client): + new_name = 'New Test User' + response = client.post(url_for('console.accountnameapi'), json={'name': new_name}) + assert response.status_code == 200 + assert json.loads(response.data)['name'] == new_name + +def test_account_interface_language(logged_in_user, client): + new_language = 'zh-CN' + response = client.post(url_for('console.accountinterfacelanguageapi'), json={'interface_language': new_language}) + assert response.status_code == 200 + assert json.loads(response.data)['interface_language'] == new_language + +def test_account_interface_theme(logged_in_user, client): + new_theme = 'dark' + response = client.post(url_for('console.accountinterfacethemeapi'), json={'interface_theme': new_theme}) + assert response.status_code == 200 + assert json.loads(response.data)['interface_theme'] == new_theme + +def test_account_timezone(logged_in_user, client): + new_timezone = 'Asia/Shanghai' + response = client.post(url_for('console.accounttimezoneapi'), json={'timezone': new_timezone}) + assert response.status_code == 200 + assert json.loads(response.data)['timezone'] == new_timezone + +def test_account_password(logged_in_user, client): + response = client.post(url_for('console.accountpasswordapi'), json={ + 'password': sample_user_data['password'], + 'new_password': sample_user_data['new_password'], + 'repeat_new_password': sample_user_data['repeat_new_password'] + }) + assert response.status_code == 200 + assert json.loads(response.data)['result'] == 'success' diff --git a/api/tests/test_controllers/test_login.py b/api/tests/test_controllers/test_login.py new file mode 100644 index 0000000000..559e2f809e --- /dev/null +++ b/api/tests/test_controllers/test_login.py @@ -0,0 +1,108 @@ +import pytest +from app import create_app, db +from flask_login import current_user +from models.model import Account, TenantAccountJoin, Tenant + + +@pytest.fixture +def client(test_client, db_session): + app = create_app() + app.config["TESTING"] = True + with app.app_context(): + db.create_all() + yield test_client + db.drop_all() + + +def test_login_api_post(client, db_session): + # create a tenant, account, and tenant account join + tenant = Tenant(name="Test Tenant", status="normal") + account = Account(email="test@test.com", name="Test User") + account.password_salt = "uQ7K0/0wUJ7VPhf3qBzwNQ==" + account.password = "A9YpfzjK7c/tOwzamrvpJg==" + db.session.add_all([tenant, account]) + db.session.flush() + tenant_account_join = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, is_tenant_owner=True) + db.session.add(tenant_account_join) + db.session.commit() + + # login with correct credentials + response = client.post("/login", json={ + "email": "test@test.com", + "password": "Abc123456", + "remember_me": True + }) + assert response.status_code == 200 + assert response.json == {"result": "success"} + assert current_user == account + assert 'tenant_id' in client.session + assert client.session['tenant_id'] == tenant.id + + # login with incorrect password + response = client.post("/login", json={ + "email": "test@test.com", + "password": "wrong_password", + "remember_me": True + }) + assert response.status_code == 401 + + # login with non-existent account + response = client.post("/login", json={ + "email": "non_existent_account@test.com", + "password": "Abc123456", + "remember_me": True + }) + assert response.status_code == 401 + + +def test_logout_api_get(client, db_session): + # create a tenant, account, and tenant account join + tenant = Tenant(name="Test Tenant", status="normal") + account = Account(email="test@test.com", name="Test User") + db.session.add_all([tenant, account]) + db.session.flush() + tenant_account_join = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, is_tenant_owner=True) + db.session.add(tenant_account_join) + db.session.commit() + + # login and check if session variable and current_user are set + with client.session_transaction() as session: + session['tenant_id'] = tenant.id + client.post("/login", json={ + "email": "test@test.com", + "password": "Abc123456", + "remember_me": True + }) + assert current_user == account + assert 'tenant_id' in client.session + assert client.session['tenant_id'] == tenant.id + + # logout and check if session variable and current_user are unset + response = client.get("/logout") + assert response.status_code == 200 + assert current_user.is_authenticated is False + assert 'tenant_id' not in client.session + + +def test_reset_password_api_get(client, db_session): + # create a tenant, account, and tenant account join + tenant = Tenant(name="Test Tenant", status="normal") + account = Account(email="test@test.com", name="Test User") + db.session.add_all([tenant, account]) + db.session.flush() + tenant_account_join = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, is_tenant_owner=True) + db.session.add(tenant_account_join) + db.session.commit() + + # reset password in cloud edition + app = client.application + app.config["CLOUD_EDITION"] = True + response = client.get("/reset_password") + assert response.status_code == 200 + assert response.json == {"result": "success"} + + # reset password in non-cloud edition + app.config["CLOUD_EDITION"] = False + response = client.get("/reset_password") + assert response.status_code == 200 + assert response.json == {"result": "success"} diff --git a/api/tests/test_controllers/test_setup.py b/api/tests/test_controllers/test_setup.py new file mode 100644 index 0000000000..96a9b0911e --- /dev/null +++ b/api/tests/test_controllers/test_setup.py @@ -0,0 +1,80 @@ +import os +import pytest +from models.model import Account, Tenant, TenantAccountJoin + + +def test_setup_api_get(test_client,db_session): + response = test_client.get("/setup") + assert response.status_code == 200 + assert response.json == {"step": "not_start"} + + # create a tenant and check again + tenant = Tenant(name="Test Tenant", status="normal") + db_session.add(tenant) + db_session.commit() + response = test_client.get("/setup") + assert response.status_code == 200 + assert response.json == {"step": "step2"} + + # create setup file and check again + response = test_client.get("/setup") + assert response.status_code == 200 + assert response.json == {"step": "finished"} + + +def test_setup_api_post(test_client): + response = test_client.post("/setup", json={ + "email": "test@test.com", + "name": "Test User", + "password": "Abc123456" + }) + assert response.status_code == 200 + assert response.json == {"result": "success", "next_step": "step2"} + + # check if the tenant, account, and tenant account join records were created + tenant = Tenant.query.first() + assert tenant.name == "Test User's LLM Factory" + assert tenant.status == "normal" + assert tenant.encrypt_public_key + + account = Account.query.first() + assert account.email == "test@test.com" + assert account.name == "Test User" + assert account.password_salt + assert account.password + assert TenantAccountJoin.query.filter_by(account_id=account.id, is_tenant_owner=True).count() == 1 + + # check if password is encrypted correctly + salt = account.password_salt.encode() + password_hashed = account.password.encode() + assert account.password == base64.b64encode(hash_password("Abc123456", salt)).decode() + + +def test_setup_step2_api_post(test_client,db_session): + # create a tenant, account, and setup file + tenant = Tenant(name="Test Tenant", status="normal") + account = Account(email="test@test.com", name="Test User") + db_session.add_all([tenant, account]) + db_session.commit() + + # try to set up with incorrect language + response = test_client.post("/setup/step2", json={ + "interface_language": "invalid_language", + "timezone": "Asia/Shanghai" + }) + assert response.status_code == 400 + + # set up successfully + response = test_client.post("/setup/step2", json={ + "interface_language": "en", + "timezone": "Asia/Shanghai" + }) + assert response.status_code == 200 + assert response.json == {"result": "success", "next_step": "finished"} + + # check if account was updated correctly + account = Account.query.first() + assert account.interface_language == "en" + assert account.timezone == "Asia/Shanghai" + assert account.interface_theme == "light" + assert account.last_login_ip == "127.0.0.1" diff --git a/api/tests/test_factory.py b/api/tests/test_factory.py new file mode 100644 index 0000000000..0d73168b43 --- /dev/null +++ b/api/tests/test_factory.py @@ -0,0 +1,22 @@ +# -*- coding:utf-8 -*- + +import pytest + +from app import create_app + +def test_create_app(): + + # Test Default(CE) Config + app = create_app() + + assert app.config['SECRET_KEY'] is not None + assert app.config['SQLALCHEMY_DATABASE_URI'] is not None + assert app.config['EDITION'] == "SELF_HOSTED" + + # Test TestConfig + from config import TestConfig + test_app = create_app(TestConfig()) + + assert test_app.config['SECRET_KEY'] is not None + assert test_app.config['SQLALCHEMY_DATABASE_URI'] is not None + assert test_app.config['TESTING'] is True \ No newline at end of file diff --git a/api/tests/test_helpers/__init__.py b/api/tests/test_helpers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/test_libs/__init__.py b/api/tests/test_libs/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/test_models/__init__.py b/api/tests/test_models/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/test_services/__init__.py b/api/tests/test_services/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/docker/Dockerfile.base b/docker/Dockerfile.base new file mode 100644 index 0000000000..ec0e930f17 --- /dev/null +++ b/docker/Dockerfile.base @@ -0,0 +1,256 @@ +FROM nginx:1.22 + +# ensure local python is preferred over distribution python +ENV PATH /usr/local/bin:$PATH + +# http://bugs.python.org/issue19846 +# > At the moment, setting "LANG=C" on a Linux system *fundamentally breaks Python 3*, and that's not OK. +ENV LANG C.UTF-8 + +# runtime dependencies +RUN set -eux; \ + apt-get update; \ + apt-get install -y --no-install-recommends \ + ca-certificates \ + netbase \ + tzdata \ + ; \ + rm -rf /var/lib/apt/lists/* + +ENV GPG_KEY A035C8C19219BA821ECEA86B64E628F8D684696D +ENV PYTHON_VERSION 3.10.10 + +RUN set -eux; \ + \ + savedAptMark="$(apt-mark showmanual)"; \ + apt-get update; \ + apt-get install -y --no-install-recommends \ + dpkg-dev \ + gcc \ + gnupg dirmngr \ + libbluetooth-dev \ + libbz2-dev \ + libc6-dev \ + libdb-dev \ + libexpat1-dev \ + libffi-dev \ + libgdbm-dev \ + liblzma-dev \ + libncursesw5-dev \ + libreadline-dev \ + libsqlite3-dev \ + libssl-dev \ + make \ + tk-dev \ + uuid-dev \ + wget \ + xz-utils \ + zlib1g-dev \ + ; \ + \ + wget -O python.tar.xz "https://www.python.org/ftp/python/${PYTHON_VERSION%%[a-z]*}/Python-$PYTHON_VERSION.tar.xz"; \ + wget -O python.tar.xz.asc "https://www.python.org/ftp/python/${PYTHON_VERSION%%[a-z]*}/Python-$PYTHON_VERSION.tar.xz.asc"; \ + GNUPGHOME="$(mktemp -d)"; export GNUPGHOME; \ + gpg --batch --keyserver hkps://keys.openpgp.org --recv-keys "$GPG_KEY"; \ + gpg --batch --verify python.tar.xz.asc python.tar.xz; \ + command -v gpgconf > /dev/null && gpgconf --kill all || :; \ + rm -rf "$GNUPGHOME" python.tar.xz.asc; \ + mkdir -p /usr/src/python; \ + tar --extract --directory /usr/src/python --strip-components=1 --file python.tar.xz; \ + rm python.tar.xz; \ + \ + cd /usr/src/python; \ + gnuArch="$(dpkg-architecture --query DEB_BUILD_GNU_TYPE)"; \ + ./configure \ + --build="$gnuArch" \ + --enable-loadable-sqlite-extensions \ + --enable-optimizations \ + --enable-option-checking=fatal \ + --enable-shared \ + --with-lto \ + --with-system-expat \ + --without-ensurepip \ + ; \ + nproc="$(nproc)"; \ + EXTRA_CFLAGS="$(dpkg-buildflags --get CFLAGS)"; \ + LDFLAGS="$(dpkg-buildflags --get LDFLAGS)"; \ + LDFLAGS="${LDFLAGS:--Wl},--strip-all"; \ + make -j "$nproc" \ + "EXTRA_CFLAGS=${EXTRA_CFLAGS:-}" \ + "LDFLAGS=${LDFLAGS:-}" \ + "PROFILE_TASK=${PROFILE_TASK:-}" \ + ; \ +# https://github.com/docker-library/python/issues/784 +# prevent accidental usage of a system installed libpython of the same version + rm python; \ + make -j "$nproc" \ + "EXTRA_CFLAGS=${EXTRA_CFLAGS:-}" \ + "LDFLAGS=${LDFLAGS:--Wl},-rpath='\$\$ORIGIN/../lib'" \ + "PROFILE_TASK=${PROFILE_TASK:-}" \ + python \ + ; \ + make install; \ + \ + cd /; \ + rm -rf /usr/src/python; \ + \ + find /usr/local -depth \ + \( \ + \( -type d -a \( -name test -o -name tests -o -name idle_test \) \) \ + -o \( -type f -a \( -name '*.pyc' -o -name '*.pyo' -o -name 'libpython*.a' \) \) \ + \) -exec rm -rf '{}' + \ + ; \ + \ + ldconfig; \ + \ + apt-mark auto '.*' > /dev/null; \ + apt-mark manual $savedAptMark; \ + find /usr/local -type f -executable -not \( -name '*tkinter*' \) -exec ldd '{}' ';' \ + | awk '/=>/ { print $(NF-1) }' \ + | sort -u \ + | xargs -r dpkg-query --search \ + | cut -d: -f1 \ + | sort -u \ + | xargs -r apt-mark manual \ + ; \ + rm -rf /var/lib/apt/lists/*; \ + \ + python3 --version + +# make some useful symlinks that are expected to exist ("/usr/local/bin/python" and friends) +RUN set -eux; \ + for src in idle3 pydoc3 python3 python3-config; do \ + dst="$(echo "$src" | tr -d 3)"; \ + [ -s "/usr/local/bin/$src" ]; \ + [ ! -e "/usr/local/bin/$dst" ]; \ + ln -svT "$src" "/usr/local/bin/$dst"; \ + done + +# if this is called "PIP_VERSION", pip explodes with "ValueError: invalid truth value ''" +ENV PYTHON_PIP_VERSION 22.3.1 +# https://github.com/docker-library/python/issues/365 +ENV PYTHON_SETUPTOOLS_VERSION 65.5.1 +# https://github.com/pypa/get-pip +ENV PYTHON_GET_PIP_URL https://github.com/pypa/get-pip/raw/d5cb0afaf23b8520f1bbcfed521017b4a95f5c01/public/get-pip.py +ENV PYTHON_GET_PIP_SHA256 394be00f13fa1b9aaa47e911bdb59a09c3b2986472130f30aa0bfaf7f3980637 + +RUN set -eux; \ + \ + savedAptMark="$(apt-mark showmanual)"; \ + apt-get update; \ + apt-get install -y --no-install-recommends wget; \ + \ + wget -O get-pip.py "$PYTHON_GET_PIP_URL"; \ + echo "$PYTHON_GET_PIP_SHA256 *get-pip.py" | sha256sum -c -; \ + \ + apt-mark auto '.*' > /dev/null; \ + [ -z "$savedAptMark" ] || apt-mark manual $savedAptMark > /dev/null; \ + rm -rf /var/lib/apt/lists/*; \ + \ + export PYTHONDONTWRITEBYTECODE=1; \ + \ + python get-pip.py \ + --disable-pip-version-check \ + --no-cache-dir \ + --no-compile \ + "pip==$PYTHON_PIP_VERSION" \ + "setuptools==$PYTHON_SETUPTOOLS_VERSION" \ + ; \ + rm -f get-pip.py; \ + \ + pip --version + +RUN groupadd --gid 1000 node \ + && useradd --uid 1000 --gid node --shell /bin/bash --create-home node + +ENV NODE_VERSION 18.15.0 + +RUN ARCH= && dpkgArch="$(dpkg --print-architecture)" \ + && case "${dpkgArch##*-}" in \ + amd64) ARCH='x64';; \ + ppc64el) ARCH='ppc64le';; \ + s390x) ARCH='s390x';; \ + arm64) ARCH='arm64';; \ + armhf) ARCH='armv7l';; \ + i386) ARCH='x86';; \ + *) echo "unsupported architecture"; exit 1 ;; \ + esac \ + && set -ex \ + # libatomic1 for arm + && apt-get update && apt-get install -y ca-certificates curl wget gnupg dirmngr xz-utils libatomic1 --no-install-recommends \ + && rm -rf /var/lib/apt/lists/* \ + && for key in \ + 4ED778F539E3634C779C87C6D7062848A1AB005C \ + 141F07595B7B3FFE74309A937405533BE57C7D57 \ + 74F12602B6F1C4E913FAA37AD3A89613643B6201 \ + DD792F5973C6DE52C432CBDAC77ABFA00DDBF2B7 \ + 61FC681DFB92A079F1685E77973F295594EC4689 \ + 8FCCA13FEF1D0C2E91008E09770F7A9A5AE15600 \ + C4F0DFFF4E8C1A8236409D08E73BC641CC11F4C8 \ + 890C08DB8579162FEE0DF9DB8BEAB4DFCF555EF4 \ + C82FA3AE1CBEDC6BE46B9360C43CEC45C17AB93C \ + 108F52B48DB57BB0CC439B2997B01419BD92F80A \ + ; do \ + gpg --batch --keyserver hkps://keys.openpgp.org --recv-keys "$key" || \ + gpg --batch --keyserver keyserver.ubuntu.com --recv-keys "$key" ; \ + done \ + && curl -fsSLO --compressed "https://nodejs.org/dist/v$NODE_VERSION/node-v$NODE_VERSION-linux-$ARCH.tar.xz" \ + && curl -fsSLO --compressed "https://nodejs.org/dist/v$NODE_VERSION/SHASUMS256.txt.asc" \ + && gpg --batch --decrypt --output SHASUMS256.txt SHASUMS256.txt.asc \ + && grep " node-v$NODE_VERSION-linux-$ARCH.tar.xz\$" SHASUMS256.txt | sha256sum -c - \ + && tar -xJf "node-v$NODE_VERSION-linux-$ARCH.tar.xz" -C /usr/local --strip-components=1 --no-same-owner \ + && rm "node-v$NODE_VERSION-linux-$ARCH.tar.xz" SHASUMS256.txt.asc SHASUMS256.txt \ + && apt-mark auto '.*' > /dev/null \ + && find /usr/local -type f -executable -exec ldd '{}' ';' \ + | awk '/=>/ { print $(NF-1) }' \ + | sort -u \ + | xargs -r dpkg-query --search \ + | cut -d: -f1 \ + | sort -u \ + | xargs -r apt-mark manual \ + && ln -s /usr/local/bin/node /usr/local/bin/nodejs \ + # smoke tests + && node --version \ + && npm --version + +ENV YARN_VERSION 1.22.19 + +RUN set -ex \ + && savedAptMark="$(apt-mark showmanual)" \ + && apt-get update && apt-get install -y ca-certificates curl wget gnupg dirmngr --no-install-recommends \ + && rm -rf /var/lib/apt/lists/* \ + && for key in \ + 6A010C5166006599AA17F08146C2130DFD2497F5 \ + ; do \ + gpg --batch --keyserver hkps://keys.openpgp.org --recv-keys "$key" || \ + gpg --batch --keyserver keyserver.ubuntu.com --recv-keys "$key" ; \ + done \ + && curl -fsSLO --compressed "https://yarnpkg.com/downloads/$YARN_VERSION/yarn-v$YARN_VERSION.tar.gz" \ + && curl -fsSLO --compressed "https://yarnpkg.com/downloads/$YARN_VERSION/yarn-v$YARN_VERSION.tar.gz.asc" \ + && gpg --batch --verify yarn-v$YARN_VERSION.tar.gz.asc yarn-v$YARN_VERSION.tar.gz \ + && mkdir -p /opt \ + && tar -xzf yarn-v$YARN_VERSION.tar.gz -C /opt/ \ + && ln -s /opt/yarn-v$YARN_VERSION/bin/yarn /usr/local/bin/yarn \ + && ln -s /opt/yarn-v$YARN_VERSION/bin/yarnpkg /usr/local/bin/yarnpkg \ + && rm yarn-v$YARN_VERSION.tar.gz.asc yarn-v$YARN_VERSION.tar.gz \ + && apt-mark auto '.*' > /dev/null \ + && { [ -z "$savedAptMark" ] || apt-mark manual $savedAptMark > /dev/null; } \ + && find /usr/local -type f -executable -exec ldd '{}' ';' \ + | awk '/=>/ { print $(NF-1) }' \ + | sort -u \ + | xargs -r dpkg-query --search \ + | cut -d: -f1 \ + | sort -u \ + | xargs -r apt-mark manual \ + # smoke test + && yarn --version + + + +RUN apt-get update && \ + apt-get install -y bash curl wget vim gcc g++ python3-dev libc-dev libffi-dev + +RUN pip3 install gunicorn +RUN npm install pm2 -g + +ENTRYPOINT ["/usr/local/bin/pm2-runtime", "start"] \ No newline at end of file diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml new file mode 100644 index 0000000000..09372ca2a3 --- /dev/null +++ b/docker/docker-compose.middleware.yaml @@ -0,0 +1,53 @@ +version: '3.1' +services: + # The postgres database. + db: + image: postgres:15-alpine + restart: always + environment: + # The password for the default postgres user. + POSTGRES_PASSWORD: difyai123456 + # The name of the default postgres database. + POSTGRES_DB: dify + # postgres data directory + PGDATA: /var/lib/postgresql/data/pgdata + volumes: + - ./volumes/db/data:/var/lib/postgresql/data + - ./volumes/db/scripts:/docker-entrypoint-initdb.d/ + ports: + - "5432:5432" + + # The redis cache. + redis: + image: redis:6-alpine + restart: always + volumes: + # Mount the redis data directory to the container. + - ./volumes/redis/data:/data + # Set the redis password when startup redis server. + command: redis-server --requirepass difyai123456 + ports: + - "6379:6379" + + # The Weaviate vector store. + weaviate: + image: semitechnologies/weaviate:1.18.4 + restart: always + volumes: + # Mount the Weaviate data directory to the container. + - ./volumes/weaviate:/var/lib/weaviate + environment: + # The Weaviate configurations + # You can refer to the [Weaviate](https://weaviate.io/developers/weaviate/config-refs/env-vars) documentation for more information. + QUERY_DEFAULTS_LIMIT: 25 + AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED: 'false' + PERSISTENCE_DATA_PATH: '/var/lib/weaviate' + DEFAULT_VECTORIZER_MODULE: 'none' + CLUSTER_HOSTNAME: 'node1' + AUTHENTICATION_APIKEY_ENABLED: 'true' + AUTHENTICATION_APIKEY_ALLOWED_KEYS: 'WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih' + AUTHENTICATION_APIKEY_USERS: 'hello@dify.ai' + AUTHORIZATION_ADMINLIST_ENABLED: 'true' + AUTHORIZATION_ADMINLIST_USERS: 'hello@dify.ai' + ports: + - "8080:8080" \ No newline at end of file diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml new file mode 100644 index 0000000000..7d1542fbca --- /dev/null +++ b/docker/docker-compose.yaml @@ -0,0 +1,213 @@ +version: '3.1' +services: + # API service + api: + image: langgenius/dify-api:latest + restart: always + environment: + # Startup mode, 'api' starts the API server. + MODE: api + # The log level for the application. Supported values are `DEBUG`, `INFO`, `WARNING`, `ERROR`, `CRITICAL` + LOG_LEVEL: INFO + # A secret key that is used for securely signing the session cookie and encrypting sensitive information on the database. You can generate a strong key using `openssl rand -base64 42`. + SECRET_KEY: sk-9f73s3ljTXVcMT3Blb3ljTqtsKiGHXVcMT3BlbkFJLK7U + # The base URL of console application, refers to the Console base URL of WEB service. + CONSOLE_URL: http://localhost + # The URL for Service API endpoints,refers to the base URL of the current API service. + API_URL: http://localhost + # The URL for Web APP, refers to the Web App base URL of WEB service. + APP_URL: http://localhost + # When enabled, migrations will be executed prior to application startup and the application will start after the migrations have completed. + MIGRATION_ENABLED: 'true' + # The configurations of postgres database connection. + # It is consistent with the configuration in the 'db' service below. + DB_USERNAME: postgres + DB_PASSWORD: difyai123456 + DB_HOST: db + DB_PORT: 5432 + DB_DATABASE: dify + # The configurations of redis connection. + # It is consistent with the configuration in the 'redis' service below. + REDIS_HOST: redis + REDIS_PORT: 6379 + REDIS_PASSWORD: difyai123456 + # use redis db 0 for redis cache + REDIS_DB: 0 + # The configurations of session, Supported values are `sqlalchemy`. `redis` + SESSION_TYPE: redis + SESSION_REDIS_HOST: redis + SESSION_REDIS_PORT: 6379 + SESSION_REDIS_PASSWORD: difyai123456 + # use redis db 2 for session store + SESSION_REDIS_DB: 2 + # The configurations of celery broker. + # Use redis as the broker, and redis db 1 for celery broker. + CELERY_BROKER_URL: redis://:difyai123456@redis:6379/1 + # Specifies the allowed origins for cross-origin requests to the Web API + WEB_API_CORS_ALLOW_ORIGINS: http://localhost,* + # Specifies the allowed origins for cross-origin requests to the console API + CONSOLE_CORS_ALLOW_ORIGINS: http://localhost,* + # CSRF Cookie settings + # Controls whether a cookie is sent with cross-site requests, + # providing some protection against cross-site request forgery attacks + COOKIE_HTTPONLY: 'true' + COOKIE_SAMESITE: 'None' + COOKIE_SECURE: 'true' + # The type of storage to use for storing user files. Supported values are `local` and `s3`, Default: `local` + STORAGE_TYPE: local + # The path to the local storage directory, the directory relative the root path of API service codes or absolute path. Default: `storage` or `/home/john/storage`. + # only available when STORAGE_TYPE is `local`. + STORAGE_LOCAL_PATH: storage + # The S3 storage configurations, only available when STORAGE_TYPE is `s3`. + S3_ENDPOINT: 'https://xxx.r2.cloudflarestorage.com' + S3_BUCKET_NAME: 'difyai' + S3_ACCESS_KEY: 'ak-difyai' + S3_SECRET_KEY: 'sk-difyai' + S3_REGION: 'us-east-1' + # The type of vector store to use. Supported values are `weaviate`, `qdrant`. + VECTOR_STORE: weaviate + # The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`. + WEAVIATE_ENDPOINT: http://weaviate:8080 + # The Weaviate API key. + WEAVIATE_API_KEY: WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih + # The Qdrant endpoint URL. Only available when VECTOR_STORE is `qdrant`. + QDRANT_URL: 'https://your-qdrant-cluster-url.qdrant.tech/' + # The Qdrant API key. + QDRANT_API_KEY: 'ak-difyai' + # The DSN for Sentry error reporting. If not set, Sentry error reporting will be disabled. + SENTRY_DSN: '' + # The sample rate for Sentry events. Default: `1.0` + SENTRY_TRACES_SAMPLE_RATE: 1.0 + # The sample rate for Sentry profiles. Default: `1.0` + SENTRY_PROFILES_SAMPLE_RATE: 1.0 + depends_on: + - db + - redis + - weaviate + volumes: + # Mount the storage directory to the container, for storing user files. + - ./volumes/app/storage:/app/storage + + # worker service + # The Celery worker for processing the queue. + worker: + image: langgenius/dify-api:latest + restart: always + environment: + # Startup mode, 'worker' starts the Celery worker for processing the queue. + MODE: worker + + # --- All the configurations below are the same as those in the 'api' service. --- + + # The log level for the application. Supported values are `DEBUG`, `INFO`, `WARNING`, `ERROR`, `CRITICAL` + LOG_LEVEL: INFO + # A secret key that is used for securely signing the session cookie and encrypting sensitive information on the database. You can generate a strong key using `openssl rand -base64 42`. + # same as the API service + SECRET_KEY: sk-9f73s3ljTXVcMT3Blb3ljTqtsKiGHXVcMT3BlbkFJLK7U + # The base URL of console application, refers to the Console base URL of WEB service. + CONSOLE_URL: http://localhost + # The URL for Service API endpoints,refers to the base URL of the current API service. + API_URL: http://localhost + # The URL for Web APP, refers to the Web App base URL of WEB service. + APP_URL: http://localhost + # The configurations of postgres database connection. + # It is consistent with the configuration in the 'db' service below. + DB_USERNAME: postgres + DB_PASSWORD: difyai123456 + DB_HOST: db + DB_PORT: 5432 + DB_DATABASE: dify + # The configurations of redis cache connection. + REDIS_HOST: redis + REDIS_PORT: 6379 + REDIS_PASSWORD: difyai123456 + REDIS_DB: 0 + # The configurations of celery broker. + CELERY_BROKER_URL: redis://:difyai123456@redis:6379/1 + # The type of storage to use for storing user files. Supported values are `local` and `s3`, Default: `local` + STORAGE_TYPE: local + STORAGE_LOCAL_PATH: storage + # The Vector store configurations. + VECTOR_STORE: weaviate + WEAVIATE_ENDPOINT: http://weaviate:8080 + WEAVIATE_API_KEY: WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih + depends_on: + - db + - redis + - weaviate + volumes: + # Mount the storage directory to the container, for storing user files. + - ./volumes/app/storage:/app/storage + + # Frontend web application. + web: + image: langgenius/dify-web:latest + restart: always + environment: + EDITION: SELF_HOSTED + # The base URL of console application, refers to the Console base URL of WEB service. + CONSOLE_URL: http://localhost + # The URL for Web APP, refers to the Web App base URL of WEB service. + APP_URL: http://localhost + + # The postgres database. + db: + image: postgres:15-alpine + restart: always + environment: + # The password for the default postgres user. + POSTGRES_PASSWORD: difyai123456 + # The name of the default postgres database. + POSTGRES_DB: dify + # postgres data directory + PGDATA: /var/lib/postgresql/data/pgdata + volumes: + - ./volumes/db/data:/var/lib/postgresql/data + - ./volumes/db/scripts:/docker-entrypoint-initdb.d/ + ports: + - "5432:5432" + + # The redis cache. + redis: + image: redis:6-alpine + restart: always + volumes: + # Mount the redis data directory to the container. + - ./volumes/redis/data:/data + # Set the redis password when startup redis server. + command: redis-server --requirepass difyai123456 + + # The Weaviate vector store. + weaviate: + image: semitechnologies/weaviate:1.18.4 + restart: always + volumes: + # Mount the Weaviate data directory to the container. + - ./volumes/weaviate:/var/lib/weaviate + environment: + # The Weaviate configurations + # You can refer to the [Weaviate](https://weaviate.io/developers/weaviate/config-refs/env-vars) documentation for more information. + QUERY_DEFAULTS_LIMIT: 25 + AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED: 'false' + PERSISTENCE_DATA_PATH: '/var/lib/weaviate' + DEFAULT_VECTORIZER_MODULE: 'none' + CLUSTER_HOSTNAME: 'node1' + AUTHENTICATION_APIKEY_ENABLED: 'true' + AUTHENTICATION_APIKEY_ALLOWED_KEYS: 'WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih' + AUTHENTICATION_APIKEY_USERS: 'hello@dify.ai' + AUTHORIZATION_ADMINLIST_ENABLED: 'true' + AUTHORIZATION_ADMINLIST_USERS: 'hello@dify.ai' + + # The nginx reverse proxy. + # used for reverse proxying the API service and Web service. + nginx: + image: nginx:latest + volumes: + - ./nginx/nginx.conf:/etc/nginx/nginx.conf + - ./nginx/proxy.conf:/etc/nginx/proxy.conf + - ./nginx/conf.d:/etc/nginx/conf.d + depends_on: + - api + - web + ports: + - "80:80" \ No newline at end of file diff --git a/docker/nginx/conf.d/default.conf b/docker/nginx/conf.d/default.conf new file mode 100644 index 0000000000..3b153f40b0 --- /dev/null +++ b/docker/nginx/conf.d/default.conf @@ -0,0 +1,24 @@ +server { + listen 80; + server_name localhost; + + location /console/api { + proxy_pass http://api:5001; + include proxy.conf; + } + + location /api { + proxy_pass http://api:5001; + include proxy.conf; + } + + location /v1 { + proxy_pass http://api:5001; + include proxy.conf; + } + + location / { + proxy_pass http://web:3000; + include proxy.conf; + } +} \ No newline at end of file diff --git a/docker/nginx/nginx.conf b/docker/nginx/nginx.conf new file mode 100644 index 0000000000..d2b52963e8 --- /dev/null +++ b/docker/nginx/nginx.conf @@ -0,0 +1,32 @@ +user nginx; +worker_processes auto; + +error_log /var/log/nginx/error.log notice; +pid /var/run/nginx.pid; + + +events { + worker_connections 1024; +} + + +http { + include /etc/nginx/mime.types; + default_type application/octet-stream; + + log_format main '$remote_addr - $remote_user [$time_local] "$request" ' + '$status $body_bytes_sent "$http_referer" ' + '"$http_user_agent" "$http_x_forwarded_for"'; + + access_log /var/log/nginx/access.log main; + + sendfile on; + #tcp_nopush on; + + keepalive_timeout 65; + + #gzip on; + client_max_body_size 15M; + + include /etc/nginx/conf.d/*.conf; +} \ No newline at end of file diff --git a/docker/nginx/proxy.conf b/docker/nginx/proxy.conf new file mode 100644 index 0000000000..254f625961 --- /dev/null +++ b/docker/nginx/proxy.conf @@ -0,0 +1,8 @@ +proxy_set_header Host $host; +proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; +proxy_set_header X-Forwarded-Proto $scheme; +proxy_http_version 1.1; +proxy_set_header Connection ""; +proxy_buffering off; +proxy_read_timeout 3600s; +proxy_send_timeout 3600s; \ No newline at end of file diff --git a/docker/volumes/db/scripts/init_extension.sh b/docker/volumes/db/scripts/init_extension.sh new file mode 100644 index 0000000000..abad1e5182 --- /dev/null +++ b/docker/volumes/db/scripts/init_extension.sh @@ -0,0 +1 @@ +psql -U postgres -d dify -c 'CREATE EXTENSION IF NOT EXISTS "uuid-ossp";' \ No newline at end of file diff --git a/images/describe-cn.jpg b/images/describe-cn.jpg new file mode 100644 index 0000000000..10bdc28657 Binary files /dev/null and b/images/describe-cn.jpg differ diff --git a/images/describe-en.png b/images/describe-en.png new file mode 100644 index 0000000000..bf2bd22675 Binary files /dev/null and b/images/describe-en.png differ diff --git a/mock-server/.gitignore b/mock-server/.gitignore new file mode 100644 index 0000000000..02651453d8 --- /dev/null +++ b/mock-server/.gitignore @@ -0,0 +1,117 @@ +# Logs +logs +*.log +npm-debug.log* +yarn-debug.log* +yarn-error.log* +lerna-debug.log* + +# Diagnostic reports (https://nodejs.org/api/report.html) +report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json + +# Runtime data +pids +*.pid +*.seed +*.pid.lock + +# Directory for instrumented libs generated by jscoverage/JSCover +lib-cov + +# Coverage directory used by tools like istanbul +coverage +*.lcov + +# nyc test coverage +.nyc_output + +# Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files) +.grunt + +# Bower dependency directory (https://bower.io/) +bower_components + +# node-waf configuration +.lock-wscript + +# Compiled binary addons (https://nodejs.org/api/addons.html) +build/Release + +# Dependency directories +node_modules/ +jspm_packages/ + +# TypeScript v1 declaration files +typings/ + +# TypeScript cache +*.tsbuildinfo + +# Optional npm cache directory +.npm + +# Optional eslint cache +.eslintcache + +# Microbundle cache +.rpt2_cache/ +.rts2_cache_cjs/ +.rts2_cache_es/ +.rts2_cache_umd/ + +# Optional REPL history +.node_repl_history + +# Output of 'npm pack' +*.tgz + +# Yarn Integrity file +.yarn-integrity + +# dotenv environment variables file +.env +.env.test + +# parcel-bundler cache (https://parceljs.org/) +.cache + +# Next.js build output +.next + +# Nuxt.js build / generate output +.nuxt +dist + +# Gatsby files +.cache/ +# Comment in the public line in if your project uses Gatsby and *not* Next.js +# https://nextjs.org/blog/next-9-1#public-directory-support +# public + +# vuepress build output +.vuepress/dist + +# Serverless directories +.serverless/ + +# FuseBox cache +.fusebox/ + +# DynamoDB Local files +.dynamodb/ + +# TernJS port file +.tern-port + +# npm +package-lock.json + +# yarn +.pnp.cjs +.pnp.loader.mjs +.yarn/ +yarn.lock +.yarnrc.yml + +# pmpm +pnpm-lock.yaml \ No newline at end of file diff --git a/mock-server/README.md b/mock-server/README.md new file mode 100644 index 0000000000..7b0a621e84 --- /dev/null +++ b/mock-server/README.md @@ -0,0 +1 @@ +# Mock Server diff --git a/mock-server/api/apps.js b/mock-server/api/apps.js new file mode 100644 index 0000000000..d704387376 --- /dev/null +++ b/mock-server/api/apps.js @@ -0,0 +1,551 @@ +const chars = '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ-_' + +function randomString (length) { + let result = '' + for (let i = length; i > 0; --i) result += chars[Math.floor(Math.random() * chars.length)] + return result +} + +// https://www.notion.so/55773516a0194781ae211792a44a3663?pvs=4 +const VirtualData = new Array(10).fill().map((_, index) => { + const date = new Date(Date.now() - index * 24 * 60 * 60 * 1000) + return { + date: `${date.getFullYear()}-${date.getMonth()}-${date.getDate()}`, + conversation_count: Math.floor(Math.random() * 10) + index, + terminal_count: Math.floor(Math.random() * 10) + index, + token_count: Math.floor(Math.random() * 10) + index, + total_price: Math.floor(Math.random() * 10) + index, + } +}) + +const registerAPI = function (app) { + const apps = [{ + id: '1', + name: 'chat app', + mode: 'chat', + description: 'description01', + enable_site: true, + enable_api: true, + api_rpm: 60, + api_rph: 3600, + is_demo: false, + model_config: { + provider: 'OPENAI', + model_id: 'gpt-3.5-turbo', + configs: { + prompt_template: '你是我的解梦小助手,请参考 {{book}} 回答我有关梦境的问题。在回答前请称呼我为 {{myName}}。', + prompt_variables: [ + { + key: 'book', + name: '书', + value: '《梦境解析》', + type: 'string', + description: '请具体说下书名' + }, + { + key: 'myName', + name: 'your name', + value: 'Book', + type: 'string', + description: 'please tell me your name' + } + ], + completion_params: { + max_token: 16, + temperature: 1, // 0-2 + top_p: 1, + presence_penalty: 1, // -2-2 + frequency_penalty: 1, // -2-2 + } + } + }, + site: { + access_token: '1000', + title: 'site 01', + author: 'John', + default_language: 'zh-Hans-CN', + customize_domain: 'http://customize_domain', + theme: 'theme', + customize_token_strategy: 'must', + prompt_public: true + } + }, + { + id: '2', + name: 'completion app', + mode: 'completion', // genertation text + description: 'description 02', // genertation text + enable_site: false, + enable_api: false, + api_rpm: 60, + api_rph: 3600, + is_demo: false, + model_config: { + provider: 'OPENAI', + model_id: 'text-davinci-003', + configs: { + prompt_template: '你是我的翻译小助手,请把以下内容 {{langA}} 翻译成 {{langB}},以下的内容:', + prompt_variables: [ + { + key: 'langA', + name: '原始语音', + value: '中文', + type: 'string', + description: '这是中文格式的原始语音' + }, + { + key: 'langB', + name: '目标语言', + value: '英语', + type: 'string', + description: '这是英语格式的目标语言' + } + ], + completion_params: { + max_token: 16, + temperature: 1, // 0-2 + top_p: 1, + presence_penalty: 1, // -2-2 + frequency_penalty: 1, // -2-2 + } + } + }, + site: { + access_token: '2000', + title: 'site 02', + author: 'Mark', + default_language: 'en-US', + customize_domain: 'http://customize_domain', + theme: 'theme', + customize_token_strategy: 'must', + prompt_public: false + } + }, + ] + + const apikeys = [{ + id: '111121312313132', + token: 'sk-DEFGHJKMNPQRSTWXYZabcdefhijk1234', + last_used_at: '1679212138000', + created_at: '1673316000000' + }, { + id: '43441242131223123', + token: 'sk-EEFGHJKMNPQRSTWXYZabcdefhijk5678', + last_used_at: '1679212721000', + created_at: '1679212731000' + }] + + // create app + app.post('/apps', async (req, res) => { + apps.push({ + id: apps.length + 1 + '', + ...req.body, + + }) + res.send({ + result: 'success' + }) + }) + + // app list + app.get('/apps', async (req, res) => { + res.send({ + data: apps + }) + }) + + // app detail + app.get('/apps/:id', async (req, res) => { + const item = apps.find(item => item.id === req.params.id) || apps[0] + res.send(item) + }) + + // update app name + app.post('/apps/:id/name', async (req, res) => { + const item = apps.find(item => item.id === req.params.id) + item.name = req.body.name + res.send(item || null) + }) + + // update app site-enable status + app.post('/apps/:id/site-enable', async (req, res) => { + const item = apps.find(item => item.id === req.params.id) + console.log(item) + item.enable_site = req.body.enable_site + res.send(item || null) + }) + + // update app api-enable status + app.post('/apps/:id/api-enable', async (req, res) => { + const item = apps.find(item => item.id === req.params.id) + console.log(item) + item.enable_api = req.body.enable_api + res.send(item || null) + }) + + // update app rate-limit + app.post('/apps/:id/rate-limit', async (req, res) => { + const item = apps.find(item => item.id === req.params.id) + console.log(item) + item.api_rpm = req.body.api_rpm + item.api_rph = req.body.api_rph + res.send(item || null) + }) + + // update app url including code + app.post('/apps/:id/site/access-token-reset', async (req, res) => { + const item = apps.find(item => item.id === req.params.id) + console.log(item) + item.site.access_token = randomString(12) + res.send(item || null) + }) + + // update app config + app.post('/apps/:id/site', async (req, res) => { + const item = apps.find(item => item.id === req.params.id) + console.log(item) + item.name = req.body.title + item.description = req.body.description + item.prompt_public = req.body.prompt_public + item.default_language = req.body.default_language + res.send(item || null) + }) + + // get statistics daily-conversations + app.get('/apps/:id/statistics/daily-conversations', async (req, res) => { + const item = apps.find(item => item.id === req.params.id) + if (item) { + res.send({ + data: VirtualData + }) + } else { + res.send({ + data: [] + }) + } + }) + + // get statistics daily-end-users + app.get('/apps/:id/statistics/daily-end-users', async (req, res) => { + const item = apps.find(item => item.id === req.params.id) + if (item) { + res.send({ + data: VirtualData + }) + } else { + res.send({ + data: [] + }) + } + }) + + // get statistics token-costs + app.get('/apps/:id/statistics/token-costs', async (req, res) => { + const item = apps.find(item => item.id === req.params.id) + if (item) { + res.send({ + data: VirtualData + }) + } else { + res.send({ + data: [] + }) + } + }) + + // update app model config + app.post('/apps/:id/model-config', async (req, res) => { + const item = apps.find(item => item.id === req.params.id) + console.log(item) + item.model_config = req.body + res.send(item || null) + }) + + + // get api keys list + app.get('/apps/:id/api-keys', async (req, res) => { + res.send({ + data: apikeys + }) + }) + + // del api key + app.delete('/apps/:id/api-keys/:api_key_id', async (req, res) => { + res.send({ + result: 'success' + }) + }) + + // create api key + app.post('/apps/:id/api-keys', async (req, res) => { + res.send({ + id: 'e2424241313131', + token: 'sk-GEFGHJKMNPQRSTWXYZabcdefhijk0124', + created_at: '1679216688962' + }) + }) + + // get completion-conversations + app.get('/apps/:id/completion-conversations', async (req, res) => { + const data = { + data: [{ + id: 1, + from_end_user_id: 'user 1', + summary: 'summary1', + created_at: '2023-10-11', + annotated: true, + message_count: 100, + user_feedback_stats: { + like: 4, dislike: 5 + }, + admin_feedback_stats: { + like: 1, dislike: 2 + }, + message: { + message: 'message1', + query: 'question1', + answer: 'answer1' + } + }, { + id: 12, + from_end_user_id: 'user 2', + summary: 'summary2', + created_at: '2023-10-01', + annotated: false, + message_count: 10, + user_feedback_stats: { + like: 2, dislike: 20 + }, + admin_feedback_stats: { + like: 12, dislike: 21 + }, + message: { + message: 'message2', + query: 'question2', + answer: 'answer2' + } + }, { + id: 13, + from_end_user_id: 'user 3', + summary: 'summary3', + created_at: '2023-10-11', + annotated: false, + message_count: 20, + user_feedback_stats: { + like: 2, dislike: 0 + }, + admin_feedback_stats: { + like: 0, dislike: 21 + }, + message: { + message: 'message3', + query: 'question3', + answer: 'answer3' + } + }], + total: 200 + } + res.send(data) + }) + + // get chat-conversations + app.get('/apps/:id/chat-conversations', async (req, res) => { + const data = { + data: [{ + id: 1, + from_end_user_id: 'user 1', + summary: 'summary1', + created_at: '2023-10-11', + read_at: '2023-10-12', + annotated: true, + message_count: 100, + user_feedback_stats: { + like: 4, dislike: 5 + }, + admin_feedback_stats: { + like: 1, dislike: 2 + }, + message: { + message: 'message1', + query: 'question1', + answer: 'answer1' + } + }, { + id: 12, + from_end_user_id: 'user 2', + summary: 'summary2', + created_at: '2023-10-01', + annotated: false, + message_count: 10, + user_feedback_stats: { + like: 2, dislike: 20 + }, + admin_feedback_stats: { + like: 12, dislike: 21 + }, + message: { + message: 'message2', + query: 'question2', + answer: 'answer2' + } + }, { + id: 13, + from_end_user_id: 'user 3', + summary: 'summary3', + created_at: '2023-10-11', + annotated: false, + message_count: 20, + user_feedback_stats: { + like: 2, dislike: 0 + }, + admin_feedback_stats: { + like: 0, dislike: 21 + }, + message: { + message: 'message3', + query: 'question3', + answer: 'answer3' + } + }], + total: 200 + } + res.send(data) + }) + + // get completion-conversation detail + app.get('/apps/:id/completion-conversations/:cid', async (req, res) => { + const data = + { + id: 1, + from_end_user_id: 'user 1', + summary: 'summary1', + created_at: '2023-10-11', + annotated: true, + message: { + message: 'question1', + // query: 'question1', + answer: 'answer1', + annotation: { + content: '这是一段纠正的内容' + } + }, + model_config: { + provider: 'openai', + model_id: 'model_id', + configs: { + prompt_template: '你是我的翻译小助手,请把以下内容 {{langA}} 翻译成 {{langB}},以下的内容:{{content}}' + } + } + } + res.send(data) + }) + + // get chat-conversation detail + app.get('/apps/:id/chat-conversations/:cid', async (req, res) => { + const data = + { + id: 1, + from_end_user_id: 'user 1', + summary: 'summary1', + created_at: '2023-10-11', + annotated: true, + message: { + message: 'question1', + // query: 'question1', + answer: 'answer1', + created_at: '2023-08-09 13:00', + provider_response_latency: 130, + message_tokens: 230 + }, + model_config: { + provider: 'openai', + model_id: 'model_id', + configs: { + prompt_template: '你是我的翻译小助手,请把以下内容 {{langA}} 翻译成 {{langB}},以下的内容:{{content}}' + } + } + } + res.send(data) + }) + + // get chat-conversation message list + app.get('/apps/:id/chat-messages', async (req, res) => { + const data = { + data: [{ + id: 1, + created_at: '2023-10-11 07:09', + message: '请说说人为什么会做梦?' + req.query.conversation_id, + answer: '梦境通常是个人内心深处的反映,很难确定每个人梦境的确切含义,因为它们可能会受到梦境者的文化背景、生活经验和情感状态等多种因素的影响。', + provider_response_latency: 450, + answer_tokens: 200, + annotation: { + content: 'string', + account: { + id: 'string', + name: 'string', + email: 'string' + } + }, + feedbacks: { + rating: 'like', + content: 'string', + from_source: 'log' + } + }, { + id: 2, + created_at: '2023-10-11 8:23', + message: '夜里经常做梦会影响次日的精神状态吗?', + answer: '总之,这个梦境可能与梦境者的个人经历和情感状态有关,但在一般情况下,它可能表示一种强烈的情感反应,包括愤怒、不满和对于正义和自由的渴望。', + provider_response_latency: 400, + answer_tokens: 250, + annotation: { + content: 'string', + account: { + id: 'string', + name: 'string', + email: 'string' + } + }, + // feedbacks: { + // rating: 'like', + // content: 'string', + // from_source: 'log' + // } + }, { + id: 3, + created_at: '2023-10-11 10:20', + message: '梦见在山上手撕鬼子,大师解解梦', + answer: '但是,一般来说,“手撕鬼子”这个场景可能是梦境者对于过去历史上的战争、侵略以及对于自己国家和族群的保护与维护的情感反应。在梦中,你可能会感到自己充满力量和勇气,去对抗那些看似强大的侵略者。', + provider_response_latency: 288, + answer_tokens: 100, + annotation: { + content: 'string', + account: { + id: 'string', + name: 'string', + email: 'string' + } + }, + feedbacks: { + rating: 'dislike', + content: 'string', + from_source: 'log' + } + }], + limit: 20, + has_more: true + } + res.send(data) + }) + + app.post('/apps/:id/annotations', async (req, res) => { + res.send({ result: 'success' }) + }) + + app.post('/apps/:id/feedbacks', async (req, res) => { + res.send({ result: 'success' }) + }) + +} + +module.exports = registerAPI \ No newline at end of file diff --git a/mock-server/api/common.js b/mock-server/api/common.js new file mode 100644 index 0000000000..3e43ad524a --- /dev/null +++ b/mock-server/api/common.js @@ -0,0 +1,38 @@ + +const registerAPI = function (app) { + app.post('/login', async (req, res) => { + res.send({ + result: 'success' + }) + }) + + // get user info + app.get('/account/profile', async (req, res) => { + res.send({ + id: '11122222', + name: 'Joel', + email: 'iamjoel007@gmail.com' + }) + }) + + // logout + app.get('/logout', async (req, res) => { + res.send({ + result: 'success' + }) + }) + + // Langgenius version + app.get('/version', async (req, res) => { + res.send({ + current_version: 'v1.0.0', + latest_version: 'v1.0.0', + upgradeable: true, + compatible_upgrade: true + }) + }) + +} + +module.exports = registerAPI + diff --git a/mock-server/api/datasets.js b/mock-server/api/datasets.js new file mode 100644 index 0000000000..0821b3786b --- /dev/null +++ b/mock-server/api/datasets.js @@ -0,0 +1,249 @@ +const registerAPI = function (app) { + app.get("/datasets/:id/documents", async (req, res) => { + if (req.params.id === "0") res.send({ data: [] }); + else { + res.send({ + data: [ + { + id: 1, + name: "Steve Jobs' life", + words: "70k", + word_count: 100, + updated_at: 1681801029, + indexing_status: "completed", + archived: true, + enabled: false, + data_source_info: { + upload_file: { + // id: string + // name: string + // size: number + // mime_type: string + // created_at: number + // created_by: string + extension: "pdf", + }, + }, + }, + { + id: 2, + name: "Steve Jobs' life", + word_count: "10k", + hit_count: 10, + updated_at: 1681801029, + indexing_status: "waiting", + archived: true, + enabled: false, + data_source_info: { + upload_file: { + extension: "json", + }, + }, + }, + { + id: 3, + name: "Steve Jobs' life xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", + word_count: "100k", + hit_count: 0, + updated_at: 1681801029, + indexing_status: "indexing", + archived: false, + enabled: true, + data_source_info: { + upload_file: { + extension: "txt", + }, + }, + }, + { + id: 4, + name: "Steve Jobs' life xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", + word_count: "100k", + hit_count: 0, + updated_at: 1681801029, + indexing_status: "splitting", + archived: false, + enabled: true, + data_source_info: { + upload_file: { + extension: "md", + }, + }, + }, + { + id: 5, + name: "Steve Jobs' life", + word_count: "100k", + hit_count: 0, + updated_at: 1681801029, + indexing_status: "error", + archived: false, + enabled: false, + data_source_info: { + upload_file: { + extension: "html", + }, + }, + }, + ], + total: 100, + id: req.params.id, + }); + } + }); + + app.get("/datasets/:id/documents/:did/segments", async (req, res) => { + if (req.params.id === "0") res.send({ data: [] }); + else { + res.send({ + data: new Array(100).fill({ + id: 1234, + content: `他的坚持让我很为难。众所周知他非常注意保护自己的隐私,而我想他应该从来没有看过我写的书。也许将来的某个时候吧,我还是这么说。但是,到了2009年,他的妻子劳伦·鲍威尔(Laurene Powell)直言不讳地对我说:“如果你真的打算写一本关于史蒂夫的书,最好现在就开始。”他当时刚刚第二次因病休假。我向劳伦坦承,当乔布斯第一次提出这个想法时,我并不知道他病了。几乎没有人知道,她说。他是在接受癌症手术之前给我打的电话,直到今天他还将此事作为一个秘密,她这么解释道。\n + 他的坚持让我很为难。众所周知他非常注意保护自己的隐私,而我想他应该从来没有看过我写的书。也许将来的某个时候吧,我还是这么说。但是,到了2009年,他的妻子劳伦·鲍威尔(Laurene Powell)直言不讳地对我说:“如果你真的打算写一本关于史蒂夫的书,最好现在就开始。”他当时刚刚第二次因病休假。我向劳伦坦承,当乔布斯第一次提出这个想法时,我并不知道他病了。几乎没有人知道,她说。他是在接受癌症手术之前给我打的电话,直到今天他还将此事作为一个秘密,她这么解释道。`, + enabled: true, + keyWords: [ + "劳伦·鲍威尔", + "劳伦·鲍威尔", + "手术", + "秘密", + "癌症", + "乔布斯", + "史蒂夫", + "书", + "休假", + "坚持", + "隐私", + ], + word_count: 120, + hit_count: 100, + status: "ok", + index_node_hash: "index_node_hash value", + }), + limit: 100, + has_more: true, + }); + } + }); + + // get doc detail + app.get("/datasets/:id/documents/:did", async (req, res) => { + const fixedParams = { + // originInfo: { + originalFilename: "Original filename", + originalFileSize: "16mb", + uploadDate: "2023-01-01", + lastUpdateDate: "2023-01-05", + source: "Source", + // }, + // technicalParameters: { + segmentSpecification: "909090", + segmentLength: 100, + avgParagraphLength: 130, + }; + const bookData = { + doc_type: "book", + doc_metadata: { + title: "机器学习实战", + language: "zh", + author: "Peter Harrington", + publisher: "人民邮电出版社", + publicationDate: "2013-01-01", + ISBN: "9787115335500", + category: "技术", + }, + }; + const webData = { + doc_type: "webPage", + doc_metadata: { + title: "深度学习入门教程", + url: "https://www.example.com/deep-learning-tutorial", + language: "zh", + publishDate: "2020-05-01", + authorPublisher: "张三", + topicsKeywords: "深度学习, 人工智能, 教程", + description: + "这是一篇详细的深度学习入门教程,适用于对人工智能和深度学习感兴趣的初学者。", + }, + }; + const postData = { + doc_type: "socialMediaPost", + doc_metadata: { + platform: "Twitter", + authorUsername: "example_user", + publishDate: "2021-08-15", + postURL: "https://twitter.com/example_user/status/1234567890", + topicsTags: + "AI, DeepLearning, Tutorial, Example, Example2, Example3, AI, DeepLearning, Tutorial, Example, Example2, Example3, AI, DeepLearning, Tutorial, Example, Example2, Example3,", + }, + }; + res.send({ + id: "550e8400-e29b-41d4-a716-446655440000", + position: 1, + dataset_id: "550e8400-e29b-41d4-a716-446655440002", + data_source_type: "upload_file", + data_source_info: { + upload_file: { + extension: "html", + id: "550e8400-e29b-41d4-a716-446655440003", + }, + }, + dataset_process_rule_id: "550e8400-e29b-41d4-a716-446655440004", + batch: "20230410123456123456", + name: "example_document", + created_from: "web", + created_by: "550e8400-e29b-41d4-a716-446655440005", + created_api_request_id: "550e8400-e29b-41d4-a716-446655440006", + created_at: 1671269696, + processing_started_at: 1671269700, + word_count: 11, + parsing_completed_at: 1671269710, + cleaning_completed_at: 1671269720, + splitting_completed_at: 1671269730, + tokens: 10, + indexing_latency: 5.0, + completed_at: 1671269740, + paused_by: null, + paused_at: null, + error: null, + stopped_at: null, + indexing_status: "completed", + enabled: true, + disabled_at: null, + disabled_by: null, + archived: false, + archived_reason: null, + archived_by: null, + archived_at: null, + updated_at: 1671269740, + ...(req.params.did === "book" + ? bookData + : req.params.did === "web" + ? webData + : req.params.did === "post" + ? postData + : {}), + segment_count: 10, + hit_count: 9, + status: "ok", + }); + }); + + // // logout + // app.get("/logout", async (req, res) => { + // res.send({ + // result: "success", + // }); + // }); + + // // Langgenius version + // app.get("/version", async (req, res) => { + // res.send({ + // current_version: "v1.0.0", + // latest_version: "v1.0.0", + // upgradeable: true, + // compatible_upgrade: true, + // }); + // }); +}; + +module.exports = registerAPI; diff --git a/mock-server/api/debug.js b/mock-server/api/debug.js new file mode 100644 index 0000000000..2e6f3ca0a7 --- /dev/null +++ b/mock-server/api/debug.js @@ -0,0 +1,119 @@ +const registerAPI = function (app) { + const coversationList = [ + { + id: '1', + name: '梦的解析', + inputs: { + book: '《梦的解析》', + callMe: '大师', + }, + chats: [] + }, + { + id: '2', + name: '生命的起源', + inputs: { + book: '《x x x》', + } + }, + ] + // site info + app.get('/apps/site/info', async (req, res) => { + // const id = req.params.id + res.send({ + enable_site: true, + appId: '1', + site: { + title: 'Story Bot', + description: '这是一款解梦聊天机器人,你可以选择你喜欢的解梦人进行解梦,这句话是客户端应用说明', + }, + prompt_public: true, //id === '1', + prompt_template: '你是我的解梦小助手,请参考 {{book}} 回答我有关梦境的问题。在回答前请称呼我为 {{myName}}。', + }) + }) + + app.post('/apps/:id/chat-messages', async (req, res) => { + const conversationId = req.body.conversation_id ? req.body.conversation_id : Date.now() + '' + res.send({ + id: Date.now() + '', + conversation_id: Date.now() + '', + answer: 'balabababab' + }) + }) + + app.post('/apps/:id/completion-messages', async (req, res) => { + res.send({ + id: Date.now() + '', + answer: `做为一个AI助手,我可以为你提供随机生成的段落,这些段落可以用于测试、占位符、或者其他目的。以下是一个随机生成的段落: + + “随着科技的不断发展,越来越多的人开始意识到人工智能的重要性。人工智能已经成为我们生活中不可或缺的一部分,它可以帮助我们完成很多繁琐的工作,也可以为我们提供更智能、更便捷的服务。虽然人工智能带来了很多好处,但它也面临着很多挑战。例如,人工智能的算法可能会出现偏见,导致对某些人群不公平。此外,人工智能的发展也可能会导致一些工作的失业。因此,我们需要不断地研究人工智能的发展,以确保它能够为人类带来更多的好处。”` + }) + }) + + // share api + // chat list + app.get('/apps/:id/coversations', async (req, res) => { + res.send({ + data: coversationList + }) + }) + + + + app.get('/apps/:id/variables', async (req, res) => { + res.send({ + variables: [ + { + key: 'book', + name: '书', + value: '《梦境解析》', + type: 'string' + }, + { + key: 'myName', + name: '称呼', + value: '', + type: 'string' + } + ], + }) + }) + +} + +module.exports = registerAPI + +// const chatList = [ +// { +// id: 1, +// content: 'AI 开场白', +// isAnswer: true, +// }, +// { +// id: 2, +// content: '梦见在山上手撕鬼子,大师解解梦', +// more: { time: '5.6 秒' }, +// }, +// { +// id: 3, +// content: '梦境通常是个人内心深处的反映,很难确定每个人梦境的确切含义,因为它们可能会受到梦境者的文化背景、生活经验和情感状态等多种因素的影响。', +// isAnswer: true, +// more: { time: '99 秒' }, + +// }, +// { +// id: 4, +// content: '梦见在山上手撕鬼子,大师解解梦', +// more: { time: '5.6 秒' }, +// }, +// { +// id: 5, +// content: '梦见在山上手撕鬼子,大师解解梦', +// more: { time: '5.6 秒' }, +// }, +// { +// id: 6, +// content: '梦见在山上手撕鬼子,大师解解梦', +// more: { time: '5.6 秒' }, +// }, +// ] \ No newline at end of file diff --git a/mock-server/api/demo.js b/mock-server/api/demo.js new file mode 100644 index 0000000000..8f8a35079b --- /dev/null +++ b/mock-server/api/demo.js @@ -0,0 +1,15 @@ +const registerAPI = function (app) { + app.get('/demo', async (req, res) => { + res.send({ + des: 'get res' + }) + }) + + app.post('/demo', async (req, res) => { + res.send({ + des: 'post res' + }) + }) +} + +module.exports = registerAPI \ No newline at end of file diff --git a/mock-server/app.js b/mock-server/app.js new file mode 100644 index 0000000000..96eec0ab2a --- /dev/null +++ b/mock-server/app.js @@ -0,0 +1,42 @@ +const express = require('express') +const app = express() +const bodyParser = require('body-parser') +var cors = require('cors') + +const commonAPI = require('./api/common') +const demoAPI = require('./api/demo') +const appsApi = require('./api/apps') +const debugAPI = require('./api/debug') +const datasetsAPI = require('./api/datasets') + +const port = 3001 + +app.use(bodyParser.json()) // for parsing application/json +app.use(bodyParser.urlencoded({ extended: true })) // for parsing application/x-www-form-urlencoded + +const corsOptions = { + origin: true, + credentials: true, +} +app.use(cors(corsOptions)) // for cross origin +app.options('*', cors(corsOptions)) // include before other routes + + +demoAPI(app) +commonAPI(app) +appsApi(app) +debugAPI(app) +datasetsAPI(app) + + +app.get('/', (req, res) => { + res.send('rootpath') +}) + +app.listen(port, () => { + console.log(`Mock run on port ${port}`) +}) + +const sleep = (ms) => { + return new Promise(resolve => setTimeout(resolve, ms)) +} diff --git a/mock-server/package.json b/mock-server/package.json new file mode 100644 index 0000000000..11a68d61e7 --- /dev/null +++ b/mock-server/package.json @@ -0,0 +1,26 @@ +{ + "name": "server", + "version": "1.0.0", + "description": "", + "main": "index.js", + "scripts": { + "dev": "nodemon node app.js", + "start": "node app.js", + "tcp": "node tcp.js" + }, + "keywords": [], + "author": "", + "license": "MIT", + "engines": { + "node": ">=16.0.0" + }, + "dependencies": { + "body-parser": "^1.20.2", + "cors": "^2.8.5", + "express": "4.18.2", + "express-jwt": "8.4.1" + }, + "devDependencies": { + "nodemon": "2.0.21" + } +} diff --git a/sdks/nodejs-client/.gitignore b/sdks/nodejs-client/.gitignore new file mode 100644 index 0000000000..2c01328155 --- /dev/null +++ b/sdks/nodejs-client/.gitignore @@ -0,0 +1,49 @@ +# See https://help.github.com/articles/ignoring-files/ for more about ignoring files. + +# dependencies +/node_modules +/.pnp +.pnp.js + +# testing +/coverage + +# next.js +/.next/ +/out/ + +# production +/build + +# misc +.DS_Store +*.pem + +# debug +npm-debug.log* +yarn-debug.log* +yarn-error.log* +.pnpm-debug.log* + +# local env files +.env*.local + +# vercel +.vercel + +# typescript +*.tsbuildinfo +next-env.d.ts + +# npm +package-lock.json + +# yarn +.pnp.cjs +.pnp.loader.mjs +.yarn/ +yarn.lock +.yarnrc.yml + +# pmpm +pnpm-lock.yaml \ No newline at end of file diff --git a/sdks/nodejs-client/README.md b/sdks/nodejs-client/README.md new file mode 100644 index 0000000000..1bfd5f5e00 --- /dev/null +++ b/sdks/nodejs-client/README.md @@ -0,0 +1,46 @@ +# Dify Node.js SDK +This is the Node.js SDK for the Dify API, which allows you to easily integrate Dify into your Node.js applications. + +## Install +```bash +npm install dify-client +``` + +## Usage +After installing the SDK, you can use it in your project like this: + +```js +import { DifyClient, ChatClient, CompletionClient } from 'dify-client' + +const API_KEY = 'your-api-key-here'; +const user = `random-user-id`: + +// Create a completion client +const completionClient = new CompletionClient(API_KEY) +// Create a completion message +completionClient.createCompletionMessage(inputs, query, responseMode, user) + +// Create a chat client +const chatClient = new ChatClient(API_KEY) +// Create a chat message +chatClient.createChatMessage(inputs, query, responseMode, user, conversationId) +// Fetch conversations +chatClient.getConversations(user) +// Fetch conversation messages +chatClient.getConversationMessages(conversationId, user) +// Rename conversation +chatClient.renameConversation(conversationId, name, user) + + +const client = new DifyClient(API_KEY) +// Fetch application parameters +client.getApplicationParameters(user) +// Provide feedback for a message +client.messageFeedback(messageId, rating, user) + +``` + +Replace 'your-api-key-here' with your actual Dify API key.Replace 'your-app-id-here' with your actual Dify APP ID. + +## License +This SDK is released under the MIT License. diff --git a/sdks/nodejs-client/index.js b/sdks/nodejs-client/index.js new file mode 100644 index 0000000000..902ecbd880 --- /dev/null +++ b/sdks/nodejs-client/index.js @@ -0,0 +1,140 @@ +import axios from 'axios' + +const BASE_URL = 'https://api.dify.ai/v1' + +const routes = { + application: { + method: 'GET', + url: () => `/parameters` + }, + feedback: { + method: 'POST', + url: (messageId) => `/messages/${messageId}/feedbacks`, + }, + createCompletionMessage: { + method: 'POST', + url: () => `/completion-messages`, + }, + createChatMessage: { + method: 'POST', + url: () => `/chat-message`, + }, + getConversationMessages: { + method: 'GET', + url: () => '/messages', + }, + getConversations: { + method: 'GET', + url: () => '/conversations', + }, + renameConversation: { + method: 'PATCH', + url: (conversationId) => `/conversations/${conversationId}`, + } + +} + +export class DifyClient { + constructor(apiKey, baseUrl = BASE_URL) { + this.apiKey = apiKey + this.baseUrl = baseUrl + } + + updateApiKey(apiKey) { + this.apiKey = apiKey + } + + async sendRequest(method, endpoint, data = null, params = null, stream = false) { + const headers = { + 'Authorization': `Bearer ${this.apiKey}`, + 'Content-Type': 'application/json', + } + + const url = `${this.baseUrl}${endpoint}` + let response + if (!stream) { + response = await axios({ + method, + url, + data, + params, + headers, + responseType: stream ? 'stream' : 'json', + }) + } else { + response = await fetch(url, { + headers, + method, + body: JSON.stringify(data), + }) + } + + return response + } + + messageFeedback(messageId, rating, user) { + const data = { + rating, + user, + } + return this.sendRequest(routes.feedback.method, routes.feedback.url(messageId), data) + } + + getApplicationParameters(user) { + const params = { user } + return this.sendRequest(routes.application.method, routes.application.url(), null, params) + } +} + +export class CompletionClient extends DifyClient { + createCompletionMessage(inputs, query, user, responseMode) { + const data = { + inputs, + query, + responseMode, + user, + } + return this.sendRequest(routes.createCompletionMessage.method, routes.createCompletionMessage.url(), data, null, responseMode === 'streaming') + } +} + +export class ChatClient extends DifyClient { + createChatMessage(inputs, query, user, responseMode = 'blocking', conversationId = null) { + const data = { + inputs, + query, + user, + responseMode, + } + if (conversationId) + data.conversation_id = conversationId + + return this.sendRequest(routes.createChatMessage.method, routes.createChatMessage.url(), data, null, responseMode === 'streaming') + } + + getConversationMessages(user, conversationId = '', firstId = null, limit = null) { + const params = { user } + + if (conversationId) + params.conversation_id = conversationId + + if (firstId) + params.first_id = firstId + + if (limit) + params.limit = limit + + return this.sendRequest(routes.getConversationMessages.method, routes.getConversationMessages.url(), null, params) + } + + getConversations(user, firstId = null, limit = null, pinned = null) { + const params = { user, first_id: firstId, limit, pinned } + return this.sendRequest(routes.getConversations.method, routes.getConversations.url(), null, params) + } + + renameConversation(conversationId, name, user) { + const data = { name, user } + return this.sendRequest(routes.renameConversation.method, routes.renameConversation.url(conversationId), data) + } +} + diff --git a/sdks/nodejs-client/package.json b/sdks/nodejs-client/package.json new file mode 100644 index 0000000000..350cfc9732 --- /dev/null +++ b/sdks/nodejs-client/package.json @@ -0,0 +1,20 @@ +{ + "name": "dify-client", + "version": "1.0.2", + "description": "This is the Node.js SDK for the Dify.AI API, which allows you to easily integrate Dify.AI into your Node.js applications.", + "main": "index.js", + "type": "module", + "keywords": [ + "Dify", + "Dify.AI", + "LLM" + ], + "author": "Joel", + "contributors": [ + " <<427733928@qq.com>> (https://github.com/crazywoola)" + ], + "license": "MIT", + "dependencies": { + "axios": "^1.3.5" + } +} \ No newline at end of file diff --git a/sdks/php-client/README.md b/sdks/php-client/README.md new file mode 100644 index 0000000000..a44abc0f80 --- /dev/null +++ b/sdks/php-client/README.md @@ -0,0 +1,51 @@ +# Dify PHP SDK + +This is the PHP SDK for the Dify API, which allows you to easily integrate Dify into your PHP applications. + +## Requirements + +- PHP 7.2 or later +- Guzzle HTTP client library + +## Usage + +After installing the SDK, you can use it in your project like this: + +``` +create_completion_message($inputs, $query, $response_mode, $user); + +// Create a chat client +$chatClient = new ChatClient($apiKey); +$response = $chatClient->create_chat_message($inputs, $query, $user, $response_mode, $conversation_id); + +// Fetch application parameters +$response = $difyClient->get_application_parameters($user); + +// Provide feedback for a message +$response = $difyClient->message_feedback($message_id, $rating, $user); + +// Other available methods: +// - get_conversation_messages() +// - get_conversations() +// - rename_conversation() +``` + +Replace 'your-api-key-here' with your actual Dify API key. + +## License + +This SDK is released under the MIT License. \ No newline at end of file diff --git a/sdks/php-client/dify-client.php b/sdks/php-client/dify-client.php new file mode 100644 index 0000000000..7a5d9b60cf --- /dev/null +++ b/sdks/php-client/dify-client.php @@ -0,0 +1,109 @@ +api_key = $api_key; + $this->base_url = "https://api.dify.ai/v1"; + $this->client = new Client([ + 'base_uri' => $this->base_url, + 'headers' => [ + 'Authorization' => 'Bearer ' . $this->api_key, + 'Content-Type' => 'application/json', + ], + ]); + } + + protected function send_request($method, $endpoint, $data = null, $params = null, $stream = false) { + $options = [ + 'json' => $data, + 'query' => $params, + 'stream' => $stream, + ]; + + $response = $this->client->request($method, $endpoint, $options); + return $response; + } + + public function message_feedback($message_id, $rating, $user) { + $data = [ + 'rating' => $rating, + 'user' => $user, + ]; + return $this->send_request('POST', "/messages/{$message_id}/feedbacks", $data); + } + + public function get_application_parameters($user) { + $params = ['user' => $user]; + return $this->send_request('GET', '/parameters', null, $params); + } +} + +class CompletionClient extends DifyClient { + public function create_completion_message($inputs, $query, $response_mode, $user) { + $data = [ + 'inputs' => $inputs, + 'query' => $query, + 'response_mode' => $response_mode, + 'user' => $user, + ]; + return $this->send_request('POST', '/completion-messages', $data, null, $response_mode === 'streaming'); + } +} + +class ChatClient extends DifyClient { + public function create_chat_message($inputs, $query, $user, $response_mode = 'blocking', $conversation_id = null) { + $data = [ + 'inputs' => $inputs, + 'query' => $query, + 'user' => $user, + 'response_mode' => $response_mode, + ]; + if ($conversation_id) { + $data['conversation_id'] = $conversation_id; + } + + return $this->send_request('POST', '/chat-messages', $data, null, $response_mode === 'streaming'); + } + + public function get_conversation_messages($user, $conversation_id = null, $first_id = null, $limit = null) { + $params = ['user' => $user]; + + if ($conversation_id) { + $params['conversation_id'] = $conversation_id; + } + if ($first_id) { + $params['first_id'] = $first_id; + } + if ($limit) { + $params['limit'] = $limit; + } + + return $this->send_request('GET', '/messages', null, $params); + } + + public function get_conversations($user, $first_id = null, $limit = null, $pinned = null) { + $params = [ + 'user' => $user, + 'first_id' => $first_id, + 'limit' => $limit, + 'pinned'=> $pinned, + ]; + return $this->send_request('GET', '/conversations', null, $params); + } + + public function rename_conversation($conversation_id, $name, $user) { + $data = [ + 'name' => $name, + 'user' => $user, + ]; + return $this->send_request('PATCH', "/conversations/{$conversation_id}", $data); + } +} diff --git a/sdks/python-client/LICENSE b/sdks/python-client/LICENSE new file mode 100644 index 0000000000..873e44b4bc --- /dev/null +++ b/sdks/python-client/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 LangGenius + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/sdks/python-client/MANIFEST.in b/sdks/python-client/MANIFEST.in new file mode 100644 index 0000000000..da331d5e5c --- /dev/null +++ b/sdks/python-client/MANIFEST.in @@ -0,0 +1 @@ +recursive-include dify_client *.py \ No newline at end of file diff --git a/sdks/python-client/README.md b/sdks/python-client/README.md new file mode 100644 index 0000000000..0997d32632 --- /dev/null +++ b/sdks/python-client/README.md @@ -0,0 +1,100 @@ +# dify-client + +A Dify App Service-API Client, using for build a webapp by request Service-API + +## Usage + +First, install `dify-client` python sdk package: + +``` +pip install dify-client +``` + +Write your code with sdk: + +- completion generate with `blocking` response_mode + +``` +import json +from dify_client import CompletionClient + +api_key = "your_api_key" + +# Initialize CompletionClient +completion_client = CompletionClient(api_key) + +# Create Completion Message using CompletionClient +completion_response = completion_client.create_completion_message(inputs={}, query="Hello", response_mode="blocking", user="user_id") +completion_response.raise_for_status() + +result = completion_response.text +result = json.loads(result) + +print(result.get('answer')) +``` + +- chat generate with `streaming` response_mode + +``` +import json +from dify_client import ChatClient + +api_key = "your_api_key" + +# Initialize ChatClient +chat_client = ChatClient(api_key) + +# Create Chat Message using ChatClient +chat_response = chat_client.create_chat_message(inputs={}, query="Hello", user="user_id", response_mode="streaming") +chat_response.raise_for_status() + +for line in chat_response.iter_lines(decode_unicode=True): + line = line.split('data:', 1)[-1] + if line.strip(): + line = json.loads(line.strip()) + print(line.get('answer')) +``` + +- Others + +``` +import json +from dify_client import ChatClient + +api_key = "your_api_key" + +# Initialize Client +client = ChatClient(api_key) + +# Get App parameters +parameters = client.get_application_parameters(user="user_id") +parameters.raise_for_status() +parameters = json.loads(parameters.text) + +print('[parameters]') +print(parameters) + +# Get Conversation List (only for chat) +conversations = client.get_conversations(user="user_id") +conversations.raise_for_status() +conversations = json.loads(conversations.text) + +print('[conversations]') +print(conversations) + +# Get Message List (only for chat) +messages = client.get_conversation_messages(user="user_id", conversation_id="conversation_id") +messages.raise_for_status() +messages = json.loads(messages.text) + +print('[messages]') +print(messages) + +# Rename Conversation (only for chat) +rename_conversation_response = client.rename_conversation(conversation_id="conversation_id", name="new_name", user="user_id") +rename_conversation_response.raise_for_status() +rename_conversation_result = json.loads(rename_conversation_response.text) + +print('[rename result]') +print(rename_conversation_result) +``` diff --git a/sdks/python-client/build.sh b/sdks/python-client/build.sh new file mode 100755 index 0000000000..ca1a762c99 --- /dev/null +++ b/sdks/python-client/build.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +set -e + +rm -rf build dist *.egg-info + +pip install setuptools wheel twine +python setup.py sdist bdist_wheel +twine upload dist/* \ No newline at end of file diff --git a/sdks/python-client/dify_client/__init__.py b/sdks/python-client/dify_client/__init__.py new file mode 100644 index 0000000000..471b8d1990 --- /dev/null +++ b/sdks/python-client/dify_client/__init__.py @@ -0,0 +1 @@ +from dify_client.client import ChatClient, CompletionClient \ No newline at end of file diff --git a/sdks/python-client/dify_client/client.py b/sdks/python-client/dify_client/client.py new file mode 100644 index 0000000000..1a8bc3602e --- /dev/null +++ b/sdks/python-client/dify_client/client.py @@ -0,0 +1,74 @@ +import requests + + +class DifyClient: + def __init__(self, api_key): + self.api_key = api_key + self.base_url = "https://api.dify.ai/v1" + + def _send_request(self, method, endpoint, data=None, params=None, stream=False): + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json" + } + + url = f"{self.base_url}{endpoint}" + response = requests.request(method, url, json=data, params=params, headers=headers, stream=stream) + + return response + + def message_feedback(self, message_id, rating, user): + data = { + "rating": rating, + "user": user + } + return self._send_request("POST", f"/messages/{message_id}/feedbacks", data) + + def get_application_parameters(self, user): + params = {"user": user} + return self._send_request("GET", "/parameters", params=params) + + +class CompletionClient(DifyClient): + def create_completion_message(self, inputs, query, response_mode, user): + data = { + "inputs": inputs, + "query": query, + "response_mode": response_mode, + "user": user + } + return self._send_request("POST", "/completion-messages", data, stream=True if response_mode == "streaming" else False) + + +class ChatClient(DifyClient): + def create_chat_message(self, inputs, query, user, response_mode="blocking", conversation_id=None): + data = { + "inputs": inputs, + "query": query, + "user": user, + "response_mode": response_mode + } + if conversation_id: + data["conversation_id"] = conversation_id + + return self._send_request("POST", "/chat-messages", data, stream=True if response_mode == "streaming" else False) + + def get_conversation_messages(self, user, conversation_id=None, first_id=None, limit=None): + params = {"user": user} + + if conversation_id: + params["conversation_id"] = conversation_id + if first_id: + params["first_id"] = first_id + if limit: + params["limit"] = limit + + return self._send_request("GET", "/messages", params=params) + + def get_conversations(self, user, first_id=None, limit=None, pinned=None): + params = {"user": user, "first_id": first_id, "limit": limit, "pinned": pinned} + return self._send_request("GET", "/conversations", params=params) + + def rename_conversation(self, conversation_id, name, user): + data = {"name": name, "user": user} + return self._send_request("POST", f"/conversations/{conversation_id}/name", data) diff --git a/sdks/python-client/setup.py b/sdks/python-client/setup.py new file mode 100644 index 0000000000..17ba87fde2 --- /dev/null +++ b/sdks/python-client/setup.py @@ -0,0 +1,28 @@ +from setuptools import setup + +with open("README.md", "r", encoding="utf-8") as fh: + long_description = fh.read() + +setup( + name="dify-client", + version="0.1.7", + author="Dify", + author_email="hello@dify.ai", + description="A package for interacting with the Dify Service-API", + long_description=long_description, + long_description_content_type="text/markdown", + url="https://github.com/langgenius/dify", + license='MIT', + packages=['dify_client'], + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + ], + python_requires=">=3.6", + install_requires=[ + "requests" + ], + keywords='dify nlp ai language-processing', + include_package_data=True, +) diff --git a/sdks/python-client/tests/__init__.py b/sdks/python-client/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sdks/python-client/tests/test_client.py b/sdks/python-client/tests/test_client.py new file mode 100644 index 0000000000..f123c1882b --- /dev/null +++ b/sdks/python-client/tests/test_client.py @@ -0,0 +1,49 @@ +import os +import unittest +from dify_client.client import ChatClient, CompletionClient, DifyClient + +API_KEY = os.environ.get("API_KEY") +APP_ID = os.environ.get("APP_ID") + + +class TestChatClient(unittest.TestCase): + def setUp(self): + self.chat_client = ChatClient(API_KEY) + + def test_create_chat_message(self): + response = self.chat_client.create_chat_message({}, "Hello, World!", "test_user") + self.assertIn("message_id", response) + + def test_get_conversation_messages(self): + response = self.chat_client.get_conversation_messages("test_user") + self.assertIsInstance(response, list) + + def test_get_conversations(self): + response = self.chat_client.get_conversations("test_user") + self.assertIsInstance(response, list) + + +class TestCompletionClient(unittest.TestCase): + def setUp(self): + self.completion_client = CompletionClient(API_KEY) + + def test_create_completion_message(self): + response = self.completion_client.create_completion_message({}, "What's the weather like today?", "blocking", "test_user") + self.assertIn("message_id", response) + + +class TestDifyClient(unittest.TestCase): + def setUp(self): + self.dify_client = DifyClient(API_KEY) + + def test_message_feedback(self): + response = self.dify_client.message_feedback("test_message_id", 5, "test_user") + self.assertIn("success", response) + + def test_get_application_parameters(self): + response = self.dify_client.get_application_parameters("test_user") + self.assertIsInstance(response, dict) + + +if __name__ == "__main__": + unittest.main() diff --git a/web/.editorconfig b/web/.editorconfig new file mode 100644 index 0000000000..e1d3f0b992 --- /dev/null +++ b/web/.editorconfig @@ -0,0 +1,22 @@ +# EditorConfig is awesome: https://EditorConfig.org + +# top-most EditorConfig file +root = true + +# Unix-style newlines with a newline ending every file +[*] +end_of_line = lf +insert_final_newline = true + +# Matches multiple files with brace expansion notation +# Set default charset +[*.{js,tsx}] +charset = utf-8 +indent_style = space +indent_size = 2 + + +# Matches the exact files either package.json or .travis.yml +[{package.json,.travis.yml}] +indent_style = space +indent_size = 2 diff --git a/web/.eslintrc.json b/web/.eslintrc.json new file mode 100644 index 0000000000..db813b0b25 --- /dev/null +++ b/web/.eslintrc.json @@ -0,0 +1,28 @@ +{ + "extends": [ + "@antfu", + "plugin:react-hooks/recommended" + ], + "rules": { + "@typescript-eslint/consistent-type-definitions": [ + "error", + "type" + ], + "no-console": "off", + "indent": "off", + "@typescript-eslint/indent": [ + "error", + 2, + { + "SwitchCase": 1, + "flatTernaryExpressions": false, + "ignoredNodes": [ + "PropertyDefinition[decorators]", + "TSUnionType", + "FunctionExpression[params]:has(Identifier[decorators])" + ] + } + ], + "react-hooks/exhaustive-deps": "warning" + } +} diff --git a/web/.gitignore b/web/.gitignore new file mode 100644 index 0000000000..2c01328155 --- /dev/null +++ b/web/.gitignore @@ -0,0 +1,49 @@ +# See https://help.github.com/articles/ignoring-files/ for more about ignoring files. + +# dependencies +/node_modules +/.pnp +.pnp.js + +# testing +/coverage + +# next.js +/.next/ +/out/ + +# production +/build + +# misc +.DS_Store +*.pem + +# debug +npm-debug.log* +yarn-debug.log* +yarn-error.log* +.pnpm-debug.log* + +# local env files +.env*.local + +# vercel +.vercel + +# typescript +*.tsbuildinfo +next-env.d.ts + +# npm +package-lock.json + +# yarn +.pnp.cjs +.pnp.loader.mjs +.yarn/ +yarn.lock +.yarnrc.yml + +# pmpm +pnpm-lock.yaml \ No newline at end of file diff --git a/web/Dockerfile b/web/Dockerfile new file mode 100644 index 0000000000..315a13e896 --- /dev/null +++ b/web/Dockerfile @@ -0,0 +1,29 @@ +FROM langgenius/base:1.0.0-bullseye-slim + +LABEL maintainer="takatost@gmail.com" + +ENV EDITION SELF_HOSTED +ENV DEPLOY_ENV PRODUCTION +ENV CONSOLE_URL http://127.0.0.1:5001 +ENV APP_URL http://127.0.0.1:5001 + +EXPOSE 3000 + +WORKDIR /app/web + +COPY package.json /app/web/package.json + +RUN npm install + +COPY . /app/web/ + +RUN npm run build + +COPY docker/pm2.json /app/web/pm2.json +COPY docker/entrypoint.sh /entrypoint.sh +RUN chmod +x /entrypoint.sh + +ARG COMMIT_SHA +ENV COMMIT_SHA ${COMMIT_SHA} + +ENTRYPOINT ["/entrypoint.sh"] diff --git a/web/README.md b/web/README.md new file mode 100644 index 0000000000..710e8062d1 --- /dev/null +++ b/web/README.md @@ -0,0 +1,39 @@ +# Dify Frontend +This is a [Next.js](https://nextjs.org/) project bootstrapped with [`create-next-app`](https://github.com/vercel/next.js/tree/canary/packages/create-next-app). + +## Getting Started + +First, run the development server: + +```bash +npm run dev +# or +yarn dev +# or +pnpm dev +``` + +Open [http://localhost:3000](http://localhost:3000) with your browser to see the result. + +You can start editing the page by modifying `app/page.tsx`. The page auto-updates as you edit the file. + +[API routes](https://nextjs.org/docs/api-routes/introduction) can be accessed on [http://localhost:3000/api/hello](http://localhost:3000/api/hello). This endpoint can be edited in `pages/api/hello.ts`. + +The `pages/api` directory is mapped to `/api/*`. Files in this directory are treated as [API routes](https://nextjs.org/docs/api-routes/introduction) instead of React pages. + +This project uses [`next/font`](https://nextjs.org/docs/basic-features/font-optimization) to automatically optimize and load Inter, a custom Google Font. + +## Learn More + +To learn more about Next.js, take a look at the following resources: + +- [Next.js Documentation](https://nextjs.org/docs) - learn about Next.js features and API. +- [Learn Next.js](https://nextjs.org/learn) - an interactive Next.js tutorial. + +You can check out [the Next.js GitHub repository](https://github.com/vercel/next.js/) - your feedback and contributions are welcome! + +## Deploy on Vercel + +The easiest way to deploy your Next.js app is to use the [Vercel Platform](https://vercel.com/new?utm_medium=default-template&filter=next.js&utm_source=create-next-app&utm_campaign=create-next-app-readme) from the creators of Next.js. + +Check out our [Next.js deployment documentation](https://nextjs.org/docs/deployment) for more details. diff --git a/web/app/(commonLayout)/_layout-client.tsx b/web/app/(commonLayout)/_layout-client.tsx new file mode 100644 index 0000000000..8624091de2 --- /dev/null +++ b/web/app/(commonLayout)/_layout-client.tsx @@ -0,0 +1,85 @@ +'use client' +import type { FC } from 'react' +import React, { useEffect, useState } from 'react' +import { usePathname, useRouter, useSelectedLayoutSegments } from 'next/navigation' +import useSWR, { SWRConfig } from 'swr' +import Header from '../components/header' +import { fetchAppList } from '@/service/apps' +import { fetchDatasets } from '@/service/datasets' +import { fetchLanggeniusVersion, fetchUserProfile, logout } from '@/service/common' +import Loading from '@/app/components/base/loading' +import AppContext from '@/context/app-context' +import DatasetsContext from '@/context/datasets-context' +import type { LangGeniusVersionResponse, UserProfileResponse } from '@/models/common' + +export type ICommonLayoutProps = { + children: React.ReactNode +} + +const CommonLayout: FC = ({ children }) => { + const router = useRouter() + const pathname = usePathname() + const segments = useSelectedLayoutSegments() + const pattern = pathname.replace(/.*\/app\//, '') + const [idOrMethod] = pattern.split('/') + const isNotDetailPage = idOrMethod === 'list' + + const appId = isNotDetailPage ? '' : idOrMethod + + const { data: appList, mutate: mutateApps } = useSWR({ url: '/apps', params: { page: 1 } }, fetchAppList) + const { data: datasetList, mutate: mutateDatasets } = useSWR(segments[0] === 'datasets' ? { url: '/datasets', params: { page: 1 } } : null, fetchDatasets) + const { data: userProfileResponse, mutate: mutateUserProfile } = useSWR({ url: '/account/profile', params: {} }, fetchUserProfile) + + const [userProfile, setUserProfile] = useState() + const [langeniusVersionInfo, setLangeniusVersionInfo] = useState() + const updateUserProfileAndVersion = async () => { + if (userProfileResponse && !userProfileResponse.bodyUsed) { + const result = await userProfileResponse.json() + setUserProfile(result) + const current_version = userProfileResponse.headers.get('x-version') + const current_env = userProfileResponse.headers.get('x-env') + const versionData = await fetchLanggeniusVersion({ url: '/version', params: { current_version } }) + setLangeniusVersionInfo({ ...versionData, current_version, latest_version: versionData.version, current_env }) + } + } + useEffect(() => { + updateUserProfileAndVersion() + }, [userProfileResponse]) + + if (!appList || !userProfile || !langeniusVersionInfo) + return + + const curApp = appList?.data.find(opt => opt.id === appId) + const currentDatasetId = segments[0] === 'datasets' && segments[2] + const currentDataset = datasetList?.data?.find(opt => opt.id === currentDatasetId) + + // if (!isNotDetailPage && !curApp) { + // alert('app not found') // TODO: use toast. Now can not get toast context here. + // // notify({ type: 'error', message: 'App not found' }) + // router.push('/apps') + // } + + const onLogout = async () => { + await logout({ + url: '/logout', + params: {}, + }) + router.push('/signin') + } + + return ( + + + +
+
+ {children} +
+
+
+
+ ) +} +export default React.memo(CommonLayout) diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/configuration/page.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/configuration/page.tsx new file mode 100644 index 0000000000..41143b979a --- /dev/null +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/configuration/page.tsx @@ -0,0 +1,10 @@ +import React from 'react' +import Configuration from '@/app/components/app/configuration' + +const IConfiguration = async () => { + return ( + + ) +} + +export default IConfiguration diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/develop/page.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/develop/page.tsx new file mode 100644 index 0000000000..d544059a84 --- /dev/null +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/develop/page.tsx @@ -0,0 +1,18 @@ +import React from 'react' +import { getDictionary } from '@/i18n/server' +import { type Locale } from '@/i18n' +import DevelopMain from '@/app/components/develop' + +export type IDevelopProps = { + params: { locale: Locale; appId: string } +} + +const Develop = async ({ + params: { locale, appId }, +}: IDevelopProps) => { + const dictionary = await getDictionary(locale) + + return +} + +export default Develop diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout.tsx new file mode 100644 index 0000000000..049d908dde --- /dev/null +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout.tsx @@ -0,0 +1,57 @@ +'use client' +import type { FC } from 'react' +import React, { useEffect } from 'react' +import cn from 'classnames' +import useSWR from 'swr' +import { useTranslation } from 'react-i18next' +import { + ChartBarSquareIcon, + Cog8ToothIcon, + CommandLineIcon, + DocumentTextIcon, +} from '@heroicons/react/24/outline' +import { + ChartBarSquareIcon as ChartBarSquareSolidIcon, + Cog8ToothIcon as Cog8ToothSolidIcon, + CommandLineIcon as CommandLineSolidIcon, + DocumentTextIcon as DocumentTextSolidIcon, +} from '@heroicons/react/24/solid' +import s from './style.module.css' +import AppSideBar from '@/app/components/app-sidebar' +import { fetchAppDetail } from '@/service/apps' + +export type IAppDetailLayoutProps = { + children: React.ReactNode + params: { appId: string } +} + +const AppDetailLayout: FC = (props) => { + const { + children, + params: { appId }, // get appId in path + } = props + const { t } = useTranslation() + const detailParams = { url: '/apps', id: appId } + const { data: response } = useSWR(detailParams, fetchAppDetail) + + const navigation = [ + { name: t('common.appMenus.overview'), href: `/app/${appId}/overview`, icon: ChartBarSquareIcon, selectedIcon: ChartBarSquareSolidIcon }, + { name: t('common.appMenus.promptEng'), href: `/app/${appId}/configuration`, icon: Cog8ToothIcon, selectedIcon: Cog8ToothSolidIcon }, + { name: t('common.appMenus.apiAccess'), href: `/app/${appId}/develop`, icon: CommandLineIcon, selectedIcon: CommandLineSolidIcon }, + { name: t('common.appMenus.logAndAnn'), href: `/app/${appId}/logs`, icon: DocumentTextIcon, selectedIcon: DocumentTextSolidIcon }, + ] + const appModeName = response?.mode?.toUpperCase() === 'COMPLETION' ? t('common.appModes.completionApp') : t('common.appModes.chatApp') + useEffect(() => { + if (response?.name) + document.title = `${(response.name || 'App')} - Dify` + }, [response]) + if (!response) + return null + return ( +
+ +
{children}
+
+ ) +} +export default React.memo(AppDetailLayout) diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/logs/page.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/logs/page.tsx new file mode 100644 index 0000000000..6c17986ce6 --- /dev/null +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/logs/page.tsx @@ -0,0 +1,16 @@ +import React from 'react' +import Main from '@/app/components/app/log' + +export type IProps = { + params: { appId: string } +} + +const Logs = async ({ + params: { appId }, +}: IProps) => { + return ( +
+ ) +} + +export default Logs diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/cardView.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/cardView.tsx new file mode 100644 index 0000000000..ff9388e147 --- /dev/null +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/cardView.tsx @@ -0,0 +1,86 @@ +'use client' +import type { FC } from 'react' +import React from 'react' +import { useTranslation } from 'react-i18next' +import { useContext } from 'use-context-selector' +import useSWR, { useSWRConfig } from 'swr' +import AppCard from '@/app/components/app/overview/appCard' +import Loading from '@/app/components/base/loading' +import { ToastContext } from '@/app/components/base/toast' +import { fetchAppDetail, updateAppApiStatus, updateAppSiteAccessToken, updateAppSiteConfig, updateAppSiteStatus } from '@/service/apps' +import type { IToastProps } from '@/app/components/base/toast' +import type { App } from '@/types/app' + +export type ICardViewProps = { + appId: string +} + +type IParams = { + url: string + body?: Record +} + +export async function asyncRunSafe(func: (val: IParams) => Promise, params: IParams, callback: (props: IToastProps) => void, dict?: any): Promise<[string?, T?]> { + try { + const res = await func(params) + callback && callback({ type: 'success', message: dict('common.actionMsg.modifiedSuccessfully') }) + return [undefined, res] + } + catch (err) { + callback && callback({ type: 'error', message: dict('common.actionMsg.modificationFailed') }) + return [(err as Error).message, undefined] + } +} + +const CardView: FC = ({ appId }) => { + const detailParams = { url: '/apps', id: appId } + const { data: response } = useSWR(detailParams, fetchAppDetail) + const { mutate } = useSWRConfig() + const { notify } = useContext(ToastContext) + const { t } = useTranslation() + + if (!response) + return + + const onChangeSiteStatus = async (value: boolean) => { + const [err] = await asyncRunSafe(updateAppSiteStatus as any, { url: `/apps/${appId}/site-enable`, body: { enable_site: value } }, notify, t) + if (!err) + mutate(detailParams) + } + + const onChangeApiStatus = async (value: boolean) => { + const [err] = await asyncRunSafe(updateAppApiStatus as any, { url: `/apps/${appId}/api-enable`, body: { enable_api: value } }, notify, t) + if (!err) + mutate(detailParams) + } + + const onSaveSiteConfig = async (params: any) => { + const [err] = await asyncRunSafe(updateAppSiteConfig as any, { url: `/apps/${appId}/site`, body: params }, notify, t) + if (!err) + mutate(detailParams) + } + + const onGenerateCode = async () => { + const [err] = await asyncRunSafe(updateAppSiteAccessToken as any, { url: `/apps/${appId}/site/access-token-reset` }, notify, t) + if (!err) + mutate(detailParams) + } + + return ( +
+ + +
+ ) +} + +export default CardView diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/chartView.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/chartView.tsx new file mode 100644 index 0000000000..251c29ad1a --- /dev/null +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/chartView.tsx @@ -0,0 +1,52 @@ +'use client' +import React, { useState } from 'react' +import dayjs from 'dayjs' +import quarterOfYear from 'dayjs/plugin/quarterOfYear' +import { useTranslation } from 'react-i18next' +import type { PeriodParams } from '@/app/components/app/overview/appChart' +import { ConversationsChart, CostChart, EndUsersChart } from '@/app/components/app/overview/appChart' +import type { Item } from '@/app/components/base/select' +import { SimpleSelect } from '@/app/components/base/select' +import { TIME_PERIOD_LIST } from '@/app/components/app/log/filter' + +dayjs.extend(quarterOfYear) + +const today = dayjs() + +const queryDateFormat = 'YYYY-MM-DD HH:mm' + +export type IChartViewProps = { + appId: string +} + +export default function ChartView({ appId }: IChartViewProps) { + const { t } = useTranslation() + const [period, setPeriod] = useState({ name: t('appLog.filter.period.last7days'), query: { start: today.subtract(7, 'day').format(queryDateFormat), end: today.format(queryDateFormat) } }) + + const onSelect = (item: Item) => { + setPeriod({ name: item.name, query: { start: today.subtract(item.value as number, 'day').format(queryDateFormat), end: today.format(queryDateFormat) } }) + } + + return ( +
+
+ {t('appOverview.analysis.title')} + ({ value: item.value, name: t(`appLog.filter.period.${item.name}`) }))} + className='mt-0 !w-40' + onSelect={onSelect} + defaultValue={7} + /> +
+
+
+ +
+
+ +
+
+ +
+ ) +} diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/page.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/page.tsx new file mode 100644 index 0000000000..1b80c0a0c7 --- /dev/null +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/page.tsx @@ -0,0 +1,30 @@ +import React from 'react' +import WelcomeBanner, { EditKeyPopover } from './welcome-banner' +import ChartView from './chartView' +import CardView from './cardView' +import { getLocaleOnServer } from '@/i18n/server' +import { useTranslation } from '@/i18n/i18next-serverside-config' + +export type IDevelopProps = { + params: { appId: string } +} + +const Overview = async ({ + params: { appId }, +}: IDevelopProps) => { + const locale = getLocaleOnServer() + const { t } = await useTranslation(locale, 'app-overview') + return ( +
+ +
+ {t('overview.title')} + +
+ + +
+ ) +} + +export default Overview diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/welcome-banner.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/welcome-banner.tsx new file mode 100644 index 0000000000..9b323eaa4b --- /dev/null +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/welcome-banner.tsx @@ -0,0 +1,200 @@ +'use client' +import type { FC } from 'react' +import React, { useState } from 'react' +import { useContext } from 'use-context-selector' +import { useTranslation } from 'react-i18next' +import Link from 'next/link' +import useSWR, { useSWRConfig } from 'swr' +import { ArrowTopRightOnSquareIcon } from '@heroicons/react/24/outline' +import { ExclamationCircleIcon } from '@heroicons/react/24/solid' +import { debounce } from 'lodash-es' +import Popover from '@/app/components/base/popover' +import Button from '@/app/components/base/button' +import Tag from '@/app/components/base/tag' +import { ToastContext } from '@/app/components/base/toast' +import { updateOpenAIKey, validateOpenAIKey } from '@/service/apps' +import { fetchTenantInfo } from '@/service/common' +import I18n from '@/context/i18n' + +type IStatusType = 'normal' | 'verified' | 'error' | 'error-api-key-exceed-bill' + +const STATUS_COLOR_MAP = { + normal: { color: '', bgColor: 'bg-primary-50', borderColor: 'border-primary-100' }, + error: { color: 'text-red-600', bgColor: 'bg-red-50', borderColor: 'border-red-100' }, + verified: { color: '', bgColor: 'bg-green-50', borderColor: 'border-green-100' }, + 'error-api-key-exceed-bill': { color: 'text-red-600', bgColor: 'bg-red-50', borderColor: 'border-red-100' }, +} + +const CheckCircleIcon: FC<{ className?: string }> = ({ className }) => { + return + + + +} + +type IEditKeyDiv = { + className?: string + showInPopover?: boolean + onClose?: () => void + getTenantInfo?: () => void +} + +const EditKeyDiv: FC = ({ className = '', showInPopover = false, onClose, getTenantInfo }) => { + const [inputValue, setInputValue] = useState() + const [editStatus, setEditStatus] = useState('normal') + const [loading, setLoading] = useState(false) + const [validating, setValidating] = useState(false) + const { notify } = useContext(ToastContext) + const { t } = useTranslation() + const { locale } = useContext(I18n) + + // Hide the pop-up window and need to get the latest key again + // If the key is valid, the edit button will be hidden later + const onClosePanel = () => { + getTenantInfo && getTenantInfo() + onClose && onClose() + } + + const onSaveKey = async () => { + if (editStatus === 'verified') { + setLoading(true) + try { + await updateOpenAIKey({ url: '/providers/openai/token', body: { token: inputValue ?? '' } }) + notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) + onClosePanel() + } + catch (err) { + notify({ type: 'error', message: t('common.actionMsg.modificationFailed') }) + } + finally { + setLoading(false) + } + } + } + + const validateKey = async (value: string) => { + try { + setValidating(true) + const res = await validateOpenAIKey({ url: '/providers/openai/token-validate', body: { token: value ?? '' } }) + setEditStatus(res.result === 'success' ? 'verified' : 'error') + } + catch (err: any) { + if (err.status === 400) { + err.json().then(({ code }: any) => { + if (code === 'provider_request_failed') { + setEditStatus('error-api-key-exceed-bill') + } + }) + } else { + setEditStatus('error') + } + } + finally { + setValidating(false) + } + } + const renderErrorMessage = () => { + if (validating) { + return ( +
+ {t('common.provider.validating')} +
+ ) + } + if (editStatus === 'error-api-key-exceed-bill') { + return ( +
+ {t('common.provider.apiKeyExceedBill')} + {locale === 'en' ? ' ' : ''} + + {locale === 'en' ? 'this link' : '这篇文档'} + +
+ ) + } + if (editStatus === 'error') { + return ( +
+ {t('common.provider.invalidKey')} +
+ ) + } + return null + } + + return ( +
+ {!showInPopover &&

{t('appOverview.welcome.firstStepTip')}

} +

{t('appOverview.welcome.enterKeyTip')} {showInPopover ? '' : '👇'}

+
+ { + setInputValue(e.target.value) + if (!e.target.value) { + setEditStatus('normal') + return + } + validateKey(e.target.value) + }, 300)} + /> + {editStatus === 'verified' &&
+ +
} + {(editStatus === 'error' || editStatus === 'error-api-key-exceed-bill') &&
+ +
} + {showInPopover ? null : } +
+ {renderErrorMessage()} + + {t('appOverview.welcome.getKeyTip')} +
+ ) +} + +const WelcomeBanner: FC = () => { + const { data: userInfo } = useSWR({ url: '/info' }, fetchTenantInfo) + if (!userInfo) + return null + return userInfo?.providers?.find(({ token_is_set }) => token_is_set) ? null : +} + +export const EditKeyPopover: FC = () => { + const { data: userInfo } = useSWR({ url: '/info' }, fetchTenantInfo) + const { mutate } = useSWRConfig() + if (!userInfo) + return null + + const getTenantInfo = () => { + mutate({ url: '/info' }) + } + // In this case, the edit button is displayed + const targetProvider = userInfo?.providers?.some(({ token_is_set, is_valid }) => token_is_set && is_valid) + return ( + !targetProvider + ?
+ OpenAI API key invalid + } + trigger='click' + position='br' + btnElement='Edit' + btnClassName='text-primary-600 !text-xs px-3 py-1.5' + className='!p-0 !w-[464px] h-[200px]' + /> +
+ : null) +} + +export default WelcomeBanner diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/style.module.css b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/style.module.css new file mode 100644 index 0000000000..2d6b3bd15c --- /dev/null +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/style.module.css @@ -0,0 +1,5 @@ +.app { + height: calc(100vh - 56px); + border-radius: 16px 16px 0px 0px; + box-shadow: 0px 0px 5px rgba(0, 0, 0, 0.05), 0px 0px 2px -1px rgba(0, 0, 0, 0.03); +} \ No newline at end of file diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/layout.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/layout.tsx new file mode 100644 index 0000000000..b70959ea79 --- /dev/null +++ b/web/app/(commonLayout)/app/(appDetailLayout)/layout.tsx @@ -0,0 +1,16 @@ +import type { FC } from 'react' +import React from 'react' + +export type IAppDetail = { + children: React.ReactNode +} + +const AppDetail: FC = ({ children }) => { + return ( + <> + {children} + + ) +} + +export default React.memo(AppDetail) diff --git a/web/app/(commonLayout)/apps/AppCard.tsx b/web/app/(commonLayout)/apps/AppCard.tsx new file mode 100644 index 0000000000..eb62cd8899 --- /dev/null +++ b/web/app/(commonLayout)/apps/AppCard.tsx @@ -0,0 +1,76 @@ +'use client' + +import { useContext, useContextSelector } from 'use-context-selector' +import Link from 'next/link' +import type { MouseEventHandler } from 'react' +import { useCallback, useState } from 'react' +import { useTranslation } from 'react-i18next' +import style from '../list.module.css' +import AppModeLabel from './AppModeLabel' +import type { App } from '@/types/app' +import Confirm from '@/app/components/base/confirm' +import { ToastContext } from '@/app/components/base/toast' +import { deleteApp } from '@/service/apps' +import AppIcon from '@/app/components/base/app-icon' +import AppsContext from '@/context/app-context' + +export type AppCardProps = { + app: App +} + +const AppCard = ({ + app, +}: AppCardProps) => { + const { t } = useTranslation() + const { notify } = useContext(ToastContext) + + const mutateApps = useContextSelector(AppsContext, state => state.mutateApps) + + const [showConfirmDelete, setShowConfirmDelete] = useState(false) + const onDeleteClick: MouseEventHandler = useCallback((e) => { + e.preventDefault() + setShowConfirmDelete(true) + }, []) + const onConfirmDelete = useCallback(async () => { + try { + await deleteApp(app.id) + notify({ type: 'success', message: t('app.appDeleted') }) + mutateApps() + } + catch (e: any) { + notify({ type: 'error', message: `${t('app.appDeleteFailed')}${'message' in e ? `: ${e.message}` : ''}` }) + } + setShowConfirmDelete(false) + }, [app.id]) + + return ( + <> + +
+ +
+
{app.name}
+
+ +
+
{app.model_config?.pre_prompt}
+
+ +
+ + {showConfirmDelete && ( + setShowConfirmDelete(false)} + onConfirm={onConfirmDelete} + onCancel={() => setShowConfirmDelete(false)} + /> + )} + + + ) +} + +export default AppCard diff --git a/web/app/(commonLayout)/apps/AppModeLabel.tsx b/web/app/(commonLayout)/apps/AppModeLabel.tsx new file mode 100644 index 0000000000..f223c01981 --- /dev/null +++ b/web/app/(commonLayout)/apps/AppModeLabel.tsx @@ -0,0 +1,26 @@ +'use client' + +import classNames from 'classnames' +import { useTranslation } from 'react-i18next' +import { type AppMode } from '@/types/app' +import style from '../list.module.css' + +export type AppModeLabelProps = { + mode: AppMode + className?: string +} + +const AppModeLabel = ({ + mode, + className, +}: AppModeLabelProps) => { + const { t } = useTranslation() + return ( + + + {t(`app.modes.${mode}`)} + + ) +} + +export default AppModeLabel diff --git a/web/app/(commonLayout)/apps/Apps.tsx b/web/app/(commonLayout)/apps/Apps.tsx new file mode 100644 index 0000000000..b11b0da6a0 --- /dev/null +++ b/web/app/(commonLayout)/apps/Apps.tsx @@ -0,0 +1,23 @@ +'use client' + +import { useEffect } from 'react' +import AppCard from './AppCard' +import NewAppCard from './NewAppCard' +import { useAppContext } from '@/context/app-context' + +const Apps = () => { + const { apps, mutateApps } = useAppContext() + + useEffect(() => { + mutateApps() + }, []) + + return ( + + ) +} + +export default Apps diff --git a/web/app/(commonLayout)/apps/NewAppCard.tsx b/web/app/(commonLayout)/apps/NewAppCard.tsx new file mode 100644 index 0000000000..7fee93534e --- /dev/null +++ b/web/app/(commonLayout)/apps/NewAppCard.tsx @@ -0,0 +1,29 @@ +'use client' + +import { useState } from 'react' +import classNames from 'classnames' +import { useTranslation } from 'react-i18next' +import style from '../list.module.css' +import NewAppDialog from './NewAppDialog' + +const CreateAppCard = () => { + const { t } = useTranslation() + const [showNewAppDialog, setShowNewAppDialog] = useState(false) + + return ( + setShowNewAppDialog(true)}> +
+ + + +
+ {t('app.createApp')} +
+
+ {/*
{t('app.createFromConfigFile')}
*/} + setShowNewAppDialog(false)} /> +
+ ) +} + +export default CreateAppCard diff --git a/web/app/(commonLayout)/apps/NewAppDialog.tsx b/web/app/(commonLayout)/apps/NewAppDialog.tsx new file mode 100644 index 0000000000..3b434fa3b2 --- /dev/null +++ b/web/app/(commonLayout)/apps/NewAppDialog.tsx @@ -0,0 +1,193 @@ +'use client' + +import type { MouseEventHandler } from 'react' +import { useCallback, useEffect, useRef, useState } from 'react' +import useSWR from 'swr' +import classNames from 'classnames' +import { useRouter } from 'next/navigation' +import { useContext, useContextSelector } from 'use-context-selector' +import { useTranslation } from 'react-i18next' +import style from '../list.module.css' +import AppModeLabel from './AppModeLabel' +import Button from '@/app/components/base/button' +import Dialog from '@/app/components/base/dialog' +import type { AppMode } from '@/types/app' +import { ToastContext } from '@/app/components/base/toast' +import { createApp, fetchAppTemplates } from '@/service/apps' +import AppIcon from '@/app/components/base/app-icon' +import AppsContext from '@/context/app-context' + +type NewAppDialogProps = { + show: boolean + onClose?: () => void +} + +const NewAppDialog = ({ show, onClose }: NewAppDialogProps) => { + const router = useRouter() + const { notify } = useContext(ToastContext) + const { t } = useTranslation() + + const nameInputRef = useRef(null) + const [newAppMode, setNewAppMode] = useState() + const [isWithTemplate, setIsWithTemplate] = useState(false) + const [selectedTemplateIndex, setSelectedTemplateIndex] = useState(-1) + const mutateApps = useContextSelector(AppsContext, state => state.mutateApps) + + const { data: templates, mutate } = useSWR({ url: '/app-templates' }, fetchAppTemplates) + const mutateTemplates = useCallback( + () => mutate(), + [], + ) + + useEffect(() => { + if (show) { + mutateTemplates() + setIsWithTemplate(false) + } + }, [show]) + + const isCreatingRef = useRef(false) + const onCreate: MouseEventHandler = useCallback(async () => { + const name = nameInputRef.current?.value + if (!name) { + notify({ type: 'error', message: t('app.newApp.nameNotEmpty') }) + return + } + if (!templates || (isWithTemplate && !(selectedTemplateIndex > -1))) { + notify({ type: 'error', message: t('app.newApp.appTemplateNotSelected') }) + return + } + if (!isWithTemplate && !newAppMode) { + notify({ type: 'error', message: t('app.newApp.appTypeRequired') }) + return + } + if (isCreatingRef.current) + return + isCreatingRef.current = true + try { + const app = await createApp({ + name, + mode: isWithTemplate ? templates.data[selectedTemplateIndex].mode : newAppMode!, + config: isWithTemplate ? templates.data[selectedTemplateIndex].model_config : undefined, + }) + if (onClose) + onClose() + notify({ type: 'success', message: t('app.newApp.appCreated') }) + mutateApps() + router.push(`/app/${app.id}/overview`) + } + catch (e) { + notify({ type: 'error', message: t('app.newApp.appCreateFailed') }) + } + isCreatingRef.current = false + }, [isWithTemplate, newAppMode, notify, router, templates, selectedTemplateIndex]) + + return ( + + + + + } + > +

{t('app.newApp.captionName')}

+ +
+ + +
+ +
+
+

{t('app.newApp.captionAppType')}

+ {isWithTemplate && ( + <> + + setIsWithTemplate(false)} + > + {t('app.newApp.hideTemplates')} + + + )} +
+ {isWithTemplate + ? ( +
    + {templates?.data?.map((template, index) => ( +
  • setSelectedTemplateIndex(index)} + > +
    + +
    +
    {template.name}
    +
    +
    +
    {template.model_config?.pre_prompt}
    + + {/* */} +
  • + ))} +
+ ) + : ( + <> +
    +
  • setNewAppMode('chat')} + > +
    + + + +
    +
    {t('app.newApp.chatApp')}
    +
    +
    +
    {t('app.newApp.chatAppIntro')}
    + +
  • +
  • setNewAppMode('completion')} + > +
    + + + +
    +
    {t('app.newApp.completeApp')}
    +
    +
    +
    {t('app.newApp.completeAppIntro')}
    + +
  • +
+
+ setIsWithTemplate(true)} + > + {t('app.newApp.showTemplates')} + +
+ + )} +
+
+ ) +} + +export default NewAppDialog diff --git a/web/app/(commonLayout)/apps/assets/add.svg b/web/app/(commonLayout)/apps/assets/add.svg new file mode 100644 index 0000000000..9958e855aa --- /dev/null +++ b/web/app/(commonLayout)/apps/assets/add.svg @@ -0,0 +1,3 @@ + + + diff --git a/web/app/(commonLayout)/apps/assets/chat-solid.svg b/web/app/(commonLayout)/apps/assets/chat-solid.svg new file mode 100644 index 0000000000..a793e982c0 --- /dev/null +++ b/web/app/(commonLayout)/apps/assets/chat-solid.svg @@ -0,0 +1,4 @@ + + + + diff --git a/web/app/(commonLayout)/apps/assets/chat.svg b/web/app/(commonLayout)/apps/assets/chat.svg new file mode 100644 index 0000000000..0971349a53 --- /dev/null +++ b/web/app/(commonLayout)/apps/assets/chat.svg @@ -0,0 +1,3 @@ + + + diff --git a/web/app/(commonLayout)/apps/assets/completion-solid.svg b/web/app/(commonLayout)/apps/assets/completion-solid.svg new file mode 100644 index 0000000000..a9dc7e3dc1 --- /dev/null +++ b/web/app/(commonLayout)/apps/assets/completion-solid.svg @@ -0,0 +1,4 @@ + + + + diff --git a/web/app/(commonLayout)/apps/assets/completion.svg b/web/app/(commonLayout)/apps/assets/completion.svg new file mode 100644 index 0000000000..34af4417fe --- /dev/null +++ b/web/app/(commonLayout)/apps/assets/completion.svg @@ -0,0 +1,3 @@ + + + diff --git a/web/app/(commonLayout)/apps/assets/delete.svg b/web/app/(commonLayout)/apps/assets/delete.svg new file mode 100644 index 0000000000..fcd60cf7dd --- /dev/null +++ b/web/app/(commonLayout)/apps/assets/delete.svg @@ -0,0 +1,3 @@ + + + diff --git a/web/app/(commonLayout)/apps/assets/discord.svg b/web/app/(commonLayout)/apps/assets/discord.svg new file mode 100644 index 0000000000..9f22a1ab59 --- /dev/null +++ b/web/app/(commonLayout)/apps/assets/discord.svg @@ -0,0 +1,3 @@ + + + diff --git a/web/app/(commonLayout)/apps/assets/github.svg b/web/app/(commonLayout)/apps/assets/github.svg new file mode 100644 index 0000000000..f03798b5e1 --- /dev/null +++ b/web/app/(commonLayout)/apps/assets/github.svg @@ -0,0 +1,17 @@ + + + + + + + + + + + + + + + + + diff --git a/web/app/(commonLayout)/apps/assets/link-gray.svg b/web/app/(commonLayout)/apps/assets/link-gray.svg new file mode 100644 index 0000000000..a293cfcf53 --- /dev/null +++ b/web/app/(commonLayout)/apps/assets/link-gray.svg @@ -0,0 +1,3 @@ + + + diff --git a/web/app/(commonLayout)/apps/assets/link.svg b/web/app/(commonLayout)/apps/assets/link.svg new file mode 100644 index 0000000000..2926c28b16 --- /dev/null +++ b/web/app/(commonLayout)/apps/assets/link.svg @@ -0,0 +1,3 @@ + + + diff --git a/web/app/(commonLayout)/apps/assets/right-arrow.svg b/web/app/(commonLayout)/apps/assets/right-arrow.svg new file mode 100644 index 0000000000..a2c1cedf95 --- /dev/null +++ b/web/app/(commonLayout)/apps/assets/right-arrow.svg @@ -0,0 +1,3 @@ + + + diff --git a/web/app/(commonLayout)/apps/page.tsx b/web/app/(commonLayout)/apps/page.tsx new file mode 100644 index 0000000000..bc5a0b583f --- /dev/null +++ b/web/app/(commonLayout)/apps/page.tsx @@ -0,0 +1,36 @@ +import classNames from 'classnames' +import style from '../list.module.css' +import Apps from './Apps' +import { getLocaleOnServer } from '@/i18n/server' +import { useTranslation } from '@/i18n/i18next-serverside-config' + +const AppList = async () => { + const locale = getLocaleOnServer() + const { t } = await useTranslation(locale, 'app') + + return ( +
+ + +
+ ) +} + +export const metadata = { + title: 'Apps - Dify', +} + +export default AppList diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/api/page.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/api/page.tsx new file mode 100644 index 0000000000..f1b20dd8d8 --- /dev/null +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/api/page.tsx @@ -0,0 +1,11 @@ +import React from 'react' + +type Props = {} + +const page = (props: Props) => { + return ( +
dataset detail api
+ ) +} + +export default page diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/documents/[documentId]/page.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/documents/[documentId]/page.tsx new file mode 100644 index 0000000000..2bd2a356cc --- /dev/null +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/documents/[documentId]/page.tsx @@ -0,0 +1,16 @@ +import React from 'react' +import MainDetail from '@/app/components/datasets/documents/detail' + +export type IDocumentDetailProps = { + params: { datasetId: string; documentId: string } +} + +const DocumentDetail = async ({ + params: { datasetId, documentId }, +}: IDocumentDetailProps) => { + return ( + + ) +} + +export default DocumentDetail diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/documents/create/page.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/documents/create/page.tsx new file mode 100644 index 0000000000..e249632bfa --- /dev/null +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/documents/create/page.tsx @@ -0,0 +1,16 @@ +import React from 'react' +import DatasetUpdateForm from '@/app/components/datasets/create' + +export type IProps = { + params: { datasetId: string } +} + +const Create = async ({ + params: { datasetId }, +}: IProps) => { + return ( + + ) +} + +export default Create diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/documents/page.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/documents/page.tsx new file mode 100644 index 0000000000..545e9ed378 --- /dev/null +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/documents/page.tsx @@ -0,0 +1,16 @@ +import React from 'react' +import Main from '@/app/components/datasets/documents' + +export type IProps = { + params: { datasetId: string } +} + +const Documents = async ({ + params: { datasetId }, +}: IProps) => { + return ( +
+ ) +} + +export default Documents diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/documents/style.module.css b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/documents/style.module.css new file mode 100644 index 0000000000..67a9fe3bf5 --- /dev/null +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/documents/style.module.css @@ -0,0 +1,9 @@ +.logTable td { + padding: 7px 8px; + box-sizing: border-box; + max-width: 200px; +} + +.pagination li { + list-style: none; +} diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/hitTesting/page.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/hitTesting/page.tsx new file mode 100644 index 0000000000..bec07e41b9 --- /dev/null +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/hitTesting/page.tsx @@ -0,0 +1,16 @@ +import React from 'react' +import Main from '@/app/components/datasets/hit-testing' + +type Props = { + params: { datasetId: string } +} + +const HitTesting = ({ + params: { datasetId }, +}: Props) => { + return ( +
+ ) +} + +export default HitTesting diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout.tsx new file mode 100644 index 0000000000..1dc6578977 --- /dev/null +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout.tsx @@ -0,0 +1,169 @@ +'use client' +import type { FC } from 'react' +import React, { useEffect } from 'react' +import { usePathname, useSelectedLayoutSegments } from 'next/navigation' +import useSWR from 'swr' +import { useTranslation } from 'react-i18next' +import { getLocaleOnClient } from '@/i18n/client' +import { + Cog8ToothIcon, + // CommandLineIcon, + Squares2X2Icon, + PuzzlePieceIcon, + DocumentTextIcon, +} from '@heroicons/react/24/outline' +import { + Cog8ToothIcon as Cog8ToothSolidIcon, + // CommandLineIcon as CommandLineSolidIcon, + DocumentTextIcon as DocumentTextSolidIcon, +} from '@heroicons/react/24/solid' +import Link from 'next/link' +import { fetchDataDetail, fetchDatasetRelatedApps } from '@/service/datasets' +import type { RelatedApp } from '@/models/datasets' +import s from './style.module.css' +import AppSideBar from '@/app/components/app-sidebar' +import Divider from '@/app/components/base/divider' +import Indicator from '@/app/components/header/indicator' +import AppIcon from '@/app/components/base/app-icon' +import Loading from '@/app/components/base/loading' +import DatasetDetailContext from '@/context/dataset-detail' + +// import { fetchDatasetDetail } from '@/service/datasets' + +export type IAppDetailLayoutProps = { + children: React.ReactNode + params: { datasetId: string } +} + +const LikedItem: FC<{ type?: 'plugin' | 'app'; appStatus?: boolean; detail: RelatedApp }> = ({ + type = 'app', + appStatus = true, + detail +}) => { + return ( + +
+ + {type === 'app' && ( +
+ +
+ )} +
+
{detail?.name || '--'}
+ + ) +} + +const TargetIcon: FC<{ className?: string }> = ({ className }) => { + return + + + + + + + + + +} + +const TargetSolidIcon: FC<{ className?: string }> = ({ className }) => { + return + + + + +} + +const BookOpenIcon: FC<{ className?: string }> = ({ className }) => { + return + + + +} + +const DatasetDetailLayout: FC = (props) => { + const { + children, + params: { datasetId }, + } = props + const pathname = usePathname() + const hideSideBar = /documents\/create$/.test(pathname) + const { t } = useTranslation() + const { data: datasetRes, error } = useSWR({ + action: 'fetchDataDetail', + datasetId, + }, apiParams => fetchDataDetail(apiParams.datasetId)) + + const { data: relatedApps } = useSWR({ + action: 'fetchDatasetRelatedApps', + datasetId, + }, apiParams => fetchDatasetRelatedApps(apiParams.datasetId)) + + const navigation = [ + { name: t('common.datasetMenus.documents'), href: `/datasets/${datasetId}/documents`, icon: DocumentTextIcon, selectedIcon: DocumentTextSolidIcon }, + { name: t('common.datasetMenus.hitTesting'), href: `/datasets/${datasetId}/hitTesting`, icon: TargetIcon, selectedIcon: TargetSolidIcon }, + // { name: 'api & webhook', href: `/datasets/${datasetId}/api`, icon: CommandLineIcon, selectedIcon: CommandLineSolidIcon }, + { name: t('common.datasetMenus.settings'), href: `/datasets/${datasetId}/settings`, icon: Cog8ToothIcon, selectedIcon: Cog8ToothSolidIcon }, + ] + + useEffect(() => { + if (datasetRes) { + document.title = `${datasetRes.name || 'Dataset'} - Dify` + } + }, [datasetRes]) + + const ExtraInfo: FC = () => { + const locale = getLocaleOnClient() + + return
+ + {relatedApps?.data?.length ? ( + <> +
{relatedApps?.total || '--'} {t('common.datasetMenus.relatedApp')}
+ {relatedApps?.data?.map((item) => ())} + + ) : ( +
+
+
+ +
+
+ +
+
+
{t('common.datasetMenus.emptyTip')}
+ + + {t('common.datasetMenus.viewDoc')} + +
+ )} +
+ } + + if (!datasetRes && !error) + return + + return ( +
+ {!hideSideBar && } + iconType='dataset' + />} + +
{children}
+
+
+ ) +} +export default React.memo(DatasetDetailLayout) diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/settings/page.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/settings/page.tsx new file mode 100644 index 0000000000..23863881c3 --- /dev/null +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/settings/page.tsx @@ -0,0 +1,23 @@ +import React from 'react' +import { getLocaleOnServer } from '@/i18n/server' +import { useTranslation } from '@/i18n/i18next-serverside-config' +import Form from '@/app/components/datasets/settings/form' + +const Settings = async () => { + const locale = getLocaleOnServer() + const { t } = await useTranslation(locale, 'dataset-settings') + + return ( +
+
+
{t('title')}
+
{t('desc')}
+
+
+
+
+
+ ) +} + +export default Settings diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/style.module.css b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/style.module.css new file mode 100644 index 0000000000..786726703f --- /dev/null +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/style.module.css @@ -0,0 +1,18 @@ +.itemWrapper { + @apply flex items-center w-full h-10 px-3 rounded-lg hover:bg-gray-50 cursor-pointer; +} +.appInfo { + @apply truncate text-gray-700 text-sm font-normal; +} +.iconWrapper { + @apply relative w-6 h-6 mr-2 bg-[#D5F5F6] rounded-md; +} +.statusPoint { + @apply flex justify-center items-center absolute -right-0.5 -bottom-0.5 w-2.5 h-2.5 bg-white rounded; +} +.subTitle { + @apply uppercase text-xs text-gray-500 font-medium px-3 pb-2 pt-4; +} +.emptyIconDiv { + @apply h-7 w-7 bg-gray-50 border border-[#EAECF5] inline-flex justify-center items-center rounded-lg; +} diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/layout.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/layout.tsx new file mode 100644 index 0000000000..ccbc58f5e5 --- /dev/null +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/layout.tsx @@ -0,0 +1,16 @@ +import type { FC } from 'react' +import React from 'react' + +export type IDatasetDetail = { + children: React.ReactNode +} + +const AppDetail: FC = ({ children }) => { + return ( + <> + {children} + + ) +} + +export default React.memo(AppDetail) diff --git a/web/app/(commonLayout)/datasets/DatasetCard.tsx b/web/app/(commonLayout)/datasets/DatasetCard.tsx new file mode 100644 index 0000000000..b6786d0519 --- /dev/null +++ b/web/app/(commonLayout)/datasets/DatasetCard.tsx @@ -0,0 +1,89 @@ +'use client' + +import { useContext, useContextSelector } from 'use-context-selector' +import Link from 'next/link' +import useSWR from 'swr' +import type { MouseEventHandler } from 'react' +import { useCallback, useState } from 'react' +import { useTranslation } from 'react-i18next' +import style from '../list.module.css' +import type { App } from '@/types/app' +import Confirm from '@/app/components/base/confirm' +import { ToastContext } from '@/app/components/base/toast' +import { deleteDataset, fetchDatasets } from '@/service/datasets' +import AppIcon from '@/app/components/base/app-icon' +import AppsContext from '@/context/app-context' +import { DataSet } from '@/models/datasets' +import classNames from 'classnames' + +export type DatasetCardProps = { + dataset: DataSet +} + +const DatasetCard = ({ + dataset, +}: DatasetCardProps) => { + const { t } = useTranslation() + const { notify } = useContext(ToastContext) + + const { mutate: mutateDatasets } = useSWR({ url: '/datasets', params: { page: 1 } }, fetchDatasets) + + const [showConfirmDelete, setShowConfirmDelete] = useState(false) + const onDeleteClick: MouseEventHandler = useCallback((e) => { + e.preventDefault() + setShowConfirmDelete(true) + }, []) + const onConfirmDelete = useCallback(async () => { + try { + await deleteDataset(dataset.id) + notify({ type: 'success', message: t('dataset.datasetDeleted') }) + mutateDatasets() + } + catch (e: any) { + notify({ type: 'error', message: `${t('dataset.datasetDeleteFailed')}${'message' in e ? `: ${e.message}` : ''}` }) + } + setShowConfirmDelete(false) + }, [dataset.id]) + + return ( + <> + +
+ +
+
{dataset.name}
+
+ +
+
{dataset.description}
+
+ + + {dataset.document_count}{t('dataset.documentCount')} + + + + {Math.round(dataset.word_count / 1000)}{t('dataset.wordCount')} + + + + {dataset.app_count}{t('dataset.appCount')} + +
+ + {showConfirmDelete && ( + setShowConfirmDelete(false)} + onConfirm={onConfirmDelete} + onCancel={() => setShowConfirmDelete(false)} + /> + )} + + + ) +} + +export default DatasetCard diff --git a/web/app/(commonLayout)/datasets/DatasetFooter.tsx b/web/app/(commonLayout)/datasets/DatasetFooter.tsx new file mode 100644 index 0000000000..020bfceffe --- /dev/null +++ b/web/app/(commonLayout)/datasets/DatasetFooter.tsx @@ -0,0 +1,19 @@ +'use client' + +import { useTranslation } from "react-i18next" + +const DatasetFooter = () => { + const { t } = useTranslation() + + return ( + + ) +} + +export default DatasetFooter diff --git a/web/app/(commonLayout)/datasets/Datasets.tsx b/web/app/(commonLayout)/datasets/Datasets.tsx new file mode 100644 index 0000000000..b044547748 --- /dev/null +++ b/web/app/(commonLayout)/datasets/Datasets.tsx @@ -0,0 +1,27 @@ +'use client' + +import { useEffect } from 'react' +import useSWR from 'swr' +import { DataSet } from '@/models/datasets'; +import NewDatasetCard from './NewDatasetCard' +import DatasetCard from './DatasetCard'; +import { fetchDatasets } from '@/service/datasets'; + +const Datasets = () => { + // const { datasets, mutateDatasets } = useAppContext() + const { data: datasetList, mutate: mutateDatasets } = useSWR({ url: '/datasets', params: { page: 1 } }, fetchDatasets) + + useEffect(() => { + mutateDatasets() + }, []) + + return ( + + ) +} + +export default Datasets + diff --git a/web/app/(commonLayout)/datasets/NewDatasetCard.tsx b/web/app/(commonLayout)/datasets/NewDatasetCard.tsx new file mode 100644 index 0000000000..a3f6282c97 --- /dev/null +++ b/web/app/(commonLayout)/datasets/NewDatasetCard.tsx @@ -0,0 +1,28 @@ +'use client' + +import { useState } from 'react' +import classNames from 'classnames' +import { useTranslation } from 'react-i18next' +import style from '../list.module.css' + +const CreateAppCard = () => { + const { t } = useTranslation() + const [showNewAppDialog, setShowNewAppDialog] = useState(false) + + return ( + +
+ + + +
+ {t('dataset.createDataset')} +
+
+
{t('dataset.createDatasetIntro')}
+ {/*
{t('app.createFromConfigFile')}
*/} +
+ ) +} + +export default CreateAppCard diff --git a/web/app/(commonLayout)/datasets/assets/application.svg b/web/app/(commonLayout)/datasets/assets/application.svg new file mode 100644 index 0000000000..0384961f82 --- /dev/null +++ b/web/app/(commonLayout)/datasets/assets/application.svg @@ -0,0 +1,6 @@ + + + + + + diff --git a/web/app/(commonLayout)/datasets/assets/doc.svg b/web/app/(commonLayout)/datasets/assets/doc.svg new file mode 100644 index 0000000000..6bb150cfd6 --- /dev/null +++ b/web/app/(commonLayout)/datasets/assets/doc.svg @@ -0,0 +1,3 @@ + + + diff --git a/web/app/(commonLayout)/datasets/assets/text.svg b/web/app/(commonLayout)/datasets/assets/text.svg new file mode 100644 index 0000000000..6bb150cfd6 --- /dev/null +++ b/web/app/(commonLayout)/datasets/assets/text.svg @@ -0,0 +1,3 @@ + + + diff --git a/web/app/(commonLayout)/datasets/create/page.tsx b/web/app/(commonLayout)/datasets/create/page.tsx new file mode 100644 index 0000000000..663a830665 --- /dev/null +++ b/web/app/(commonLayout)/datasets/create/page.tsx @@ -0,0 +1,12 @@ +import React from 'react' +import DatasetUpdateForm from '@/app/components/datasets/create' + +type Props = {} + +const DatasetCreation = async (props: Props) => { + return ( + + ) +} + +export default DatasetCreation diff --git a/web/app/(commonLayout)/datasets/page.tsx b/web/app/(commonLayout)/datasets/page.tsx new file mode 100644 index 0000000000..909c46a435 --- /dev/null +++ b/web/app/(commonLayout)/datasets/page.tsx @@ -0,0 +1,23 @@ +import classNames from 'classnames' +import { getLocaleOnServer } from '@/i18n/server' +import { useTranslation } from '@/i18n/i18next-serverside-config' +import Datasets from './Datasets' +import DatasetFooter from './DatasetFooter' + +const AppList = async () => { + const locale = getLocaleOnServer() + const { t } = await useTranslation(locale, 'dataset') + + return ( +
+ + +
+ ) +} + +export const metadata = { + title: 'Datasets - Dify', +} + +export default AppList diff --git a/web/app/(commonLayout)/layout.tsx b/web/app/(commonLayout)/layout.tsx new file mode 100644 index 0000000000..901fd95eae --- /dev/null +++ b/web/app/(commonLayout)/layout.tsx @@ -0,0 +1,19 @@ +import React from "react"; +import type { FC } from 'react' +import LayoutClient, { ICommonLayoutProps } from "./_layout-client"; +import GA, { GaType } from '@/app/components/base/ga' + +const Layout: FC = ({ children }) => { + return ( + <> + + + + ) +} + +export const metadata = { + title: 'Dify', +} + +export default Layout \ No newline at end of file diff --git a/web/app/(commonLayout)/list.module.css b/web/app/(commonLayout)/list.module.css new file mode 100644 index 0000000000..14351d40eb --- /dev/null +++ b/web/app/(commonLayout)/list.module.css @@ -0,0 +1,183 @@ +.listItem { + @apply col-span-1 bg-white border-2 border-solid border-transparent rounded-lg shadow-sm min-h-[160px] flex flex-col transition-all duration-200 ease-in-out cursor-pointer hover:shadow-lg; +} + +.listItem.newItemCard { + @apply outline outline-1 outline-gray-200 -outline-offset-1 hover:shadow-sm hover:bg-white; + background-color: rgba(229, 231, 235, 0.5); +} + +.listItem.selectable { + @apply relative bg-gray-50 outline outline-1 outline-gray-200 -outline-offset-1 shadow-none hover:bg-none hover:shadow-none hover:outline-primary-200 transition-colors; +} +.listItem.selectable * { + @apply relative; +} +.listItem.selectable::before { + content: ''; + @apply absolute top-0 left-0 block w-full h-full rounded-lg pointer-events-none opacity-0 transition-opacity duration-200 ease-in-out hover:opacity-100; + background: linear-gradient(0deg, rgba(235, 245, 255, 0.5), rgba(235, 245, 255, 0.5)), #FFFFFF; +} +.listItem.selectable:hover::before { + @apply opacity-100; +} + +.listItem.selected { + @apply border-primary-600 hover:border-primary-600 border-2; +} +.listItem.selected::before { + @apply opacity-100; +} + +.appIcon { + @apply flex items-center justify-center w-8 h-8 bg-pink-100 rounded-lg grow-0 shrink-0; +} +.appIcon.medium { + @apply w-9 h-9; +} +.appIcon.large { + @apply w-10 h-10; +} + +.newItemIcon { + @apply flex items-center justify-center w-8 h-8 transition-colors duration-200 ease-in-out border border-gray-200 rounded-lg hover:bg-white grow-0 shrink-0; +} +.listItem:hover .newItemIcon { + @apply bg-gray-50 border-primary-100; +} +.newItemCard .newItemIcon { + @apply bg-gray-100; +} +.newItemCard:hover .newItemIcon { + @apply bg-white; +} +.selectable .newItemIcon { + @apply bg-gray-50; +} +.selectable:hover .newItemIcon { + @apply bg-primary-50; +} +.newItemIconImage { + @apply grow-0 shrink-0 block w-4 h-4 bg-center bg-contain transition-colors duration-200 ease-in-out; + color: #1f2a37; +} +.listItem:hover .newIconImage { + @apply text-primary-600; +} +.newItemIconAdd { + background-image: url('./apps/assets/add.svg'); +} +.newItemIconChat { + background-image: url('./apps/assets/chat.svg'); +} +.newItemIconComplete { + background-image: url('./apps/assets/completion.svg'); +} + +.listItemTitle { + @apply flex pt-[14px] px-[14px] pb-3 h-[66px] items-center gap-3 grow-0 shrink-0; +} + +.listItemHeading { + @apply relative h-8 text-sm font-medium leading-8 grow; +} + +.listItemHeadingContent { + @apply absolute top-0 left-0 w-full h-full overflow-hidden text-ellipsis whitespace-nowrap; +} + +.deleteAppIcon { + @apply hidden grow-0 shrink-0 basis-8 w-8 h-8 rounded-lg transition-colors duration-200 ease-in-out bg-white border border-gray-200 hover:bg-gray-100 bg-center bg-no-repeat; + background-size: 16px; + background-image: url('./apps/assets/delete.svg'); +} +.listItem:hover .deleteAppIcon { + @apply block; +} + +.listItemDescription { + @apply mb-3 px-[14px] h-9 text-xs leading-normal text-gray-500 line-clamp-2; +} + +.listItemFooter { + @apply flex items-center flex-wrap min-h-[42px] px-[14px] pt-2 pb-[10px]; +} +.listItemFooter.datasetCardFooter { + @apply flex items-center gap-4 text-xs text-gray-500; +} + +.listItemStats { + @apply flex items-center gap-1; +} + +.listItemFooterIcon { + @apply block w-3 h-3 bg-center bg-contain; +} +.solidChatIcon { + background-image: url('./apps/assets/chat-solid.svg'); +} +.solidCompletionIcon { + background-image: url('./apps/assets/completion-solid.svg'); +} +.docIcon { + background-image: url('./datasets/assets/doc.svg'); +} +.textIcon { + background-image: url('./datasets/assets/text.svg'); +} +.applicationIcon { + background-image: url('./datasets/assets/application.svg'); +} + +.newItemCardHeading { + @apply transition-colors duration-200 ease-in-out; +} +.listItem:hover .newItemCardHeading { + @apply text-primary-600; +} + +.listItemLink { + @apply inline-flex items-center gap-1 text-xs text-gray-400 transition-colors duration-200 ease-in-out; +} +.listItem:hover .listItemLink { + @apply text-primary-600 +} + +.linkIcon { + @apply block w-[13px] h-[13px] bg-center bg-contain; + background-image: url('./apps/assets/link.svg'); +} + +.linkIcon.grayLinkIcon { + background-image: url('./apps/assets/link-gray.svg'); +} +.listItem:hover .grayLinkIcon { + background-image: url('./apps/assets/link.svg'); +} + +.rightIcon { + @apply block w-[13px] h-[13px] bg-center bg-contain; + background-image: url('./apps/assets/right-arrow.svg'); +} + +.socialMediaLink { + @apply flex items-center justify-center w-8 h-8 cursor-pointer hover:opacity-80 transition-opacity duration-200 ease-in-out; +} + +.socialMediaIcon { + @apply block w-6 h-6 bg-center bg-contain; +} + +.githubIcon { + background-image: url('./apps/assets/github.svg'); +} + +.discordIcon { + background-image: url('./apps/assets/discord.svg'); +} + +/* #region new app dialog */ +.newItemCaption { + @apply inline-flex items-center mb-2 text-sm font-medium; +} +/* #endregion new app dialog */ diff --git a/web/app/(commonLayout)/plugins-coming-soon/assets/coming-soon.png b/web/app/(commonLayout)/plugins-coming-soon/assets/coming-soon.png new file mode 100644 index 0000000000..a1c48b508d Binary files /dev/null and b/web/app/(commonLayout)/plugins-coming-soon/assets/coming-soon.png differ diff --git a/web/app/(commonLayout)/plugins-coming-soon/assets/plugins-bg.png b/web/app/(commonLayout)/plugins-coming-soon/assets/plugins-bg.png new file mode 100644 index 0000000000..9be76acc52 Binary files /dev/null and b/web/app/(commonLayout)/plugins-coming-soon/assets/plugins-bg.png differ diff --git a/web/app/(commonLayout)/plugins-coming-soon/page.module.css b/web/app/(commonLayout)/plugins-coming-soon/page.module.css new file mode 100644 index 0000000000..73aab949c9 --- /dev/null +++ b/web/app/(commonLayout)/plugins-coming-soon/page.module.css @@ -0,0 +1,32 @@ +.bg { + position: relative; + width: 750px; + height: 450px; + background: #fff url(./assets/plugins-bg.png) center center no-repeat; + background-size: contain; + box-shadow: 0px 12px 16px -4px rgba(16, 24, 40, 0.08), 0px 4px 6px -2px rgba(16, 24, 40, 0.03); + border-radius: 16px; +} + +.text { + position: absolute; + top: 40px; + left: 48px; + width: 526px; + background: linear-gradient(91.92deg, #104AE1 -1.74%, #0098EE 75.74%); + background-clip: text; + color: transparent; + font-size: 24px; + font-weight: 700; + line-height: 32px; +} + +.tag { + position: absolute; + width: 116.74px; + height: 69.3px; + left: -18.37px; + top: -11.48px; + background: url(./assets/coming-soon.png) center center no-repeat; + background-size: contain; +} \ No newline at end of file diff --git a/web/app/(commonLayout)/plugins-coming-soon/page.tsx b/web/app/(commonLayout)/plugins-coming-soon/page.tsx new file mode 100644 index 0000000000..285b0189a3 --- /dev/null +++ b/web/app/(commonLayout)/plugins-coming-soon/page.tsx @@ -0,0 +1,19 @@ +import s from './page.module.css' +import { getLocaleOnServer } from '@/i18n/server' +import { useTranslation } from '@/i18n/i18next-serverside-config' + +const PluginsComingSoon = async () => { + const locale = getLocaleOnServer() + const { t } = await useTranslation(locale, 'common') + + return ( +
+
+
+
{t('menus.pluginsTips')}
+
+
+ ) +} + +export default PluginsComingSoon diff --git a/web/app/(shareLayout)/chat/[token]/page.tsx b/web/app/(shareLayout)/chat/[token]/page.tsx new file mode 100644 index 0000000000..472bc36091 --- /dev/null +++ b/web/app/(shareLayout)/chat/[token]/page.tsx @@ -0,0 +1,16 @@ +import type { FC } from 'react' +import React from 'react' + +import type { IMainProps } from '@/app/components/share/chat' +import Main from '@/app/components/share/chat' + +const Chat: FC = ({ + params, +}: any) => { + + return ( +
+ ) +} + +export default React.memo(Chat) diff --git a/web/app/(shareLayout)/completion/[token]/page.tsx b/web/app/(shareLayout)/completion/[token]/page.tsx new file mode 100644 index 0000000000..0c3992dc72 --- /dev/null +++ b/web/app/(shareLayout)/completion/[token]/page.tsx @@ -0,0 +1,13 @@ +import type { FC } from 'react' +import React from 'react' + +import type { IMainProps } from '@/app/components/share/chat' +import Main from '@/app/components/share/text-generation' + +const TextGeneration: FC = () => { + return ( +
+ ) +} + +export default React.memo(TextGeneration) \ No newline at end of file diff --git a/web/app/(shareLayout)/layout.tsx b/web/app/(shareLayout)/layout.tsx new file mode 100644 index 0000000000..f629943f86 --- /dev/null +++ b/web/app/(shareLayout)/layout.tsx @@ -0,0 +1,18 @@ +import React from "react"; +import type { FC } from 'react' +import GA, { GaType } from '@/app/components/base/ga' + +const Layout: FC<{ + children: React.ReactNode +}> = ({ children }) => { + return ( +
+
+ + {children} +
+
+ ) +} + +export default Layout \ No newline at end of file diff --git a/web/app/api/hello/route.ts b/web/app/api/hello/route.ts new file mode 100644 index 0000000000..d3a7036df1 --- /dev/null +++ b/web/app/api/hello/route.ts @@ -0,0 +1,3 @@ +export async function GET(_request: Request) { + return new Response('Hello, Next.js!') +} diff --git a/web/app/components/app-sidebar/basic.tsx b/web/app/components/app-sidebar/basic.tsx new file mode 100644 index 0000000000..4cefafa0c1 --- /dev/null +++ b/web/app/components/app-sidebar/basic.tsx @@ -0,0 +1,65 @@ +import React from 'react' +import { + InformationCircleIcon, +} from '@heroicons/react/24/outline' +import Tooltip from '../base/tooltip' +import AppIcon from '../base/app-icon' + +const chars = '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ-_' + +export function randomString(length: number) { + let result = '' + for (let i = length; i > 0; --i) result += chars[Math.floor(Math.random() * chars.length)] + return result +} + +export type IAppBasicProps = { + iconType?: 'app' | 'api' | 'dataset' + iconUrl?: string + name: string + type: string | React.ReactNode + hoverTip?: string + textStyle?: { main?: string; extra?: string } +} + +const AlgorithmSvg = + + + + + + + + +const DatasetSvg = + + + +const ICON_MAP = { + 'app': , + 'api': , + 'dataset': +} + +export default function AppBasic({ iconUrl, name, type, hoverTip, textStyle, iconType = 'app' }: IAppBasicProps) { + return ( +
+ {iconUrl && ( +
+ {/* {name} */} + {ICON_MAP[iconType]} +
+ )} +
+
+ {name} + {hoverTip + && + + } +
+
{type}
+
+
+ ) +} diff --git a/web/app/components/app-sidebar/index.tsx b/web/app/components/app-sidebar/index.tsx new file mode 100644 index 0000000000..cdafcce78e --- /dev/null +++ b/web/app/components/app-sidebar/index.tsx @@ -0,0 +1,39 @@ +import React from 'react' +import type { FC } from 'react' +import NavLink from './navLink' +import AppBasic from './basic' + +export type IAppDetailNavProps = { + iconType?: 'app' | 'dataset' + title: string + desc: string + navigation: Array<{ + name: string + href: string + icon: any + selectedIcon: any + }> + extraInfo?: React.ReactNode +} + +const sampleAppIconUrl = 'https://images.unsplash.com/photo-1472099645785-5658abf4ff4e?ixlib=rb-1.2.1&ixid=eyJhcHBfaWQiOjEyMDd9&auto=format&fit=facearea&facepad=2&w=256&h=256&q=80' + +const AppDetailNav: FC = ({ title, desc, navigation, extraInfo, iconType = 'app' }) => { + return ( +
+
+ +
+ +
+ ) +} + +export default React.memo(AppDetailNav) diff --git a/web/app/components/app-sidebar/navLink.tsx b/web/app/components/app-sidebar/navLink.tsx new file mode 100644 index 0000000000..20aa2241bd --- /dev/null +++ b/web/app/components/app-sidebar/navLink.tsx @@ -0,0 +1,39 @@ +'use client' +import { useSelectedLayoutSegment } from 'next/navigation' +import classNames from 'classnames' +import Link from 'next/link' + +export default function NavLink({ + name, + href, + iconMap, +}: { + name: string + href: string + iconMap: { selected: any; normal: any } +}) { + const segment = useSelectedLayoutSegment() + const isActive = href.toLowerCase().split('/')?.pop() === segment?.toLowerCase() + const NavIcon = isActive ? iconMap.selected : iconMap.normal + + return ( + +
+ ) +} +export default React.memo(Chat) diff --git a/web/app/components/app/chat/loading-anim/index.tsx b/web/app/components/app/chat/loading-anim/index.tsx new file mode 100644 index 0000000000..0cd4111b39 --- /dev/null +++ b/web/app/components/app/chat/loading-anim/index.tsx @@ -0,0 +1,16 @@ +'use client' +import React, { FC } from 'react' +import s from './style.module.css' + +export interface ILoaidingAnimProps { + type: 'text' | 'avatar' +} + +const LoaidingAnim: FC = ({ + type +}) => { + return ( +
+ ) +} +export default React.memo(LoaidingAnim) diff --git a/web/app/components/app/chat/loading-anim/style.module.css b/web/app/components/app/chat/loading-anim/style.module.css new file mode 100644 index 0000000000..5a764db13c --- /dev/null +++ b/web/app/components/app/chat/loading-anim/style.module.css @@ -0,0 +1,82 @@ +.dot-flashing { + position: relative; + animation: 1s infinite linear alternate; + animation-delay: 0.5s; +} + +.dot-flashing::before, +.dot-flashing::after { + content: ""; + display: inline-block; + position: absolute; + top: 0; + animation: 1s infinite linear alternate; +} + +.dot-flashing::before { + animation-delay: 0s; +} + +.dot-flashing::after { + animation-delay: 1s; +} + +@keyframes dot-flashing { + 0% { + background-color: #667085; + } + + 50%, + 100% { + background-color: rgba(102, 112, 133, 0.3); + } +} + +@keyframes dot-flashing-avatar { + 0% { + background-color: #155EEF; + } + + 50%, + 100% { + background-color: rgba(21, 94, 239, 0.3); + } +} + +.text, +.text::before, +.text::after { + width: 4px; + height: 4px; + border-radius: 50%; + background-color: #667085; + color: #667085; + animation-name: dot-flashing; +} + +.text::before { + left: -7px; +} + +.text::after { + left: 7px; +} + +.avatar, +.avatar::before, +.avatar::after { + width: 2px; + height: 2px; + border-radius: 50%; + background-color: #155EEF; + color: #155EEF; + animation-name: dot-flashing-avatar; +} + +.avatar::before { + left: -5px; +} + +.avatar::after { + left: 5px; +} \ No newline at end of file diff --git a/web/app/components/app/chat/style.module.css b/web/app/components/app/chat/style.module.css new file mode 100644 index 0000000000..0744311c6d --- /dev/null +++ b/web/app/components/app/chat/style.module.css @@ -0,0 +1,91 @@ +.answerIcon { + position: relative; + background: url(./icons/robot.svg); +} + +.typeingIcon { + position: absolute; + top: 0px; + left: 0px; + display: flex; + justify-content: center; + align-items: center; + width: 16px; + height: 16px; + background: #FFFFFF; + box-shadow: 0px 1px 2px rgba(16, 24, 40, 0.05); + border-radius: 16px; +} + + +.questionIcon { + background: url(./icons/default-avatar.jpg); + background-size: contain; + border-radius: 50%; +} + +.answer::before, +.question::before { + content: ''; + position: absolute; + top: 0; + width: 8px; + height: 12px; +} + +.answer::before { + left: 0; + background: url(./icons/answer.svg) no-repeat; +} + +.answerWrap .itemOperation { + display: none; +} + +.answerWrap:hover .itemOperation { + display: flex; +} + +.question::before { + right: 0; + background: url(./icons/question.svg) no-repeat; +} + +.textArea { + padding-top: 13px; + padding-bottom: 13px; + padding-right: 90px; + border-radius: 12px; + line-height: 20px; + background-color: #fff; +} + +.textArea:hover { + background-color: #fff; +} + +/* .textArea:focus { + box-shadow: 0px 3px 15px -3px rgba(0, 0, 0, 0.1), 0px 4px 6px rgba(0, 0, 0, 0.05); +} */ + +.count { + /* display: none; */ + padding: 0 2px; +} + +.sendBtn { + background: url(./icons/send.svg) center center no-repeat; +} + +.sendBtn:hover { + background-image: url(./icons/send-active.svg); + background-color: #EBF5FF; +} + +.textArea:focus+div .count { + display: block; +} + +.textArea:focus+div .sendBtn { + background-image: url(./icons/send-active.svg); +} \ No newline at end of file diff --git a/web/app/components/app/configuration/base/feature-panel/index.tsx b/web/app/components/app/configuration/base/feature-panel/index.tsx new file mode 100644 index 0000000000..0bb2f88ba1 --- /dev/null +++ b/web/app/components/app/configuration/base/feature-panel/index.tsx @@ -0,0 +1,54 @@ +'use client' +import React, { FC, ReactNode } from 'react' +import cn from 'classnames' + +export interface IFeaturePanelProps { + className?: string + headerIcon: ReactNode + title: ReactNode + headerRight: ReactNode + hasHeaderBottomBorder?: boolean + isFocus?: boolean + noBodySpacing?: boolean + children?: ReactNode +} + +const FeaturePanel: FC = ({ + className, + headerIcon, + title, + headerRight, + hasHeaderBottomBorder, + isFocus, + noBodySpacing, + children, +}) => { + return ( +
+ {/* Header */} +
+
+
+
{headerIcon}
+
{title}
+
+
+ {headerRight} +
+
+
+ {/* Body */} + {children && ( +
+ {children} +
+ )} +
+ ) +} +export default React.memo(FeaturePanel) diff --git a/web/app/components/app/configuration/base/group-name/index.tsx b/web/app/components/app/configuration/base/group-name/index.tsx new file mode 100644 index 0000000000..68c87d6d92 --- /dev/null +++ b/web/app/components/app/configuration/base/group-name/index.tsx @@ -0,0 +1,23 @@ +'use client' +import React, { FC } from 'react' + +export interface IGroupNameProps { + name: string +} + +const GroupName: FC = ({ + name +}) => { + return ( +
+
{name}
+
+
+ ) +} +export default React.memo(GroupName) diff --git a/web/app/components/app/configuration/base/icons/more-like-this-icon.tsx b/web/app/components/app/configuration/base/icons/more-like-this-icon.tsx new file mode 100644 index 0000000000..8aa79fb400 --- /dev/null +++ b/web/app/components/app/configuration/base/icons/more-like-this-icon.tsx @@ -0,0 +1,13 @@ +'use client' +import React, { FC } from 'react' + +const MoreLikeThisIcon: FC = ({ }) => { + return ( + + + + + + ) +} +export default React.memo(MoreLikeThisIcon) diff --git a/web/app/components/app/configuration/base/icons/remove-icon/index.tsx b/web/app/components/app/configuration/base/icons/remove-icon/index.tsx new file mode 100644 index 0000000000..497d5b68c2 --- /dev/null +++ b/web/app/components/app/configuration/base/icons/remove-icon/index.tsx @@ -0,0 +1,31 @@ +'use client' +import React, { FC, useState } from 'react' +import cn from 'classnames' + +export interface IRemoveIconProps { + className?: string + isHoverStatus?: boolean + onClick: () => void +} + +const RemoveIcon: FC = ({ + className, + isHoverStatus, + onClick +}) => { + const [isHovered, setIsHovered] = useState(false) + const computedIsHovered = isHoverStatus || isHovered + return ( +
setIsHovered(true)} + onMouseLeave={() => setIsHovered(false)} + onClick={onClick} + > + + + +
+ ) +} +export default React.memo(RemoveIcon) diff --git a/web/app/components/app/configuration/base/icons/remove-icon/style.module.css b/web/app/components/app/configuration/base/icons/remove-icon/style.module.css new file mode 100644 index 0000000000..e69de29bb2 diff --git a/web/app/components/app/configuration/base/icons/suggested-questions-after-answer-icon.tsx b/web/app/components/app/configuration/base/icons/suggested-questions-after-answer-icon.tsx new file mode 100644 index 0000000000..178d27609f --- /dev/null +++ b/web/app/components/app/configuration/base/icons/suggested-questions-after-answer-icon.tsx @@ -0,0 +1,11 @@ +'use client' +import React, { FC } from 'react' + +const SuggestedQuestionsAfterAnswerIcon: FC = () => { + return ( + + + + ) +} +export default React.memo(SuggestedQuestionsAfterAnswerIcon) diff --git a/web/app/components/app/configuration/base/icons/var-icon.tsx b/web/app/components/app/configuration/base/icons/var-icon.tsx new file mode 100644 index 0000000000..e991e5bb38 --- /dev/null +++ b/web/app/components/app/configuration/base/icons/var-icon.tsx @@ -0,0 +1,11 @@ +'use client' +import React, { FC } from 'react' + +const VarIcon: FC = () => { + return ( + + + + ) +} +export default React.memo(VarIcon) diff --git a/web/app/components/app/configuration/base/operation-btn/index.tsx b/web/app/components/app/configuration/base/operation-btn/index.tsx new file mode 100644 index 0000000000..1756464f7e --- /dev/null +++ b/web/app/components/app/configuration/base/operation-btn/index.tsx @@ -0,0 +1,39 @@ +'use client' +import React, { FC } from 'react' +import { useTranslation } from 'react-i18next' +import { PlusIcon } from '@heroicons/react/20/solid' + +export interface IOperationBtnProps { + type: 'add' | 'edit' + actionName?: string + onClick: () => void +} + +const iconMap = { + add: , + edit: ( + + + ) +} + +const OperationBtn: FC = ({ + type, + actionName, + onClick +}) => { + const { t } = useTranslation() + return ( +
+
+ {iconMap[type]} +
+
+ {actionName || t(`common.operation.${type}`)} +
+
+ ) +} +export default React.memo(OperationBtn) diff --git a/web/app/components/app/configuration/base/var-highlight/index.tsx b/web/app/components/app/configuration/base/var-highlight/index.tsx new file mode 100644 index 0000000000..8d7e09573f --- /dev/null +++ b/web/app/components/app/configuration/base/var-highlight/index.tsx @@ -0,0 +1,36 @@ +'use client' +import React, { FC } from 'react' + +import s from './style.module.css' + +export interface IVarHighlightProps { + name: string +} + +const VarHighlight: FC = ({ + name, +}) => { + return ( +
+ {'{{'} + {name} + {'}}'} +
+ ) +} + +export const varHighlightHTML = ({ name }: IVarHighlightProps) => { + const html = `
+ {{ + ${name} + }} +
` + return html +} + + + +export default React.memo(VarHighlight) diff --git a/web/app/components/app/configuration/base/var-highlight/style.module.css b/web/app/components/app/configuration/base/var-highlight/style.module.css new file mode 100644 index 0000000000..cd5c8f8d77 --- /dev/null +++ b/web/app/components/app/configuration/base/var-highlight/style.module.css @@ -0,0 +1,3 @@ +.item { + background-color: rgba(21, 94, 239, 0.05); +} \ No newline at end of file diff --git a/web/app/components/app/configuration/base/warning-mask/formatting-changed.tsx b/web/app/components/app/configuration/base/warning-mask/formatting-changed.tsx new file mode 100644 index 0000000000..61787a4d0f --- /dev/null +++ b/web/app/components/app/configuration/base/warning-mask/formatting-changed.tsx @@ -0,0 +1,40 @@ +'use client' +import React, { FC } from 'react' +import { useTranslation } from 'react-i18next' +import WarningMask from '.' +import Button from '@/app/components/base/button' + +export interface IFormattingChangedProps { + onConfirm: () => void + onCancel: () => void +} + +const icon = ( + + + +) + +const FormattingChanged: FC = ({ + onConfirm, + onCancel +}) => { + const { t } = useTranslation() + + return ( + + + + + } + /> + ) +} +export default React.memo(FormattingChanged) diff --git a/web/app/components/app/configuration/base/warning-mask/has-not-set-api.tsx b/web/app/components/app/configuration/base/warning-mask/has-not-set-api.tsx new file mode 100644 index 0000000000..af39b70038 --- /dev/null +++ b/web/app/components/app/configuration/base/warning-mask/has-not-set-api.tsx @@ -0,0 +1,37 @@ +'use client' +import React, { FC } from 'react' +import { useTranslation } from 'react-i18next' +import WarningMask from '.' +import Button from '@/app/components/base/button' + +export interface IHasNotSetAPIProps { + isTrailFinished: boolean + onSetting: () => void +} + +const icon = ( + + + + +) + +const HasNotSetAPI: FC = ({ + isTrailFinished, + onSetting +}) => { + const { t } = useTranslation() + + return ( + + {t('appDebug.notSetAPIKey.settingBtn')} + {icon} + } + /> + ) +} +export default React.memo(HasNotSetAPI) diff --git a/web/app/components/app/configuration/base/warning-mask/index.tsx b/web/app/components/app/configuration/base/warning-mask/index.tsx new file mode 100644 index 0000000000..d7e45e360f --- /dev/null +++ b/web/app/components/app/configuration/base/warning-mask/index.tsx @@ -0,0 +1,42 @@ +'use client' +import React, { FC } from 'react' + +import s from './style.module.css' + +export interface IWarningMaskProps { + title: string + description: string + footer: React.ReactNode +} + +const warningIcon = ( + + + +) + +const WarningMask: FC = ({ + title, + description, + footer, +}) => { + return ( +
+
+
{warningIcon}
+
+ {title} +
+
+ {description} +
+
+ {footer} +
+
+ +
+ ) +} +export default React.memo(WarningMask) diff --git a/web/app/components/app/configuration/base/warning-mask/style.module.css b/web/app/components/app/configuration/base/warning-mask/style.module.css new file mode 100644 index 0000000000..e1d6f10de9 --- /dev/null +++ b/web/app/components/app/configuration/base/warning-mask/style.module.css @@ -0,0 +1,8 @@ +.mask { + background-color: rgba(239, 244, 255, 0.9); + backdrop-filter: blur(2px); +} + +.icon { + box-shadow: 0px 12px 16px -4px rgba(16, 24, 40, 0.08), 0px 4px 6px -2px rgba(16, 24, 40, 0.03); +} \ No newline at end of file diff --git a/web/app/components/app/configuration/config-model/index.tsx b/web/app/components/app/configuration/config-model/index.tsx new file mode 100644 index 0000000000..999e989682 --- /dev/null +++ b/web/app/components/app/configuration/config-model/index.tsx @@ -0,0 +1,243 @@ +'use client' +import type { FC } from 'react' +import React, { useEffect, useState } from 'react' +import cn from 'classnames' +import { useTranslation } from 'react-i18next' +import { useBoolean, useClickAway } from 'ahooks' +import ParamItem from './param-item' +import Radio from '@/app/components/base/radio' +import Panel from '@/app/components/base/panel' +import type { CompletionParams } from '@/models/debug' +import { Cog8ToothIcon, InformationCircleIcon, ChevronDownIcon } from '@heroicons/react/24/outline' +import { AppType } from '@/types/app' +import { TONE_LIST } from '@/config' + +export type IConifgModelProps = { + mode: string + modelId: string + setModelId: (id: string) => void + completionParams: CompletionParams + onCompletionParamsChange: (newParams: CompletionParams) => void + disabled: boolean + canUseGPT4: boolean + onShowUseGPT4Confirm: () => void +} + +const options = [ + { id: 'gpt-3.5-turbo', name: 'gpt-3.5-turbo', type: AppType.chat }, + { id: 'gpt-4', name: 'gpt-4', type: AppType.chat }, // 8k version + { id: 'gpt-3.5-turbo', name: 'gpt-3.5-turbo', type: AppType.completion }, + { id: 'text-davinci-003', name: 'text-davinci-003', type: AppType.completion }, +] + +const ModelIcon = ({ className }: { className?: string }) => ( + + + + +) + +const ConifgModel: FC = ({ + mode, + modelId, + setModelId, + completionParams, + onCompletionParamsChange, + disabled, + canUseGPT4, + onShowUseGPT4Confirm, +}) => { + const { t } = useTranslation() + const isChatApp = mode === AppType.chat + const availableModels = options.filter((item) => item.type === mode) + const [isShowConfig, { setFalse: hideConfig, toggle: toogleShowConfig }] = useBoolean(false) + const configContentRef = React.useRef(null) + useClickAway(() => { + hideConfig() + }, configContentRef) + + const params = [ + { + id: 1, + name: t('common.model.params.temperature'), + key: 'temperature', + tip: t('common.model.params.temperatureTip'), + max: 2, + }, + { + id: 2, + name: t('common.model.params.topP'), + key: 'top_p', + tip: t('common.model.params.topPTip'), + max: 1, + }, + { + id: 3, + name: t('common.model.params.presencePenalty'), + key: 'presence_penalty', + tip: t('common.model.params.presencePenaltyTip'), + min: -2, + max: 2, + }, + { + id: 4, + name: t('common.model.params.frequencyPenalty'), + key: 'frequency_penalty', + tip: t('common.model.params.frequencyPenaltyTip'), + min: -2, + max: 2, + }, + { + id: 5, + name: t('common.model.params.maxToken'), + key: 'max_tokens', + tip: t('common.model.params.maxTokenTip'), + step: 100, + max: 4000, + }, + ] + + const selectModelDisabled = false // chat gpt-3.5-turbo, gpt-4; text generation text-davinci-003, gpt-3.5-turbo + + const selectedModel = { name: modelId } // options.find(option => option.id === modelId) + const [isShowOption, { setFalse: hideOption, toggle: toogleOption }] = useBoolean(false) + const triggerRef = React.useRef(null) + useClickAway(() => { + hideOption() + }, triggerRef) + + const handleSelectModel = (id: string) => { + return () => { + if (id === 'gpt-4' && !canUseGPT4) { + hideConfig() + hideOption() + onShowUseGPT4Confirm() + return + } + setModelId(id) + } + } + + function matchToneId(completionParams: CompletionParams): number { + const remvoedCustomeTone = TONE_LIST.slice(0, -1) + const CUSTOM_TONE_ID = 4 + const tone = remvoedCustomeTone.find((tone) => { + return tone.config?.temperature === completionParams.temperature + && tone.config?.top_p === completionParams.top_p + && tone.config?.presence_penalty === completionParams.presence_penalty + && tone.config?.frequency_penalty === completionParams.frequency_penalty + }) + return tone ? tone.id : CUSTOM_TONE_ID + } + + // tone is a preset of completionParams. + const [toneId, setToneId] = React.useState(matchToneId(completionParams)) // default is Balanced + // set completionParams by toneId + const handleToneChange = (id: number) => { + if (id === 4) + return // custom tone + const tone = TONE_LIST.find(tone => tone.id === id) + if (tone) { + setToneId(id) + onCompletionParamsChange({ + ...tone.config, + max_tokens: completionParams.max_tokens + } as CompletionParams) + } + } + + useEffect(() => { + setToneId(matchToneId(completionParams)) + }, [completionParams]) + + const handleParamChange = (id: number, value: number) => { + const key = params.find(item => item.id === id)?.key + + if (key) { + onCompletionParamsChange({ + ...completionParams, + [key]: value, + }) + } + } + const ableStyle = 'bg-indigo-25 border-[#2A87F5] cursor-pointer' + const diabledStyle = 'bg-[#FFFCF5] border-[#F79009]' + + return ( +
+
!disabled && toogleShowConfig()} + > + +
{selectedModel.name}
+ {disabled ? : } +
+ {isShowConfig && ( + + + + + + + + + + } + title={t('appDebug.modelConfig.title')} + > +
+
+
{t('appDebug.modelConfig.model')}
+ {/* model selector */} +
+
!selectModelDisabled && toogleOption()} className={cn(selectModelDisabled ? 'cursor-not-allowed' : 'cursor-pointer', "flex items-center h-9 px-3 space-x-2 rounded-lg bg-gray-50 ")}> + +
{selectedModel?.name}
+ {!selectModelDisabled && } +
+ {isShowOption && ( +
+ {availableModels.map(item => ( +
+ +
{item.name}
+
+ ))} +
+ )} +
+
+
+ + {/* Response type */} +
+
{t('appDebug.modelConfig.setTone')}
+ + <> + {TONE_LIST.slice(0, 3).map(tone => ( + {t(`common.model.tone.${tone.name}`) as string} + ))} + +
+ {t(`common.model.tone.${TONE_LIST[3].name}`) as string} +
+
+ + {/* Params */} +
+ {params.map(({ key, ...param }) => ())} +
+
+
+ )} +
+ + ) +} + +export default React.memo(ConifgModel) diff --git a/web/app/components/app/configuration/config-model/param-item.tsx b/web/app/components/app/configuration/config-model/param-item.tsx new file mode 100644 index 0000000000..3ac1959434 --- /dev/null +++ b/web/app/components/app/configuration/config-model/param-item.tsx @@ -0,0 +1,45 @@ +'use client' +import type { FC } from 'react' +import React from 'react' +import Tooltip from '@/app/components/base/tooltip' +import Slider from '@/app/components/base/slider' + +export type IParamIteProps = { + id: number + name: string + tip: string + value: number + step?: number + min?: number + max: number + onChange: (id: number, value: number) => void +} + +const ParamIte: FC = ({ id, name, tip, step = 0.1, min = 0, max, value, onChange }) => { + return ( +
+
+ {name} + {/* Give tooltip different tip to avoiding hide bug */} + {tip}
} position='top' selector={`param-name-tooltip-${id}`}> + + + + +
+
+
+ onChange(id, value / (max < 5 ? 10 : 1))} /> +
+ { + const value = parseFloat(e.target.value) + if (value < 0 || value > max) + return + + onChange(id, value) + }} /> +
+ + ) +} +export default React.memo(ParamIte) diff --git a/web/app/components/app/configuration/config-prompt/confirm-add-var/index.tsx b/web/app/components/app/configuration/config-prompt/confirm-add-var/index.tsx new file mode 100644 index 0000000000..972f097be2 --- /dev/null +++ b/web/app/components/app/configuration/config-prompt/confirm-add-var/index.tsx @@ -0,0 +1,72 @@ +'use client' +import React, { FC, useRef } from 'react' +import { useTranslation } from 'react-i18next' +import Button from '@/app/components/base/button' +import { useClickAway } from 'ahooks' +import VarHighlight from '../../base/var-highlight' + +export interface IConfirmAddVarProps { + varNameArr: string[] + onConfrim: () => void + onCancel: () => void + onHide: () => void +} + +const VarIcon = ( + + + + + +) + +const ConfirmAddVar: FC = ({ + varNameArr, + onConfrim, + onCancel, + onHide, +}) => { + const { t } = useTranslation() + const mainContentRef = useRef(null) + useClickAway(() => { + onHide() + }, mainContentRef) + return ( +
+
+
+
{VarIcon}
+
+
{t('appDebug.autoAddVar')}
+
+ {varNameArr.map((name) => ( + + ))} +
+
+
+
+ + +
+
+ +
+ ) +} +export default React.memo(ConfirmAddVar) diff --git a/web/app/components/app/configuration/config-prompt/index.tsx b/web/app/components/app/configuration/config-prompt/index.tsx new file mode 100644 index 0000000000..b18d148c4c --- /dev/null +++ b/web/app/components/app/configuration/config-prompt/index.tsx @@ -0,0 +1,95 @@ +'use client' +import type { FC } from 'react' +import React from 'react' +import BlockInput from '@/app/components/base/block-input' +import type { PromptVariable } from '@/models/debug' +import Tooltip from '@/app/components/base/tooltip' +import { AppType } from '@/types/app' +import { getNewVar } from '@/utils/var' +import { useTranslation } from 'react-i18next' +import { useBoolean } from 'ahooks' +import ConfirmAddVar from './confirm-add-var' + +export type IPromptProps = { + mode: AppType + promptTemplate: string + promptVariables: PromptVariable[] + onChange: (promp: string, promptVariables: PromptVariable[]) => void +} + +const Prompt: FC = ({ + mode, + promptTemplate, + promptVariables, + onChange, +}) => { + const { t } = useTranslation() + const promptVariablesObj = (() => { + const obj: Record = {} + promptVariables.forEach((item) => { + obj[item.key] = true + }) + return obj + })() + + const [newPromptVariables, setNewPromptVariables] = React.useState(promptVariables) + const [newTemplates, setNewTemplates] = React.useState('') + const [isShowConfirmAddVar, { setTrue: showConfirmAddVar, setFalse: hideConfirmAddVar }] = useBoolean(false) + + const handleChange = (newTemplates: string, keys: string[]) => { + // const hasRemovedKeysInput = promptVariables.filter(input => keys.includes(input.key)) + const newPromptVariables = keys.filter(key => !(key in promptVariablesObj)).map(key => getNewVar(key)) + if (newPromptVariables.length > 0) { + setNewPromptVariables(newPromptVariables) + setNewTemplates(newTemplates) + showConfirmAddVar() + return + } + onChange(newTemplates, []) + } + + const handleAutoAdd = (isAdd: boolean) => { + return () => { + onChange(newTemplates, isAdd ? newPromptVariables : []) + hideConfirmAddVar() + } + } + + return ( +
+
+ + + +
{mode === AppType.chat ? t('appDebug.chatSubTitle') : t('appDebug.completionSubTitle')}
+ + {t('appDebug.promptTip')} +
} + selector='config-prompt-tooltip'> + + + + +
+ + { + handleChange(value, vars) + }} + /> + + {isShowConfirmAddVar && ( + v.name)} + onConfrim={handleAutoAdd(true)} + onCancel={handleAutoAdd(false)} + onHide={hideConfirmAddVar} + /> + )} + + ) +} + +export default React.memo(Prompt) diff --git a/web/app/components/app/configuration/config-var/config-model/index.tsx b/web/app/components/app/configuration/config-var/config-model/index.tsx new file mode 100644 index 0000000000..8d70148992 --- /dev/null +++ b/web/app/components/app/configuration/config-var/config-model/index.tsx @@ -0,0 +1,115 @@ +'use client' +import React, { FC, useState, useEffect } from 'react' +import { useTranslation } from 'react-i18next' +import Modal from '@/app/components/base/modal' +import ModalFoot from '../modal-foot' +import ConfigSelect, { Options } from '../config-select' +import ConfigString from '../config-string' +import Toast from '@/app/components/base/toast' +import type { PromptVariable } from '@/models/debug' +import SelectTypeItem from '../select-type-item' +import { getNewVar } from '@/utils/var' + +import s from './style.module.css' + +export interface IConfigModalProps { + payload: PromptVariable + type?: string + isShow: boolean + onClose: () => void + onConfirm: (newValue: { type: string, value: any }) => void +} + +const ConfigModal: FC = ({ + payload, + isShow, + onClose, + onConfirm +}) => { + const { t } = useTranslation() + const { type, name, key, options, max_length } = payload || getNewVar('') + + const [tempType, setTempType] = useState(type) + useEffect(() => { + setTempType(type) + }, [type]) + const handleTypeChange = (type: string) => { + return () => { + setTempType(type) + } + } + + const isStringInput = tempType === 'string' + const title = isStringInput ? t('appDebug.variableConig.maxLength') : t('appDebug.variableConig.options') + + // string type + const [tempMaxLength, setTempMaxValue] = useState(max_length) + useEffect(() => { + setTempMaxValue(max_length) + }, [max_length]) + + // select type + const [tempOptions, setTempOptions] = useState(options || []) + useEffect(() => { + setTempOptions(options || []) + }, [options]) + + const handleConfirm = () => { + if (isStringInput) { + onConfirm({ type: tempType, value: tempMaxLength }) + } else { + if (tempOptions.length === 0) { + Toast.notify({ type: 'error', message: 'At least one option requied' }) + return + } + const obj: Record = {} + let hasRepeatedItem = false + tempOptions.forEach(o => { + if (obj[o]) { + hasRepeatedItem = true + return + } + obj[o] = true + }) + if (hasRepeatedItem) { + Toast.notify({ type: 'error', message: 'Has repeat items' }) + return + } + onConfirm({ type: tempType, value: tempOptions }) + } + } + + return ( + +
+
{t('appDebug.variableConig.description', { varName: `{{${name || key}}}` })}
+
+
{t('appDebug.variableConig.fieldType')}
+
+ + +
+
+ +
+
{title}
+ {isStringInput ? ( + + ) : ( + + )} +
+ +
+ +
+ ) +} +export default React.memo(ConfigModal) diff --git a/web/app/components/app/configuration/config-var/config-model/style.module.css b/web/app/components/app/configuration/config-var/config-model/style.module.css new file mode 100644 index 0000000000..71476af235 --- /dev/null +++ b/web/app/components/app/configuration/config-var/config-model/style.module.css @@ -0,0 +1,8 @@ +.title { + margin-bottom: 8px; + font-size: 13px; + line-height: 18px; + font-weight: 500; + color: #101828; + text-transform: capitalize; +} \ No newline at end of file diff --git a/web/app/components/app/configuration/config-var/config-select/index.tsx b/web/app/components/app/configuration/config-var/config-select/index.tsx new file mode 100644 index 0000000000..37d3812e11 --- /dev/null +++ b/web/app/components/app/configuration/config-var/config-select/index.tsx @@ -0,0 +1,64 @@ +'use client' +import React, { FC, useState } from 'react' +import { useTranslation } from 'react-i18next' +import { PlusIcon } from '@heroicons/react/24/outline' +import RemoveIcon from '../../base/icons/remove-icon' + +import s from './style.module.css' + +export type Options = string[] +export interface IConfigSelectProps { + options: Options + onChange: (options: Options) => void +} + + +const ConfigSelect: FC = ({ + options, + onChange +}) => { + const { t } = useTranslation() + + return ( +
+ {options.length > 0 && ( +
+ {options.map((o, index) => ( +
+ { + let value = e.target.value + onChange(options.map((item, i) => { + if (index === i) { + return value + } + return item + })) + }} + className={`${s.input} w-full px-3 text-sm leading-9 text-gray-900 border-0 grow h-9 bg-transparent focus:outline-none cursor-pointer`} + /> + { + onChange(options.filter((_, i) => index !== i)) + }} + /> +
+ ))} +
+ )} + +
{ onChange([...options, '']) }} + className='flex items-center h-9 px-3 gap-2 rounded-lg cursor-pointer text-gray-400 bg-gray-100'> + +
{t('appDebug.variableConig.addOption')}
+
+
+ ) +} + +export default React.memo(ConfigSelect) diff --git a/web/app/components/app/configuration/config-var/config-select/style.module.css b/web/app/components/app/configuration/config-var/config-select/style.module.css new file mode 100644 index 0000000000..5d558d89f2 --- /dev/null +++ b/web/app/components/app/configuration/config-var/config-select/style.module.css @@ -0,0 +1,20 @@ +.inputWrap { + border-radius: 8px; + border: 1px solid #EAECF0; + cursor: pointer; +} + +.deleteBtn { + display: none; + display: flex; +} + +.inputWrap:hover { + border-color: #FEE4E2; + background-color: #FFFBFA; + box-shadow: 0px 1px 2px rgba(16, 24, 40, 0.05); +} + +.inputWrap:hover .deleteBtn { + display: flex; +} \ No newline at end of file diff --git a/web/app/components/app/configuration/config-var/config-string/index.tsx b/web/app/components/app/configuration/config-var/config-string/index.tsx new file mode 100644 index 0000000000..41ebf65153 --- /dev/null +++ b/web/app/components/app/configuration/config-var/config-string/index.tsx @@ -0,0 +1,38 @@ +'use client' +import React, { FC, } from 'react' + +export interface IConfigStringProps { + value: number | undefined + onChange: (value: number | undefined) => void +} + +const MAX_LENGTH = 64 + +const ConfigString: FC = ({ + value, + onChange, +}) => { + + return ( +
+ { + let value = parseInt(e.target.value, 10) + if (value > MAX_LENGTH) { + value = MAX_LENGTH + } else if (value < 1) { + value = 1 + } + onChange(value) + }} + className="w-full px-3 text-sm leading-9 text-gray-900 border-0 rounded-lg grow h-9 bg-gray-50 focus:outline-none focus:ring-1 focus:ring-inset focus:ring-gray-200" + /> +
+ ) +} + +export default React.memo(ConfigString) diff --git a/web/app/components/app/configuration/config-var/index.tsx b/web/app/components/app/configuration/config-var/index.tsx new file mode 100644 index 0000000000..45411cc8a5 --- /dev/null +++ b/web/app/components/app/configuration/config-var/index.tsx @@ -0,0 +1,228 @@ +'use client' +import type { FC } from 'react' +import React, { useState } from 'react' +import { useTranslation } from 'react-i18next' +import Panel from '../base/feature-panel' +import Tooltip from '@/app/components/base/tooltip' +import type { PromptVariable } from '@/models/debug' +import { Cog8ToothIcon, TrashIcon } from '@heroicons/react/24/outline' +import { useBoolean } from 'ahooks' +import EditModel from './config-model' +import { DEFAULT_VALUE_MAX_LEN, getMaxVarNameLength } from '@/config' +import { getNewVar } from '@/utils/var' +import OperationBtn from '../base/operation-btn' +import Switch from '@/app/components/base/switch' +import IconTypeIcon from './input-type-icon' +import { checkKeys } from '@/utils/var' +import Toast from '@/app/components/base/toast' + +import s from './style.module.css' +import VarIcon from '../base/icons/var-icon' + +export type IConfigVarProps = { + promptVariables: PromptVariable[] + onPromptVariablesChange: (promptVariables: PromptVariable[]) => void +} + +const ConfigVar: FC = ({ promptVariables, onPromptVariablesChange }) => { + const { t } = useTranslation() + const hasVar = promptVariables.length > 0 + const promptVariableObj = (() => { + const obj: Record = {} + promptVariables.forEach((item) => { + obj[item.key] = true + }) + return obj + })() + + const updatePromptVariable = (key: string, updateKey: string, newValue: any) => { + if (!(key in promptVariableObj)) + return + const newPromptVariables = promptVariables.map((item) => { + if (item.key === key) + return { + ...item, + [updateKey]: newValue + } + + return item + }) + + onPromptVariablesChange(newPromptVariables) + } + + const batchUpdatePromptVariable = (key: string, updateKeys: string[], newValues: any[]) => { + if (!(key in promptVariableObj)) + return + const newPromptVariables = promptVariables.map((item) => { + if (item.key === key) { + const newItem: any = { ...item } + updateKeys.forEach((updateKey, i) => { + newItem[updateKey] = newValues[i] + }) + return newItem + } + + return item + }) + + onPromptVariablesChange(newPromptVariables) + } + + + const updatePromptKey = (index: number, newKey: string) => { + const { isValid, errorKey, errorMessageKey } = checkKeys([newKey], true) + if (!isValid) { + Toast.notify({ + type: 'error', + message: t(`appDebug.varKeyError.${errorMessageKey}`, { key: errorKey }) + }) + return + } + const newPromptVariables = promptVariables.map((item, i) => { + if (i === index) + return { + ...item, + key: newKey, + } + + return item + }) + + onPromptVariablesChange(newPromptVariables) + } + + const updatePromptNameIfNameEmpty = (index: number, newKey: string) => { + if (!newKey) return + const newPromptVariables = promptVariables.map((item, i) => { + if (i === index && !item.name) + return { + ...item, + name: newKey, + } + return item + }) + + onPromptVariablesChange(newPromptVariables) + } + + const handleAddVar = () => { + const newVar = getNewVar('') + onPromptVariablesChange([...promptVariables, newVar]) + } + + const handleRemoveVar = (index: number) => { + onPromptVariablesChange(promptVariables.filter((_, i) => i !== index)) + } + + const [currKey, setCurrKey] = useState(null) + const currItem = currKey ? promptVariables.find(item => item.key === currKey) : null + const [isShowEditModal, { setTrue: showEditModal, setFalse: hideEditModal }] = useBoolean(false) + const handleConfig = (key: string) => { + setCurrKey(key) + showEditModal() + } + + return ( + + } + title={ +
+
{t('appDebug.variableTitle')}
+ + {t('appDebug.variableTip')} +
} selector='config-var-tooltip'> + + + + + + } + headerRight={} + > + {!hasVar && ( +
{t('appDebug.notSetVar')}
+ )} + {hasVar && ( +
+ + + + + + + + + + + {promptVariables.map(({ key, name, type, required }, index) => ( + + + + + + + ))} + +
{t('appDebug.variableTable.key')}{t('appDebug.variableTable.name')}{t('appDebug.variableTable.optional')}{t('appDebug.variableTable.action')}
+
+ + updatePromptKey(index, e.target.value)} + onBlur={e => updatePromptNameIfNameEmpty(index, e.target.value)} + maxLength={getMaxVarNameLength(name)} + className="h-6 leading-6 block w-full rounded-md border-0 py-1.5 text-gray-900 placeholder:text-gray-400 focus:outline-none focus:ring-1 focus:ring-inset focus:ring-gray-200" + /> +
+
+ updatePromptVariable(key, 'name', e.target.value)} + maxLength={getMaxVarNameLength(name)} + className="h-6 leading-6 block w-full rounded-md border-0 py-1.5 text-gray-900 placeholder:text-gray-400 focus:outline-none focus:ring-1 focus:ring-inset focus:ring-gray-200" + /> + +
+ updatePromptVariable(key, 'required', !value)} /> +
+
+
+
handleConfig(key)}> + +
+
handleRemoveVar(index)} > + +
+
+
+
+ )} + + {isShowEditModal && ( + { + if (type === 'string') { + batchUpdatePromptVariable(currKey as string, ['type', 'max_length'], [type, value || DEFAULT_VALUE_MAX_LEN]) + } else { + batchUpdatePromptVariable(currKey as string, ['type', 'options'], [type, value || []]) + } + hideEditModal() + }} + /> + )} + +
+ ) +} +export default React.memo(ConfigVar) diff --git a/web/app/components/app/configuration/config-var/input-type-icon.tsx b/web/app/components/app/configuration/config-var/input-type-icon.tsx new file mode 100644 index 0000000000..ff59a48bda --- /dev/null +++ b/web/app/components/app/configuration/config-var/input-type-icon.tsx @@ -0,0 +1,34 @@ +'use client' +import React, { FC, ReactNode } from 'react' +import { ReactElement } from 'react-markdown/lib/react-markdown' + +export interface IInputTypeIconProps { + type: string +} + +const IconMap = (type: string) => { + const icons: Record = { + 'string': ( + + + + ), + 'select': ( + + + + + ), + } + + return icons[type] as any +} + +const InputTypeIcon: FC = ({ + type +}) => { + const Icon = IconMap(type) + return Icon +} + +export default React.memo(InputTypeIcon) diff --git a/web/app/components/app/configuration/config-var/modal-foot.tsx b/web/app/components/app/configuration/config-var/modal-foot.tsx new file mode 100644 index 0000000000..acd8d9b51d --- /dev/null +++ b/web/app/components/app/configuration/config-var/modal-foot.tsx @@ -0,0 +1,23 @@ +'use client' +import React, { FC } from 'react' +import { useTranslation } from 'react-i18next' +import Button from '@/app/components/base/button' + +export interface IModalFootProps { + onConfirm: () => void + onCancel: () => void +} + +const ModalFoot: FC = ({ + onConfirm, + onCancel +}) => { + const { t } = useTranslation() + return ( +
+ + +
+ ) +} +export default React.memo(ModalFoot) diff --git a/web/app/components/app/configuration/config-var/select-type-item/index.tsx b/web/app/components/app/configuration/config-var/select-type-item/index.tsx new file mode 100644 index 0000000000..632a3301aa --- /dev/null +++ b/web/app/components/app/configuration/config-var/select-type-item/index.tsx @@ -0,0 +1,61 @@ +'use client' +import React, { FC } from 'react' +import { useTranslation } from 'react-i18next' +import cn from 'classnames' + +import s from './style.module.css' + +export interface ISelectTypeItemProps { + type: string + selected: boolean + onClick: () => void +} + +const Icon = ({ type, selected }: Partial) => { + switch (type) { + case 'select': + return selected ? ( + + + + + + ) : ( + + + + + + ) + case 'string': + default: + return selected ? ( + + + + + ) : ( + + ) + } +} + +const SelectTypeItem: FC = ({ + type, + selected, + onClick +}) => { + const { t } = useTranslation() + const typeName = t(`appDebug.variableConig.${type}`) + + return ( +
+ + {typeName} +
+ ) +} +export default React.memo(SelectTypeItem) diff --git a/web/app/components/app/configuration/config-var/select-type-item/style.module.css b/web/app/components/app/configuration/config-var/select-type-item/style.module.css new file mode 100644 index 0000000000..9f3dc278d2 --- /dev/null +++ b/web/app/components/app/configuration/config-var/select-type-item/style.module.css @@ -0,0 +1,38 @@ +.item { + display: flex; + align-items: center; + height: 32px; + width: 133px; + padding-left: 12px; + border-radius: 8px; + border: 1px solid #EAECF0; + box-shadow: 0px 1px 2px rgba(16, 24, 40, 0.05); + background-color: #fff; + cursor: pointer; +} + +.item:not(.selected):hover { + border-color: #B2CCFF; + background-color: #F5F8FF; + box-shadow: 0px 4px 8px -2px rgba(16, 24, 40, 0.1), 0px 2px 4px -2px rgba(16, 24, 40, 0.06); +} + +.item.selected { + border-color: #528BFF; + background-color: #F5F8FF; + box-shadow: 0px 1px 3px rgba(16, 24, 40, 0.1), 0px 1px 2px rgba(16, 24, 40, 0.06); +} + +.text { + font-size: 13px; + color: #667085; + font-weight: 500; +} + +.item.selected.text { + color: #155EEF; +} + +.item:not(.selected):hover { + color: #344054; +} \ No newline at end of file diff --git a/web/app/components/app/configuration/config-var/style.module.css b/web/app/components/app/configuration/config-var/style.module.css new file mode 100644 index 0000000000..733755d0c8 --- /dev/null +++ b/web/app/components/app/configuration/config-var/style.module.css @@ -0,0 +1,12 @@ +.table td { + padding-left: 12px; +} + +.table thead td { + height: 33px; + line-height: 33px; +} + +.table tbody tr:last-child td { + border-bottom: none; +} \ No newline at end of file diff --git a/web/app/components/app/configuration/config/feature/add-feature-btn/index.tsx b/web/app/components/app/configuration/config/feature/add-feature-btn/index.tsx new file mode 100644 index 0000000000..90c9fd9153 --- /dev/null +++ b/web/app/components/app/configuration/config/feature/add-feature-btn/index.tsx @@ -0,0 +1,31 @@ +'use client' +import React, { FC } from 'react' +import { useTranslation } from 'react-i18next' +import { PlusIcon } from '@heroicons/react/24/solid' + +export interface IAddFeatureBtnProps { + onClick: () => void +} + +const AddFeatureBtn: FC = ({ + onClick +}) => { + const { t } = useTranslation() + return ( +
+ +
{t('appDebug.operation.addFeature')}
+
+ ) +} +export default React.memo(AddFeatureBtn) diff --git a/web/app/components/app/configuration/config/feature/choose-feature/feature-item/index.tsx b/web/app/components/app/configuration/config/feature/choose-feature/feature-item/index.tsx new file mode 100644 index 0000000000..cfb3159e1a --- /dev/null +++ b/web/app/components/app/configuration/config/feature/choose-feature/feature-item/index.tsx @@ -0,0 +1,42 @@ +'use client' +import React, { FC } from 'react' +import Switch from '@/app/components/base/switch' + +export interface IFeatureItemProps { + icon: React.ReactNode + title: string + description: string + value: boolean + onChange: (value: boolean) => void +} + +const FeatureItem: FC = ({ + icon, + title, + description, + value, + onChange +}) => { + return ( +
+
+ {/* icon */} +
+ {icon} +
+
+
{title}
+
{description}
+
+
+ + +
+ ) +} +export default React.memo(FeatureItem) diff --git a/web/app/components/app/configuration/config/feature/choose-feature/index.tsx b/web/app/components/app/configuration/config/feature/choose-feature/index.tsx new file mode 100644 index 0000000000..174104663d --- /dev/null +++ b/web/app/components/app/configuration/config/feature/choose-feature/index.tsx @@ -0,0 +1,92 @@ +'use client' +import React, { FC } from 'react' +import { useTranslation } from 'react-i18next' +import Modal from '@/app/components/base/modal' +import FeatureItem from './feature-item' +import FeatureGroup from '../feature-group' +import MoreLikeThisIcon from '../../../base/icons/more-like-this-icon' +import SuggestedQuestionsAfterAnswerIcon from '@/app/components/app/configuration/base/icons/suggested-questions-after-answer-icon' + +interface IConfig { + openingStatement: boolean + moreLikeThis: boolean + suggestedQuestionsAfterAnswer: boolean +} + +export interface IChooseFeatureProps { + isShow: boolean + onClose: () => void + config: IConfig + isChatApp: boolean + onChange: (key: string, value: boolean) => void +} + +const OpeningStatementIcon = ( + + + +) + +const ChooseFeature: FC = ({ + isShow, + onClose, + isChatApp, + config, + onChange +}) => { + const { t } = useTranslation() + + return ( + +
+ {/* Chat Feature */} + {isChatApp && ( + + <> + onChange('openingStatement', value)} + /> + } + title={t('appDebug.feature.suggestedQuestionsAfterAnswer.title')} + description={t('appDebug.feature.suggestedQuestionsAfterAnswer.description')} + value={config.suggestedQuestionsAfterAnswer} + onChange={(value) => onChange('suggestedQuestionsAfterAnswer', value)} + /> + + + )} + + {/* Text Generation Feature */} + {!isChatApp && ( + + <> + } + title={t('appDebug.feature.moreLikeThis.title')} + description={t('appDebug.feature.moreLikeThis.description')} + value={config.moreLikeThis} + onChange={(value) => onChange('moreLikeThis', value)} + /> + + + )} +
+ +
+ ) +} +export default React.memo(ChooseFeature) diff --git a/web/app/components/app/configuration/config/feature/feature-group/index.tsx b/web/app/components/app/configuration/config/feature/feature-group/index.tsx new file mode 100644 index 0000000000..6bbe851134 --- /dev/null +++ b/web/app/components/app/configuration/config/feature/feature-group/index.tsx @@ -0,0 +1,30 @@ +'use client' +import React, { FC } from 'react' +import GroupName from '@/app/components/app/configuration/base/group-name' + +export interface IFeatureGroupProps { + title: string + description?: string + children: React.ReactNode +} + +const FeatureGroup: FC = ({ + title, + description, + children +}) => { + return ( +
+
+ + {description && ( +
{description}
+ )} +
+
+ {children} +
+
+ ) +} +export default React.memo(FeatureGroup) diff --git a/web/app/components/app/configuration/config/feature/use-feature.tsx b/web/app/components/app/configuration/config/feature/use-feature.tsx new file mode 100644 index 0000000000..c88f7c8fe5 --- /dev/null +++ b/web/app/components/app/configuration/config/feature/use-feature.tsx @@ -0,0 +1,58 @@ +import React, { useEffect } from 'react' + +function useFeature({ + introduction, + setIntroduction, + moreLikeThis, + setMoreLikeThis, + suggestedQuestionsAfterAnswer, + setSuggestedQuestionsAfterAnswer, +}: { + introduction: string + setIntroduction: (introduction: string) => void + moreLikeThis: boolean + setMoreLikeThis: (moreLikeThis: boolean) => void + suggestedQuestionsAfterAnswer: boolean + setSuggestedQuestionsAfterAnswer: (suggestedQuestionsAfterAnswer: boolean) => void +}) { + const [tempshowOpeningStatement, setTempShowOpeningStatement] = React.useState(!!introduction) + useEffect(() => { + // wait to api data back + if (!!introduction) { + setTempShowOpeningStatement(true) + } + }, [introduction]) + + // const [tempMoreLikeThis, setTempMoreLikeThis] = React.useState(moreLikeThis) + // useEffect(() => { + // setTempMoreLikeThis(moreLikeThis) + // }, [moreLikeThis]) + + const featureConfig = { + openingStatement: tempshowOpeningStatement, + moreLikeThis: moreLikeThis, + suggestedQuestionsAfterAnswer: suggestedQuestionsAfterAnswer + } + const handleFeatureChange = (key: string, value: boolean) => { + switch (key) { + case 'openingStatement': + if (!value) { + setIntroduction('') + } + setTempShowOpeningStatement(value) + break + case 'moreLikeThis': + setMoreLikeThis(value) + break + case 'suggestedQuestionsAfterAnswer': + setSuggestedQuestionsAfterAnswer(value) + break + } + } + return { + featureConfig, + handleFeatureChange + } +} + +export default useFeature \ No newline at end of file diff --git a/web/app/components/app/configuration/config/index.tsx b/web/app/components/app/configuration/config/index.tsx new file mode 100644 index 0000000000..4f9e3bcd4c --- /dev/null +++ b/web/app/components/app/configuration/config/index.tsx @@ -0,0 +1,153 @@ +'use client' +import type { FC } from 'react' +import React from 'react' +import { useContext } from 'use-context-selector' +import produce from 'immer' +import AddFeatureBtn from './feature/add-feature-btn' +import ChooseFeature from './feature/choose-feature' +import useFeature from './feature/use-feature' +import ConfigContext from '@/context/debug-configuration' +import DatasetConfig from '../dataset-config' +import ChatGroup from '../features/chat-group' +import ExperienceEnchanceGroup from '../features/experience-enchance-group' +import Toolbox from '../toolbox' +import ConfigPrompt from '@/app/components/app/configuration/config-prompt' +import ConfigVar from '@/app/components/app/configuration/config-var' +import type { PromptVariable } from '@/models/debug' +import { AppType } from '@/types/app' +import { useBoolean } from 'ahooks' + +const Config: FC = () => { + const { + mode, + introduction, + setIntroduction, + modelConfig, + setModelConfig, + setPrevPromptConfig, + setFormattingChanged, + moreLikeThisConifg, + setMoreLikeThisConifg, + suggestedQuestionsAfterAnswerConfig, + setSuggestedQuestionsAfterAnswerConfig + } = useContext(ConfigContext) + const isChatApp = mode === AppType.chat + + const promptTemplate = modelConfig.configs.prompt_template + const promptVariables = modelConfig.configs.prompt_variables + const handlePromptChange = (newTemplate: string, newVariables: PromptVariable[]) => { + const newModelConfig = produce(modelConfig, (draft) => { + draft.configs.prompt_template = newTemplate + draft.configs.prompt_variables = [...draft.configs.prompt_variables, ...newVariables] + }) + + if (modelConfig.configs.prompt_template !== newTemplate) { + setFormattingChanged(true) + } + + setPrevPromptConfig(modelConfig.configs) + setModelConfig(newModelConfig) + } + + const handlePromptVariablesNameChange = (newVariables: PromptVariable[]) => { + setPrevPromptConfig(modelConfig.configs) + const newModelConfig = produce(modelConfig, (draft) => { + draft.configs.prompt_variables = newVariables + }) + setModelConfig(newModelConfig) + } + + const [showChooseFeature, { + setTrue: showChooseFeatureTrue, + setFalse: showChooseFeatureFalse + }] = useBoolean(false) + const { featureConfig, handleFeatureChange } = useFeature({ + introduction, + setIntroduction, + moreLikeThis: moreLikeThisConifg.enabled, + setMoreLikeThis: (value) => { + setMoreLikeThisConifg(produce(moreLikeThisConifg, (draft) => { + draft.enabled = value + })) + }, + suggestedQuestionsAfterAnswer: suggestedQuestionsAfterAnswerConfig.enabled, + setSuggestedQuestionsAfterAnswer: (value) => { + setSuggestedQuestionsAfterAnswerConfig(produce(suggestedQuestionsAfterAnswerConfig, (draft) => { + draft.enabled = value + })) + }, + }) + + const hasChatConfig = isChatApp && (featureConfig.openingStatement || featureConfig.suggestedQuestionsAfterAnswer) + const hasToolbox = false + + return ( + <> +
+
+ +
+ {/* AutoMatic */} +
+
+ + {showChooseFeature && ( + + )} + {/* Template */} + + + {/* Variables */} + + + {/* Dataset */} + + + {/* ChatConifig */} + { + hasChatConfig && ( + + ) + } + + {/* TextnGeneration config */} + {moreLikeThisConifg.enabled && ( + + )} + + + {/* Toolbox */} + { + hasToolbox && ( + + ) + } +
+ + ) +} +export default React.memo(Config) diff --git a/web/app/components/app/configuration/ctrl-btn-group/index.tsx b/web/app/components/app/configuration/ctrl-btn-group/index.tsx new file mode 100644 index 0000000000..d7b5d49b0e --- /dev/null +++ b/web/app/components/app/configuration/ctrl-btn-group/index.tsx @@ -0,0 +1,24 @@ +'use client' +import type { FC } from 'react' +import React from 'react' +import { useTranslation } from 'react-i18next' +import s from './style.module.css' +import Button from '@/app/components/base/button' + +export type IContrlBtnGroupProps = { + onSave: () => void + onReset: () => void +} + +const ContrlBtnGroup: FC = ({ onSave, onReset }) => { + const { t } = useTranslation() + return ( +
+
+ + +
+
+ ) +} +export default React.memo(ContrlBtnGroup) diff --git a/web/app/components/app/configuration/ctrl-btn-group/style.module.css b/web/app/components/app/configuration/ctrl-btn-group/style.module.css new file mode 100644 index 0000000000..c7250b8f96 --- /dev/null +++ b/web/app/components/app/configuration/ctrl-btn-group/style.module.css @@ -0,0 +1,6 @@ +.ctrlBtn { + left: -16px; + right: -16px; + bottom: -16px; + border-top: 1px solid #F3F4F6; +} \ No newline at end of file diff --git a/web/app/components/app/configuration/dataset-config/card-item/index.tsx b/web/app/components/app/configuration/dataset-config/card-item/index.tsx new file mode 100644 index 0000000000..88d3716ce6 --- /dev/null +++ b/web/app/components/app/configuration/dataset-config/card-item/index.tsx @@ -0,0 +1,51 @@ +'use client' +import React, { FC } from 'react' +import cn from 'classnames' +import TypeIcon from '../type-icon' +import { useTranslation } from 'react-i18next' +import { formatNumber } from '@/utils/format' +import RemoveIcon from '../../base/icons/remove-icon' +import s from './style.module.css' + +export interface ICardItemProps { + className?: string + config: any + onRemove: (id: string) => void +} + + + +// const RemoveIcon = ({ className, onClick }: { className: string, onClick: () => void }) => ( +// +// +// +// ) + +const CardItem: FC = ({ + className, + config, + onRemove +}) => { + const { t } = useTranslation() + + return ( +
+
+ +
+
{config.name}
+
+ {formatNumber(config.word_count)} {t('appDebug.feature.dataSet.words')} · {formatNumber(config.document_count)} {t('appDebug.feature.dataSet.textBlocks')} +
+
+
+ + onRemove(config.id)} /> +
+ ) +} +export default React.memo(CardItem) diff --git a/web/app/components/app/configuration/dataset-config/card-item/style.module.css b/web/app/components/app/configuration/dataset-config/card-item/style.module.css new file mode 100644 index 0000000000..def8bf8d66 --- /dev/null +++ b/web/app/components/app/configuration/dataset-config/card-item/style.module.css @@ -0,0 +1,16 @@ +.card { + box-shadow: 0px 1px 2px rgba(16, 24, 40, 0.05); + width: calc(50% - 4px); +} + +.card:hover { + box-shadow: 0px 4px 8px -2px rgba(16, 24, 40, 0.1), 0px 2px 4px -2px rgba(16, 24, 40, 0.06); +} + +.deleteBtn { + display: none; +} + +.card:hover .deleteBtn { + display: block; +} \ No newline at end of file diff --git a/web/app/components/app/configuration/dataset-config/index.tsx b/web/app/components/app/configuration/dataset-config/index.tsx new file mode 100644 index 0000000000..46207ca396 --- /dev/null +++ b/web/app/components/app/configuration/dataset-config/index.tsx @@ -0,0 +1,79 @@ +'use client' +import React, { FC } from 'react' +import { useTranslation } from 'react-i18next' +import { useContext } from 'use-context-selector' +import ConfigContext from '@/context/debug-configuration' +import FeaturePanel from '../base/feature-panel' +import OperationBtn from '../base/operation-btn' +import CardItem from './card-item' +import { useBoolean } from 'ahooks' +import SelectDataSet from './select-dataset' +import { DataSet } from '@/models/datasets' +import { isEqual } from 'lodash-es' + +const Icon = ( + + + + +) + +const DatasetConfig: FC = () => { + const { t } = useTranslation() + const { + dataSets: dataSet, + setDataSets: setDataSet, + setFormattingChanged + } = useContext(ConfigContext) + const selectedIds = dataSet.map((item) => item.id) + + const hasData = dataSet.length > 0 + const [isShowSelectDataSet, { setTrue: showSelectDataSet, setFalse: hideSelectDataSet }] = useBoolean(false) + const handleSelect = (data: DataSet[]) => { + if (isEqual(data, dataSet)) { + hideSelectDataSet() + } + setFormattingChanged(true) + setDataSet(data) + hideSelectDataSet() + } + const onRemove = (id: string) => { + setDataSet(dataSet.filter((item) => item.id !== id)) + } + + + return ( + } + hasHeaderBottomBorder={!hasData} + > + {hasData ? ( +
+ {dataSet.map((item) => ( + + ))} +
+ ) : ( +
{t('appDebug.feature.dataSet.noData')}
+ )} + + {isShowSelectDataSet && ( + + )} +
+ ) +} +export default React.memo(DatasetConfig) diff --git a/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx b/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx new file mode 100644 index 0000000000..78c219bfc5 --- /dev/null +++ b/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx @@ -0,0 +1,130 @@ +'use client' +import React, { FC, useEffect } from 'react' +import cn from 'classnames' +import { useTranslation } from 'react-i18next' +import Modal from '@/app/components/base/modal' +import { DataSet } from '@/models/datasets' +import TypeIcon from '../type-icon' +import Button from '@/app/components/base/button' +import { fetchDatasets } from '@/service/datasets' +import Loading from '@/app/components/base/loading' +import { formatNumber } from '@/utils/format' +import Link from 'next/link' + +import s from './style.module.css' +import Toast from '@/app/components/base/toast' + +export interface ISelectDataSetProps { + isShow: boolean + onClose: () => void + selectedIds: string[] + onSelect: (dataSet: DataSet[]) => void +} + +const SelectDataSet: FC = ({ + isShow, + onClose, + selectedIds, + onSelect, +}) => { + const { t } = useTranslation() + const [selected, setSelected] = React.useState([]) + const [loaded, setLoaded] = React.useState(false) + const [datasets, setDataSets] = React.useState(null) + const hasNoData = !datasets || datasets?.length === 0 + // Only one dataset can be selected. Historical data retains data and supports multiple selections, but when saving, only one can be selected. This is based on considerations of performance and accuracy. + const canSelectMulti = selectedIds.length > 1 + useEffect(() => { + (async () => { + const { data } = await fetchDatasets({ url: '/datasets', params: { page: 1 } }) + setDataSets(data) + setLoaded(true) + setSelected(data.filter((item) => selectedIds.includes(item.id))) + })() + }, []) + const toggleSelect = (dataSet: DataSet) => { + const isSelected = selected.some((item) => item.id === dataSet.id) + if (isSelected) { + setSelected(selected.filter((item) => item.id !== dataSet.id)) + } + else { + if (canSelectMulti) { + setSelected([...selected, dataSet]) + } else { + setSelected([dataSet]) + } + } + } + + const handleSelect = () => { + if (selected.length > 1) { + Toast.notify({ + type: 'error', + message: t('appDebug.feature.dataSet.notSupportSelectMulti') + }) + return + } + onSelect(selected) + } + return ( + + {!loaded && ( +
+ +
+ )} + + {(loaded && hasNoData) && ( +
+ {t('appDebug.feature.dataSet.noDataSet')} + {t('appDebug.feature.dataSet.toCreate')} +
+ )} + + {datasets && datasets?.length > 0 && ( + <> +
+ {datasets.map((item) => ( +
i.id === item.id) && s.selected, 'flex justify-between items-center h-10 px-2 rounded-lg bg-white border border-gray-200 cursor-pointer')} + onClick={() => toggleSelect(item)} + > +
+ +
{item.name}
+
+ +
+ {formatNumber(item.word_count)} {t('appDebug.feature.dataSet.words')} · {formatNumber(item.document_count)} {t('appDebug.feature.dataSet.textBlocks')} +
+
+ ))} +
+ + )} + {loaded && ( +
+
+ {selected.length > 0 && `${selected.length} ${t('appDebug.feature.dataSet.selected')}`} +
+
+ + +
+
+ )} +
+ ) +} +export default React.memo(SelectDataSet) diff --git a/web/app/components/app/configuration/dataset-config/select-dataset/style.module.css b/web/app/components/app/configuration/dataset-config/select-dataset/style.module.css new file mode 100644 index 0000000000..9c73b88298 --- /dev/null +++ b/web/app/components/app/configuration/dataset-config/select-dataset/style.module.css @@ -0,0 +1,9 @@ +.item { + box-shadow: 0px 1px 2px rgba(16, 24, 40, 0.05); +} + +.item:hover, +.item.selected { + background: #F5F8FF; + border-color: #528BFF; +} \ No newline at end of file diff --git a/web/app/components/app/configuration/dataset-config/type-icon/index.tsx b/web/app/components/app/configuration/dataset-config/type-icon/index.tsx new file mode 100644 index 0000000000..9bebea6127 --- /dev/null +++ b/web/app/components/app/configuration/dataset-config/type-icon/index.tsx @@ -0,0 +1,33 @@ +'use client' +import React, { FC } from 'react' + +export interface ITypeIconProps { + type: 'upload_file' + size?: 'md' | 'lg' +} + +// data_source_type: current only support upload_file +const Icon = ({ type, size = "lg" }: ITypeIconProps) => { + const len = size === "lg" ? 32 : 24 + const iconMap = { + upload_file: ( + + + + + + ) + } + return iconMap[type] + +} + +const TypeIcon: FC = ({ + type, + size = 'lg', +}) => { + return ( + + ) +} +export default React.memo(TypeIcon) diff --git a/web/app/components/app/configuration/debug/index.tsx b/web/app/components/app/configuration/debug/index.tsx new file mode 100644 index 0000000000..c7ac273886 --- /dev/null +++ b/web/app/components/app/configuration/debug/index.tsx @@ -0,0 +1,417 @@ +'use client' +import type { FC } from 'react' +import { useTranslation } from 'react-i18next' +import React, { useEffect, useState, useRef } from 'react' +import cn from 'classnames' +import produce from 'immer' +import { useBoolean, useGetState } from 'ahooks' +import { useContext } from 'use-context-selector' +import { AppType } from '@/types/app' +import PromptValuePanel, { replaceStringWithValues } from '@/app/components/app/configuration/prompt-value-panel' +import type { IChatItem } from '@/app/components/app/chat' +import Chat from '@/app/components/app/chat' +import ConfigContext from '@/context/debug-configuration' +import { ToastContext } from '@/app/components/base/toast' +import { sendChatMessage, sendCompletionMessage, fetchSuggestedQuestions, fetchConvesationMessages } from '@/service/debug' +import Button from '@/app/components/base/button' +import type { ModelConfig as BackendModelConfig } from '@/types/app' +import { promptVariablesToUserInputsForm } from '@/utils/model-config' +import HasNotSetAPIKEY from '../base/warning-mask/has-not-set-api' +import FormattingChanged from '../base/warning-mask/formatting-changed' +import TextGeneration from '@/app/components/app/text-generate/item' +import GroupName from '../base/group-name' +import dayjs from 'dayjs' +import { IS_CE_EDITION } from '@/config' + +interface IDebug { + hasSetAPIKEY: boolean + onSetting: () => void +} + +const Debug: FC = ({ + hasSetAPIKEY = true, + onSetting +}) => { + const { t } = useTranslation() + const { + appId, + mode, + introduction, + suggestedQuestionsAfterAnswerConfig, + moreLikeThisConifg, + inputs, + // setInputs, + formattingChanged, + setFormattingChanged, + conversationId, + setConversationId, + controlClearChatMessage, + dataSets, + modelConfig, + completionParams, + } = useContext(ConfigContext) + + + const [chatList, setChatList, getChatList] = useGetState([]) + const chatListDomRef = useRef(null) + useEffect(() => { + // scroll to bottom + if (chatListDomRef.current) { + chatListDomRef.current.scrollTop = chatListDomRef.current.scrollHeight + } + }, [chatList]) + + const getIntroduction = () => replaceStringWithValues(introduction, modelConfig.configs.prompt_variables, inputs) + useEffect(() => { + if (introduction && !chatList.some(item => !item.isAnswer)) { + setChatList([{ + id: `${Date.now()}`, + content: getIntroduction(), + isAnswer: true, + isOpeningStatement: true + }]) + } + }, [introduction, modelConfig.configs.prompt_variables, inputs]) + + const [isResponsing, { setTrue: setResponsingTrue, setFalse: setResponsingFalse }] = useBoolean(false) + const [abortController, setAbortController] = useState(null) + const [isShowFormattingChangeConfirm, setIsShowFormattingChangeConfirm] = useState(false) + + useEffect(() => { + if (formattingChanged && chatList.some(item => !item.isAnswer)) { + setIsShowFormattingChangeConfirm(true) + } + setFormattingChanged(false) + }, [formattingChanged]) + + const clearConversation = () => { + setConversationId(null) + abortController?.abort() + setResponsingFalse() + setChatList(introduction ? [{ + id: `${Date.now()}`, + content: getIntroduction(), + isAnswer: true, + isOpeningStatement: true + }] : []) + setIsShowSuggestion(false) + } + + const handleConfirm = () => { + clearConversation() + setIsShowFormattingChangeConfirm(false) + } + + const handleCancel = () => { + setIsShowFormattingChangeConfirm(false) + } + + const { notify } = useContext(ToastContext) + const logError = (message: string) => { + notify({ type: 'error', message }) + } + + const checkCanSend = () => { + let hasEmptyInput = false + const requiredVars = modelConfig.configs.prompt_variables.filter(({ key, name, required }) => { + const res = (!key || !key.trim()) || (!name || !name.trim()) || (required || required === undefined || required === null) + return res + }) // compatible with old version + // debugger + requiredVars.forEach(({ key }) => { + if (hasEmptyInput) { + return + } + if (!inputs[key]) { + hasEmptyInput = true + } + }) + + if (hasEmptyInput) { + logError(t('appDebug.errorMessage.valueOfVarRequired')) + return false + } + return !hasEmptyInput + } + + const [isShowSuggestion, setIsShowSuggestion] = useState(false) + const doShowSuggestion = isShowSuggestion && !isResponsing + const [suggestQuestions, setSuggestQuestions] = useState([]) + const onSend = async (message: string) => { + if (isResponsing) { + notify({ type: 'info', message: t('appDebug.errorMessage.waitForResponse') }) + return false + } + + const postDatasets = dataSets.map(({ id }) => ({ + dataset: { + enabled: true, + id, + } + })) + + const postModelConfig: BackendModelConfig = { + pre_prompt: modelConfig.configs.prompt_template, + user_input_form: promptVariablesToUserInputsForm(modelConfig.configs.prompt_variables), + opening_statement: introduction, + more_like_this: { + enabled: false + }, + suggested_questions_after_answer: suggestedQuestionsAfterAnswerConfig, + agent_mode: { + enabled: true, + tools: [...postDatasets] + }, + model: { + provider: modelConfig.provider, + name: modelConfig.model_id, + completion_params: completionParams as any + }, + } + + const data = { + conversation_id: conversationId, + inputs, + query: message, + model_config: postModelConfig, + } + + // qustion + const questionId = `question-${Date.now()}` + const questionItem = { + id: questionId, + content: message, + isAnswer: false, + } + + const placeholderAnswerId = `answer-placeholder-${Date.now()}` + const placeholderAnswerItem = { + id: placeholderAnswerId, + content: '', + isAnswer: true, + } + + const newList = [...getChatList(), questionItem, placeholderAnswerItem] + setChatList(newList) + + // answer + const responseItem = { + id: `${Date.now()}`, + content: '', + isAnswer: true, + } + + let _newConversationId: null | string = null + + setResponsingTrue() + setIsShowSuggestion(false) + sendChatMessage(appId, data, { + getAbortController: (abortController) => { + setAbortController(abortController) + }, + onData: (message: string, isFirstMessage: boolean, { conversationId: newConversationId, messageId }: any) => { + responseItem.content = responseItem.content + message + if (isFirstMessage && newConversationId) { + setConversationId(newConversationId) + _newConversationId = newConversationId + } + if (messageId) { + responseItem.id = messageId + } + // closesure new list is outdated. + const newListWithAnswer = produce( + getChatList().filter(item => item.id !== responseItem.id && item.id !== placeholderAnswerId), + (draft) => { + if (!draft.find(item => item.id === questionId)) { + draft.push({ ...questionItem }) + } + draft.push({ ...responseItem }) + }) + setChatList(newListWithAnswer) + }, + async onCompleted(hasError?: boolean) { + setResponsingFalse() + if (hasError) { + return + } + if (_newConversationId) { + const { data }: any = await fetchConvesationMessages(appId, _newConversationId as string) + const newResponseItem = data.find((item: any) => item.id === responseItem.id) + if (!newResponseItem) { + return + } + setChatList(produce(getChatList(), draft => { + const index = draft.findIndex(item => item.id === responseItem.id) + if (index !== -1) { + draft[index] = { + ...draft[index], + more: { + time: dayjs.unix(newResponseItem.created_at).format('hh:mm A'), + tokens: newResponseItem.answer_tokens, + latency: (newResponseItem.provider_response_latency / 1000).toFixed(2), + } + } + } + })) + } + if (suggestedQuestionsAfterAnswerConfig.enabled) { + const { data }: any = await fetchSuggestedQuestions(appId, responseItem.id) + setSuggestQuestions(data) + setIsShowSuggestion(true) + } + }, + onError() { + setResponsingFalse() + // role back placeholder answer + setChatList(produce(getChatList(), draft => { + draft.splice(draft.findIndex(item => item.id === placeholderAnswerId), 1) + })) + } + }) + return true + } + + useEffect(() => { + if (controlClearChatMessage) + setChatList([]) + }, [controlClearChatMessage]) + + const [completionQuery, setCompletionQuery] = useState('') + const [completionRes, setCompletionRes] = useState(``) + + const sendTextCompletion = async () => { + if (isResponsing) { + notify({ type: 'info', message: t('appDebug.errorMessage.waitForResponse') }) + return false + } + + if (!checkCanSend()) + return + + if (!completionQuery) { + logError(t('appDebug.errorMessage.queryRequired')) + return false + } + + const postDatasets = dataSets.map(({ id }) => ({ + dataset: { + enabled: true, + id, + } + })) + + const postModelConfig: BackendModelConfig = { + pre_prompt: modelConfig.configs.prompt_template, + user_input_form: promptVariablesToUserInputsForm(modelConfig.configs.prompt_variables), + opening_statement: introduction, + suggested_questions_after_answer: suggestedQuestionsAfterAnswerConfig, + more_like_this: moreLikeThisConifg, + agent_mode: { + enabled: true, + tools: [...postDatasets] + }, + model: { + provider: modelConfig.provider, + name: modelConfig.model_id, + completion_params: completionParams as any + }, + } + + + const data = { + inputs, + query: completionQuery, + model_config: postModelConfig, + } + + setCompletionRes('') + const res: string[] = [] + + setResponsingTrue() + sendCompletionMessage(appId, data, { + onData: (data: string) => { + res.push(data) + setCompletionRes(res.join('')) + }, + onCompleted() { + setResponsingFalse() + }, + onError() { + setResponsingFalse() + } + }) + } + + + return ( + <> +
+
+
{t('appDebug.inputs.title')}
+ {mode === 'chat' && ( + + )} +
+ +
+
+ {/* Chat */} + {mode === AppType.chat && ( +
+
+
+ {/* {JSON.stringify(chatList)} */} + { + abortController?.abort() + setResponsingFalse() + }} + isShowSuggestion={doShowSuggestion} + suggestionList={suggestQuestions} + /> +
+
+
+ )} + {/* Text Generation */} + {mode === AppType.completion && ( +
+ + {(completionRes || isResponsing) && ( + + )} +
+ )} + {isShowFormattingChangeConfirm && ( + + )} +
+ + {!hasSetAPIKEY && ()} + + ) +} +export default React.memo(Debug) diff --git a/web/app/components/app/configuration/features/chat-group/index.tsx b/web/app/components/app/configuration/features/chat-group/index.tsx new file mode 100644 index 0000000000..774137edb1 --- /dev/null +++ b/web/app/components/app/configuration/features/chat-group/index.tsx @@ -0,0 +1,40 @@ +'use client' +import React, { FC } from 'react' +import GroupName from '../../base/group-name' +import OpeningStatement, { IOpeningStatementProps } from './opening-statement' +import SuggestedQuestionsAfterAnswer from './suggested-questions-after-answer' +import { useTranslation } from 'react-i18next' + +/* +* Include +* 1. Conversation Opener +* 2. Opening Suggestion +* 3. Next question suggestion +*/ +interface ChatGroupProps { + isShowOpeningStatement: boolean + openingStatementConfig: IOpeningStatementProps + isShowSuggestedQuestionsAfterAnswer: boolean +} +const ChatGroup: FC = ({ + isShowOpeningStatement, + openingStatementConfig, + isShowSuggestedQuestionsAfterAnswer +}) => { + const { t } = useTranslation() + + return ( +
+ +
+ {isShowOpeningStatement && ( + + )} + {isShowSuggestedQuestionsAfterAnswer && ( + + )} +
+
+ ) +} +export default React.memo(ChatGroup) diff --git a/web/app/components/app/configuration/features/chat-group/opening-statement/index.tsx b/web/app/components/app/configuration/features/chat-group/opening-statement/index.tsx new file mode 100644 index 0000000000..3feacd22ab --- /dev/null +++ b/web/app/components/app/configuration/features/chat-group/opening-statement/index.tsx @@ -0,0 +1,177 @@ +'use client' +import React, { FC, useEffect, useRef, useState } from 'react' +import cn from 'classnames' +import { useContext } from 'use-context-selector' +import ConfigContext from '@/context/debug-configuration' +import produce from 'immer' +import { useTranslation } from 'react-i18next' +import { useBoolean } from 'ahooks' +import Panel from '@/app/components/app/configuration/base/feature-panel' +import Button from '@/app/components/base/button' +import OperationBtn from '@/app/components/app/configuration/base/operation-btn' +import { getInputKeys } from '@/app/components/base/block-input' +import ConfirmAddVar from '@/app/components/app/configuration/config-prompt/confirm-add-var' +import { getNewVar } from '@/utils/var' +import { varHighlightHTML } from '@/app/components/app/configuration/base/var-highlight' + +export interface IOpeningStatementProps { + promptTemplate: string + value: string + onChange: (value: string) => void +} + +// regex to match the {{}} and replace it with a span +const regex = /\{\{([^}]+)\}\}/g + +const OpeningStatement: FC = ({ + value = '', + onChange +}) => { + const { t } = useTranslation() + const { + modelConfig, + setModelConfig, + } = useContext(ConfigContext) + const promptVariables = modelConfig.configs.prompt_variables + const [notIncludeKeys, setNotIncludeKeys] = useState([]) + + const hasValue = !!(value || '').trim() + const inputRef = useRef(null) + + const [isFocus, { setTrue: didSetFocus, setFalse: setBlur }] = useBoolean(false) + const setFocus = () => { + didSetFocus() + setTimeout(() => { + const input = inputRef.current + if (input) { + input.focus() + input.setSelectionRange(input.value.length, input.value.length) + } + }, 0) + } + + const [tempValue, setTempValue] = useState(value) + useEffect(() => { + setTempValue(value || '') + }, [value]) + + const coloredContent = (tempValue || '') + .replace(regex, varHighlightHTML({ name: '$1' })) // `{{$1}}` + .replace(/\n/g, '
') + + + const handleEdit = () => { + setFocus() + } + + const [isShowConfirmAddVar, { setTrue: showConfirmAddVar, setFalse: hideConfirmAddVar }] = useBoolean(false) + + const handleCancel = () => { + setBlur() + setTempValue(value) + } + + const handleConfirm = () => { + const keys = getInputKeys(tempValue) + const promptKeys = promptVariables.map((item) => item.key) + let notIncludeKeys: string[] = [] + + if (promptKeys.length === 0) { + if (keys.length > 0) { + notIncludeKeys = keys + } + } else { + notIncludeKeys = keys.filter((key) => !promptKeys.includes(key)) + } + + if (notIncludeKeys.length > 0) { + setNotIncludeKeys(notIncludeKeys) + showConfirmAddVar() + return + } + setBlur() + onChange(tempValue) + } + + const cancelAutoAddVar = () => { + onChange(tempValue) + hideConfirmAddVar() + setBlur() + } + + const autoAddVar = () => { + const newModelConfig = produce(modelConfig, (draft) => { + draft.configs.prompt_variables = [...draft.configs.prompt_variables, ...notIncludeKeys.map((key) => getNewVar(key))] + }) + onChange(tempValue) + setModelConfig(newModelConfig) + hideConfirmAddVar() + setBlur() + } + + const headerRight = ( + + ) + + return ( + + + + } + headerRight={headerRight} + hasHeaderBottomBorder={!hasValue} + isFocus={isFocus} + > +
+ {(hasValue || (!hasValue && isFocus)) ? ( + <> + {isFocus ? ( + + ) : ( +
+ )} + + {/* Operation Bar */} + {isFocus && ( +
+
{t('appDebug.openingStatement.varTip')}
+ +
+ + +
+
+ )} + + ) : ( +
{t('appDebug.openingStatement.noDataPlaceHolder')}
+ )} + + {isShowConfirmAddVar && ( + + )} + +
+
+ ) +} +export default React.memo(OpeningStatement) diff --git a/web/app/components/app/configuration/features/chat-group/suggested-questions-after-answer/index.tsx b/web/app/components/app/configuration/features/chat-group/suggested-questions-after-answer/index.tsx new file mode 100644 index 0000000000..0c5b8cf4db --- /dev/null +++ b/web/app/components/app/configuration/features/chat-group/suggested-questions-after-answer/index.tsx @@ -0,0 +1,33 @@ +'use client' +import React, { FC } from 'react' +import { useTranslation } from 'react-i18next' +import Panel from '@/app/components/app/configuration/base/feature-panel' +import SuggestedQuestionsAfterAnswerIcon from '@/app/components/app/configuration/base/icons/suggested-questions-after-answer-icon' +import Tooltip from '@/app/components/base/tooltip' + +const SuggestedQuestionsAfterAnswer: FC = () => { + const { t } = useTranslation() + + return ( + +
{t('appDebug.feature.suggestedQuestionsAfterAnswer.title')}
+ + {t('appDebug.feature.suggestedQuestionsAfterAnswer.description')} + } selector='suggestion-question-tooltip'> + + + + + + } + headerIcon={} + headerRight={ +
{t('appDebug.feature.suggestedQuestionsAfterAnswer.resDes')}
+ } + noBodySpacing + /> + ) +} +export default React.memo(SuggestedQuestionsAfterAnswer) diff --git a/web/app/components/app/configuration/features/experience-enchance-group/index.tsx b/web/app/components/app/configuration/features/experience-enchance-group/index.tsx new file mode 100644 index 0000000000..0c898cbee4 --- /dev/null +++ b/web/app/components/app/configuration/features/experience-enchance-group/index.tsx @@ -0,0 +1,21 @@ +'use client' +import React, { FC } from 'react' +import { useTranslation } from 'react-i18next' +import GroupName from '../../base/group-name' +import MoreLikeThis from './more-like-this' + +/* +* Include +* 1. More like this +*/ +const ExperienceEnchanceGroup: FC = () => { + const { t } = useTranslation() + + return ( +
+ + +
+ ) +} +export default React.memo(ExperienceEnchanceGroup) diff --git a/web/app/components/app/configuration/features/experience-enchance-group/more-like-this/index.tsx b/web/app/components/app/configuration/features/experience-enchance-group/more-like-this/index.tsx new file mode 100644 index 0000000000..91d77fee14 --- /dev/null +++ b/web/app/components/app/configuration/features/experience-enchance-group/more-like-this/index.tsx @@ -0,0 +1,50 @@ +'use client' +import React, { FC } from 'react' +import Panel from '@/app/components/app/configuration/base/feature-panel' +import { useTranslation } from 'react-i18next' +import MoreLikeThisIcon from '../../../base/icons/more-like-this-icon' +import { XMarkIcon } from '@heroicons/react/24/outline' +import { useLocalStorageState } from 'ahooks' + +const GENERATE_NUM = 1 + +const warningIcon = ( + + + + +) +const MoreLikeThis: FC = () => { + const { t } = useTranslation() + + const [isHideTip, setIsHideTip] = useLocalStorageState('isHideMoreLikeThisTip', { + defaultValue: false, + }) + + const headerRight = ( +
{t('appDebug.feature.moreLikeThis.generateNumTip')} {GENERATE_NUM}
+ ) + return ( + } + headerRight={headerRight} + noBodySpacing + > + {!isHideTip && ( +
+
+
{warningIcon}
+
{t('appDebug.feature.moreLikeThis.tip')}
+
+
setIsHideTip(true)}> + +
+
+ )} + +
+ ) +} +export default React.memo(MoreLikeThis) diff --git a/web/app/components/app/configuration/index.tsx b/web/app/components/app/configuration/index.tsx new file mode 100644 index 0000000000..03bea3f316 --- /dev/null +++ b/web/app/components/app/configuration/index.tsx @@ -0,0 +1,317 @@ +'use client' +import type { FC } from 'react' +import React, { useEffect, useState } from 'react' +import { useTranslation } from 'react-i18next' +import { useContext } from 'use-context-selector' +import { usePathname } from 'next/navigation' +import produce from 'immer' +import type { CompletionParams, Inputs, ModelConfig, PromptConfig, PromptVariable, MoreLikeThisConfig } from '@/models/debug' +import type { DataSet } from '@/models/datasets' +import type { ModelConfig as BackendModelConfig } from '@/types/app' +import ConfigContext from '@/context/debug-configuration' +import ConfigModel from '@/app/components/app/configuration/config-model' +import Config from '@/app/components/app/configuration/config' +import Debug from '@/app/components/app/configuration/debug' +import Confirm from '@/app/components/base/confirm' +import type { AppDetailResponse } from '@/models/app' +import { ToastContext } from '@/app/components/base/toast' +import { fetchTenantInfo } from '@/service/common' +import { fetchAppDetail, updateAppModelConfig } from '@/service/apps' +import { userInputsFormToPromptVariables, promptVariablesToUserInputsForm } from '@/utils/model-config' +import { fetchDatasets } from '@/service/datasets' +import AccountSetting from '@/app/components/header/account-setting' +import { useBoolean } from 'ahooks' +import Button from '../../base/button' +import Loading from '../../base/loading' + +const Configuration: FC = () => { + const { t } = useTranslation() + const { notify } = useContext(ToastContext) + + const [hasFetchedDetail, setHasFetchedDetail] = useState(false) + const [hasFetchedKey, setHasFetchedKey] = useState(false) + const isLoading = !hasFetchedDetail || !hasFetchedKey + const pathname = usePathname() + const matched = pathname.match(/\/app\/([^/]+)/) + const appId = (matched?.length && matched[1]) ? matched[1] : '' + const [mode, setMode] = useState('') + const [pusblisedConfig, setPusblisedConfig] = useState<{ + modelConfig: ModelConfig, + completionParams: CompletionParams + } | null>(null) + + const [conversationId, setConversationId] = useState('') + const [introduction, setIntroduction] = useState('') + const [controlClearChatMessage, setControlClearChatMessage] = useState(0) + const [prevPromptConfig, setPrevPromptConfig] = useState({ + prompt_template: '', + prompt_variables: [], + }) + const [moreLikeThisConifg, setMoreLikeThisConifg] = useState({ + enabled: false, + }) + const [suggestedQuestionsAfterAnswerConfig, setSuggestedQuestionsAfterAnswerConfig] = useState({ + enabled: false, + }) + const [formattingChanged, setFormattingChanged] = useState(false) + const [inputs, setInputs] = useState({}) + const [query, setQuery] = useState('') + const [completionParams, setCompletionParams] = useState({ + max_tokens: 16, + temperature: 1, // 0-2 + top_p: 1, + presence_penalty: 1, // -2-2 + frequency_penalty: 1, // -2-2 + }) + const [modelConfig, doSetModelConfig] = useState({ + provider: 'openai', + model_id: 'gpt-3.5-turbo', + configs: { + prompt_template: '', + prompt_variables: [] as PromptVariable[], + }, + }) + + const setModelConfig = (newModelConfig: ModelConfig) => { + doSetModelConfig(newModelConfig) + } + + const setModelId = (modelId: string) => { + const newModelConfig = produce(modelConfig, (draft) => { + draft.model_id = modelId + }) + setModelConfig(newModelConfig) + } + + const syncToPublishedConfig = (_pusblisedConfig: any) => { + setModelConfig(_pusblisedConfig.modelConfig) + setCompletionParams(_pusblisedConfig.completionParams) + } + + const [dataSets, setDataSets] = useState([]) + + const [hasSetCustomAPIKEY, setHasSetCustomerAPIKEY] = useState(true) + const [isTrailFinished, setIsTrailFinished] = useState(false) + const hasSetAPIKEY = hasSetCustomAPIKEY || !isTrailFinished + + const [isShowSetAPIKey, { setTrue: showSetAPIKey, setFalse: hideSetAPIkey }] = useBoolean() + + const checkAPIKey = async () => { + const { in_trail, trial_end_reason } = await fetchTenantInfo({ url: '/info' }) + const isTrailFinished = in_trail && trial_end_reason === 'trial_exceeded' + const hasSetCustomAPIKEY = trial_end_reason === 'using_custom' + setHasSetCustomerAPIKEY(hasSetCustomAPIKEY) + setIsTrailFinished(isTrailFinished) + setHasFetchedKey(true) + } + + useEffect(() => { + checkAPIKey() + }, []) + + useEffect(() => { + (fetchAppDetail({ url: '/apps', id: appId }) as any).then(async (res: AppDetailResponse) => { + setMode(res.mode) + const modelConfig = res.model_config + const model = res.model_config.model + + let datasets: any = null + if (modelConfig.agent_mode?.enabled) { + datasets = modelConfig.agent_mode?.tools.filter(({ dataset }: any) => dataset?.enabled) + } + + if (dataSets && datasets?.length && datasets?.length > 0) { + const { data: dataSetsWithDetail } = await fetchDatasets({ url: '/datasets', params: { page: 1, ids: datasets.map(({ dataset }: any) => dataset.id) } }) + datasets = dataSetsWithDetail + setDataSets(datasets) + } + const config = { + modelConfig: { + provider: model.provider, + model_id: model.name, + configs: { + prompt_template: modelConfig.pre_prompt, + prompt_variables: userInputsFormToPromptVariables(modelConfig.user_input_form) + }, + }, + completionParams: model.completion_params, + } + syncToPublishedConfig(config) + setPusblisedConfig(config) + setIntroduction(modelConfig.opening_statement) + if (modelConfig.more_like_this) { + setMoreLikeThisConifg(modelConfig.more_like_this) + } + if (modelConfig.suggested_questions_after_answer) { + setSuggestedQuestionsAfterAnswerConfig(modelConfig.suggested_questions_after_answer) + } + setHasFetchedDetail(true) + }) + }, [appId]) + + const saveAppConfig = async () => { + const modelId = modelConfig.model_id + const promptTemplate = modelConfig.configs.prompt_template + const promptVariables = modelConfig.configs.prompt_variables + + // not save empty key adn name + // const missingNameItem = promptVariables.find(item => item.name.trim() === '') + // if (missingNameItem) { + // notify({ type: 'error', message: t('appDebug.errorMessage.nameOfKeyRequired', { key: missingNameItem.key }) }) + // return + // } + + const postDatasets = dataSets.map(({ id }) => ({ + dataset: { + enabled: true, + id, + } + })) + + // new model config data struct + const data: BackendModelConfig = { + pre_prompt: promptTemplate, + user_input_form: promptVariablesToUserInputsForm(promptVariables), + opening_statement: introduction || '', + more_like_this: moreLikeThisConifg, + suggested_questions_after_answer: suggestedQuestionsAfterAnswerConfig, + agent_mode: { + enabled: true, + tools: [...postDatasets] + }, + model: { + provider: modelConfig.provider, + name: modelId, + completion_params: completionParams as any, + }, + } + + await updateAppModelConfig({ url: `/apps/${appId}/model-config`, body: data }) + setPusblisedConfig({ + modelConfig, + completionParams, + }) + notify({ type: 'success', message: t('common.api.success'), duration: 3000 }) + } + + const [showConfirm, setShowConfirm] = useState(false) + const resetAppConfig = () => { + // debugger + syncToPublishedConfig(pusblisedConfig) + setShowConfirm(false) + } + + const [showUseGPT4Confirm, setShowUseGPT4Confirm] = useState(false) + const [showSetAPIKeyModal, setShowSetAPIKeyModal] = useState(false) + + if (isLoading) { + return
+ +
+ } + + return ( + + <> +
+
+
{t('appDebug.pageTitle')}
+
+ {/* Model and Parameters */} + { + setCompletionParams(newParams) + }} + disabled={!hasSetAPIKEY} + canUseGPT4={hasSetCustomAPIKEY} + onShowUseGPT4Confirm={() => { + setShowUseGPT4Confirm(true) + }} + /> +
+ + +
+
+
+
+ +
+
+ +
+
+
+ {showConfirm && ( + setShowConfirm(false)} + onConfirm={resetAppConfig} + onCancel={() => setShowConfirm(false)} + /> + )} + {showUseGPT4Confirm && ( + setShowUseGPT4Confirm(false)} + onConfirm={() => { + setShowSetAPIKeyModal(true) + setShowUseGPT4Confirm(false) + }} + onCancel={() => setShowUseGPT4Confirm(false)} + /> + )} + { + showSetAPIKeyModal && ( + { + setShowSetAPIKeyModal(false) + }} /> + ) + } + {isShowSetAPIKey && { + await checkAPIKey() + hideSetAPIkey() + }} />} + +
+ ) +} +export default React.memo(Configuration) diff --git a/web/app/components/app/configuration/prompt-value-panel/index.tsx b/web/app/components/app/configuration/prompt-value-panel/index.tsx new file mode 100644 index 0000000000..b86094556c --- /dev/null +++ b/web/app/components/app/configuration/prompt-value-panel/index.tsx @@ -0,0 +1,210 @@ +'use client' +import type { FC } from 'react' +import React from 'react' +import { useTranslation } from 'react-i18next' +import { useContext } from 'use-context-selector' +import { + PlayIcon, +} from '@heroicons/react/24/solid' +import ConfigContext from '@/context/debug-configuration' +import type { PromptVariable } from '@/models/debug' +import { AppType } from '@/types/app' +import Select from '@/app/components/base/select' +import { DEFAULT_VALUE_MAX_LEN } from '@/config' +import VarIcon from '../base/icons/var-icon' +import Button from '@/app/components/base/button' + +export type IPromptValuePanelProps = { + appType: AppType + value?: string + onChange?: (value: string) => void + onSend?: () => void +} + +const starIcon = ( + + + + + +) + +const PromptValuePanel: FC = ({ + appType, + value, + onChange, + onSend, +}) => { + const { t } = useTranslation() + const { modelConfig, inputs, setInputs } = useContext(ConfigContext) + const promptTemplate = modelConfig.configs.prompt_template + const promptVariables = modelConfig.configs.prompt_variables.filter(({ key, name }) => { + return key && key?.trim() && name && name?.trim() + }) + + const promptVariableObj = (() => { + const obj: Record = {} + promptVariables.forEach((input) => { + obj[input.key] = true + }) + return obj + })() + + const handleInputValueChange = (key: string, value: string) => { + if (!(key in promptVariableObj)) + return + + const newInputs = { ...inputs } + promptVariables.forEach((input) => { + if (input.key === key) + newInputs[key] = value + }) + setInputs(newInputs) + } + + const promptPreview = ( +
+
+
+ {starIcon} +
{t('appDebug.inputs.previewTitle')}
+
+
+ { + (promptTemplate && promptTemplate?.trim()) ? ( +
+
+ ) : ( +
{t('appDebug.inputs.noPrompt')}
+ ) + } +
+
+
+ ) + + return ( +
+ {promptPreview} + +
+
+
+
+
{t('appDebug.inputs.userInputField')}
+
+ {appType === AppType.completion && promptVariables.length > 0 && ( +
{t('appDebug.inputs.completionVarTip')}
+ )} +
+ { + promptVariables.length > 0 ? ( +
+ {promptVariables.map(({ key, name, type, options, max_length, required }) => ( +
+
{name || key}
+ {type === 'select' ? ( + { handleInputValueChange(key, e.target.value) }} + maxLength={max_length || DEFAULT_VALUE_MAX_LEN} + /> + )} + +
+ ))} +
+ ) : ( +
{t('appDebug.inputs.noVar')}
+ ) + } +
+ + { + appType === AppType.completion && ( +
+
+
+
+
{t('appDebug.inputs.queryTitle')}
+
+
+ +
+
+
+ {value?.length} +
+ +
+
+
+
+
+ ) + } +
+ ) +} + +export default React.memo(PromptValuePanel) + +function replaceStringWithValuesWithFormat(str: string, promptVariables: PromptVariable[], inputs: Record) { + return str.replace(/\{\{([^}]+)\}\}/g, (match, key) => { + const name = inputs[key] + if (name) { // has set value + return `
${name}
` + } + + const valueObj: PromptVariable | undefined = promptVariables.find(v => v.key === key) + return `
${valueObj ? valueObj.name : match}
` + }) +} + +export function replaceStringWithValues(str: string, promptVariables: PromptVariable[], inputs: Record) { + return str.replace(/\{\{([^}]+)\}\}/g, (match, key) => { + const name = inputs[key] + if (name) { // has set value + return name + } + + const valueObj: PromptVariable | undefined = promptVariables.find(v => v.key === key) + return valueObj ? `{{${valueObj.name}}}` : match + }) +} + +// \n -> br +function format(str: string) { + return str.replaceAll('\n', '
') +} diff --git a/web/app/components/app/configuration/toolbox/index.tsx b/web/app/components/app/configuration/toolbox/index.tsx new file mode 100644 index 0000000000..0304b3f32e --- /dev/null +++ b/web/app/components/app/configuration/toolbox/index.tsx @@ -0,0 +1,26 @@ +'use client' +import React, { FC } from 'react' +import GroupName from '../base/group-name' + +export interface IToolboxProps { + searchToolConfig: any + sensitiveWordAvoidanceConifg: any +} + +/* +* Include +* 1. Search Tool +* 2. Sensitive word avoidance +*/ +const Toolbox: FC = ({ searchToolConfig, sensitiveWordAvoidanceConifg }) => { + return ( +
+ +
+ {searchToolConfig?.enabled &&
Search Tool
} + {sensitiveWordAvoidanceConifg?.enabled &&
Sensitive word avoidance
} +
+
+ ) +} +export default React.memo(Toolbox) diff --git a/web/app/components/app/log/filter.tsx b/web/app/components/app/log/filter.tsx new file mode 100644 index 0000000000..45c70b5c4d --- /dev/null +++ b/web/app/components/app/log/filter.tsx @@ -0,0 +1,83 @@ +'use client' +import type { FC } from 'react' +import React from 'react' +import { useTranslation } from 'react-i18next' +import { + MagnifyingGlassIcon, +} from '@heroicons/react/24/solid' +import useSWR from 'swr' +import dayjs from 'dayjs' +import quarterOfYear from 'dayjs/plugin/quarterOfYear' +import type { QueryParam } from './index' +import { SimpleSelect } from '@/app/components/base/select' +import { fetchAnnotationsCount } from '@/service/log' +dayjs.extend(quarterOfYear) + +const today = dayjs() + +export const TIME_PERIOD_LIST = [ + { value: 0, name: 'today' }, + { value: 7, name: 'last7days' }, + { value: 28, name: 'last4weeks' }, + { value: today.diff(today.subtract(3, 'month'), 'day'), name: 'last3months' }, + { value: today.diff(today.subtract(12, 'month'), 'day'), name: 'last12months' }, + { value: today.diff(today.startOf('month'), 'day'), name: 'monthToDate' }, + { value: today.diff(today.startOf('quarter'), 'day'), name: 'quarterToDate' }, + { value: today.diff(today.startOf('year'), 'day'), name: 'yearToDate' }, + { value: 'all', name: 'allTime' }, +] + +type IFilterProps = { + appId: string + queryParams: QueryParam + setQueryParams: (v: QueryParam) => void +} + +const Filter: FC = ({ appId, queryParams, setQueryParams }: IFilterProps) => { + const { data } = useSWR({ url: `/apps/${appId}/annotations/count` }, fetchAnnotationsCount) + const { t } = useTranslation() + if (!data) + return null + return ( +
+ ({ value: item.value, name: t(`appLog.filter.period.${item.name}`) }))} + className='mt-0 !w-40' + onSelect={(item) => { + setQueryParams({ ...queryParams, period: item.value }) + }} + defaultValue={queryParams.period} /> +
+ { + setQueryParams({ ...queryParams, annotation_status: item.value as string }) + } + } + items={[{ value: 'all', name: t('appLog.filter.annotation.all') }, + { value: 'annotated', name: t('appLog.filter.annotation.annotated', { count: data?.count }) }, + { value: 'not_annotated', name: t('appLog.filter.annotation.not_annotated') }]} + /> +
+
+
+
+ { + setQueryParams({ ...queryParams, keyword: e.target.value }) + }} + /> +
+
+ ) +} + +export default Filter diff --git a/web/app/components/app/log/index.tsx b/web/app/components/app/log/index.tsx new file mode 100644 index 0000000000..eb4dbdd636 --- /dev/null +++ b/web/app/components/app/log/index.tsx @@ -0,0 +1,146 @@ +'use client' +import type { FC } from 'react' +import React, { useState } from 'react' +import useSWR from 'swr' +import { usePathname } from 'next/navigation' +import { Pagination } from 'react-headless-pagination' +import { omit } from 'lodash-es' +import dayjs from 'dayjs' +import { ArrowLeftIcon, ArrowRightIcon } from '@heroicons/react/24/outline' +import { Trans, useTranslation } from 'react-i18next' +import Link from 'next/link' +import List from './list' +import Filter from './filter' +import s from './style.module.css' +import Loading from '@/app/components/base/loading' +import { fetchChatConversations, fetchCompletionConversations } from '@/service/log' +import { fetchAppDetail } from '@/service/apps' + +export type ILogsProps = { + appId: string +} + +export type QueryParam = { + period?: number | string + annotation_status?: string + keyword?: string +} + +// Custom page count is not currently supported. +const limit = 10 + +const ThreeDotsIcon: FC<{ className?: string }> = ({ className }) => { + return + + +} + +const EmptyElement: FC<{ appUrl: string }> = ({ appUrl }) => { + const { t } = useTranslation() + const pathname = usePathname() + const pathSegments = pathname.split('/') + pathSegments.pop() + return
+
+ {t('appLog.table.empty.element.title')} +
+ , testLink: }} + /> +
+
+
+} + +const Logs: FC = ({ appId }) => { + const { t } = useTranslation() + const [queryParams, setQueryParams] = useState({ period: 7, annotation_status: 'all' }) + const [currPage, setCurrPage] = React.useState(0) + + const query = { + page: currPage + 1, + limit, + ...(queryParams.period !== 'all' + ? { + start: dayjs().subtract(queryParams.period as number, 'day').format('YYYY-MM-DD HH:mm'), + end: dayjs().format('YYYY-MM-DD HH:mm'), + } + : {}), + ...omit(queryParams, ['period']), + } + + // Get the app type first + const { data: appDetail } = useSWR({ url: '/apps', id: appId }, fetchAppDetail) + const isChatMode = appDetail?.mode === 'chat' + + // When the details are obtained, proceed to the next request + const { data: chatConversations, mutate: mutateChatList } = useSWR(() => isChatMode + ? { + url: `/apps/${appId}/chat-conversations`, + params: query, + } + : null, fetchChatConversations) + + const { data: completionConversations, mutate: mutateCompletionList } = useSWR(() => !isChatMode + ? { + url: `/apps/${appId}/completion-conversations`, + params: query, + } + : null, fetchCompletionConversations) + + const total = isChatMode ? chatConversations?.total : completionConversations?.total + + return ( +
+
+

{t('appLog.title')}

+

{t('appLog.description')}

+
+
+ + {total === undefined + ? + : total > 0 + ? + : + } + {/* Show Pagination only if the total is more than the limit */} + {(total && total > limit) + ? + + + {t('appLog.table.pagination.previous')} + +
+ +
+ + {t('appLog.table.pagination.next')} + + +
+ : null} +
+
+ ) +} + +export default Logs diff --git a/web/app/components/app/log/list.tsx b/web/app/components/app/log/list.tsx new file mode 100644 index 0000000000..cfde271d53 --- /dev/null +++ b/web/app/components/app/log/list.tsx @@ -0,0 +1,477 @@ +'use client' +import type { FC } from 'react' +import React, { useEffect, useState } from 'react' +// import type { Log } from '@/models/log' +import useSWR from 'swr' +import { + HandThumbDownIcon, + HandThumbUpIcon, + InformationCircleIcon, + XMarkIcon, +} from '@heroicons/react/24/outline' +import { SparklesIcon } from '@heroicons/react/24/solid' +import { get } from 'lodash-es' +import InfiniteScroll from 'react-infinite-scroll-component' +import dayjs from 'dayjs' +import { createContext, useContext } from 'use-context-selector' +import classNames from 'classnames' +import { useTranslation } from 'react-i18next' +import { EditIconSolid } from '../chat' +import { randomString } from '../../app-sidebar/basic' +import s from './style.module.css' +import type { FeedbackFunc, Feedbacktype, IChatItem, SubmitAnnotationFunc } from '@/app/components/app/chat' +import type { Annotation, ChatConversationFullDetailResponse, ChatConversationGeneralDetail, ChatConversationsResponse, ChatMessage, ChatMessagesRequest, CompletionConversationFullDetailResponse, CompletionConversationGeneralDetail, CompletionConversationsResponse } from '@/models/log' +import type { App } from '@/types/app' +import Loading from '@/app/components/base/loading' +import Drawer from '@/app/components/base/drawer' +import Popover from '@/app/components/base/popover' +import Chat from '@/app/components/app/chat' +import Tooltip from '@/app/components/base/tooltip' +import { ToastContext } from '@/app/components/base/toast' +import { fetchChatConversationDetail, fetchChatMessages, fetchCompletionConversationDetail, updateLogMessageAnnotations, updateLogMessageFeedbacks } from '@/service/log' +import { TONE_LIST } from '@/config' + +type IConversationList = { + logs?: ChatConversationsResponse | CompletionConversationsResponse + appDetail?: App + onRefresh: () => void +} + +const defaultValue = 'N/A' +const emptyText = '[Empty]' + +type IDrawerContext = { + onClose: () => void + appDetail?: App +} + +const DrawerContext = createContext({} as IDrawerContext) + +export const OpenAIIcon: FC<{ className?: string }> = ({ className }) => { + return + + + +} + +/** + * Icon component with numbers + */ +const HandThumbIconWithCount: FC<{ count: number; iconType: 'up' | 'down' }> = ({ count, iconType }) => { + const classname = iconType === 'up' ? 'text-primary-600 bg-primary-50' : 'text-red-600 bg-red-50' + const Icon = iconType === 'up' ? HandThumbUpIcon : HandThumbDownIcon + return
+ + {count > 0 ? count : null} +
+} + +const PARAM_MAP = { + temperature: 'Temperature', + top_p: 'Top P', + presence_penalty: 'Presence Penalty', + max_tokens: 'Max Token', + stop: 'Stop', + frequency_penalty: 'Frequency Penalty', +} + +// Format interface data for easy display +const getFormattedChatList = (messages: ChatMessage[]) => { + const newChatList: IChatItem[] = [] + messages.forEach((item: ChatMessage) => { + newChatList.push({ + id: `question-${item.id}`, + content: item.query, + isAnswer: false, + }) + + newChatList.push({ + id: item.id, + content: item.answer, + feedback: item.feedbacks.find(item => item.from_source === 'user'), // user feedback + adminFeedback: item.feedbacks.find(item => item.from_source === 'admin'), // admin feedback + feedbackDisabled: false, + isAnswer: true, + more: { + time: dayjs.unix(item.created_at).format('hh:mm A'), + tokens: item.answer_tokens, + latency: (item.provider_response_latency / 1000).toFixed(2), + }, + annotation: item.annotation, + }) + }) + return newChatList +} + +// const displayedParams = CompletionParams.slice(0, -2) +const validatedParams = ['temperature', 'top_p', 'presence_penalty', 'frequency_penalty'] + +type IDetailPanel = { + detail: T + onFeedback: FeedbackFunc + onSubmitAnnotation: SubmitAnnotationFunc +} + +function DetailPanel({ detail, onFeedback, onSubmitAnnotation }: IDetailPanel) { + const { onClose, appDetail } = useContext(DrawerContext) + const { t } = useTranslation() + const [items, setItems] = React.useState([]) + const [hasMore, setHasMore] = useState(true) + + const fetchData = async () => { + try { + if (!hasMore) + return + const params: ChatMessagesRequest = { + conversation_id: detail.id, + limit: 4, + } + if (items?.[0]?.id) + params.first_id = items?.[0]?.id.replace('question-', '') + + const messageRes = await fetchChatMessages({ + url: `/apps/${appDetail?.id}/chat-messages`, + params, + }) + const newItems = [...getFormattedChatList(messageRes.data), ...items] + if (messageRes.has_more === false && detail?.model_config?.configs?.introduction) { + newItems.unshift({ + id: 'introduction', + isAnswer: true, + isOpeningStatement: true, + content: detail?.model_config?.configs?.introduction ?? 'hello', + feedbackDisabled: true, + }) + } + setItems(newItems) + setHasMore(messageRes.has_more) + } + catch (err) { + console.error(err) + } + } + + useEffect(() => { + if (appDetail?.id && detail.id && appDetail?.mode === 'chat') + fetchData() + }, [appDetail?.id, detail.id]) + + const isChatMode = appDetail?.mode === 'chat' + + const targetTone = TONE_LIST.find((item) => { + let res = true + validatedParams.forEach((param) => { + res = item.config?.[param] === detail.model_config?.configs?.completion_params?.[param] + }) + return res + })?.name ?? 'custom' + + return (
+ {/* Panel Header */} +
+
+ {isChatMode ? t('appLog.detail.conversationId') : t('appLog.detail.time')} +
{isChatMode ? detail.id : dayjs.unix(detail.created_at).format(t('appLog.dateTimeFormat'))}
+
+
{detail.model_config.model_id}
+ + {targetTone} + + } + htmlContent={
+
+ Tone of responses +
{targetTone}
+
+ {['temperature', 'top_p', 'presence_penalty', 'max_tokens'].map((param, index) => { + return
+ {PARAM_MAP[param]} + {detail?.model_config.model?.completion_params?.[param] || '-'} +
+ })} +
} + /> +
+ +
+
+ {/* Panel Body */} +
+
+ {isChatMode ? t('appLog.detail.promptTemplateBeforeChat') : t('appLog.detail.promptTemplate')} +
+
{detail.model_config?.pre_prompt || emptyText}
+
+ {!isChatMode + ?
+ +
+ : items.length < 8 + ?
+ +
+ :
+ {/* Put the scroll bar always on the bottom */} + {t('appLog.detail.loading')}...
} + // endMessage={
Nothing more to show
} + // below props only if you need pull down functionality + refreshFunction={fetchData} + pullDownToRefresh + pullDownToRefreshThreshold={50} + // pullDownToRefreshContent={ + //
Pull down to refresh
+ // } + // releaseToRefreshContent={ + //
Release to refresh
+ // } + // To put endMessage and loader to the top. + style={{ display: 'flex', flexDirection: 'column-reverse' }} + inverse={true} + > + + +
+ } + ) +} + +/** + * Text App Conversation Detail Component + */ +const CompletionConversationDetailComp: FC<{ appId?: string; conversationId?: string }> = ({ appId, conversationId }) => { + // Text Generator App Session Details Including Message List + const detailParams = ({ url: `/apps/${appId}/completion-conversations/${conversationId}` }) + const { data: conversationDetail, mutate: conversationDetailMutate } = useSWR(() => (appId && conversationId) ? detailParams : null, fetchCompletionConversationDetail) + const { notify } = useContext(ToastContext) + const { t } = useTranslation() + + const handleFeedback = async (mid: string, { rating }: Feedbacktype): Promise => { + try { + await updateLogMessageFeedbacks({ url: `/apps/${appId}/feedbacks`, body: { message_id: mid, rating } }) + conversationDetailMutate() + notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) + return true + } + catch (err) { + notify({ type: 'error', message: t('common.actionMsg.modificationFailed') }) + return false + } + } + + const handleAnnotation = async (mid: string, value: string): Promise => { + try { + await updateLogMessageAnnotations({ url: `/apps/${appId}/annotations`, body: { message_id: mid, content: value } }) + conversationDetailMutate() + notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) + return true + } + catch (err) { + notify({ type: 'error', message: t('common.actionMsg.modificationFailed') }) + return false + } + } + + if (!conversationDetail) + return null + + return + detail={conversationDetail} + onFeedback={handleFeedback} + onSubmitAnnotation={handleAnnotation} + /> +} + +/** + * Chat App Conversation Detail Component + */ +const ChatConversationDetailComp: FC<{ appId?: string; conversationId?: string }> = ({ appId, conversationId }) => { + const detailParams = { url: `/apps/${appId}/chat-conversations/${conversationId}` } + const { data: conversationDetail } = useSWR(() => (appId && conversationId) ? detailParams : null, fetchChatConversationDetail) + const { notify } = useContext(ToastContext) + const { t } = useTranslation() + + const handleFeedback = async (mid: string, { rating }: Feedbacktype): Promise => { + try { + await updateLogMessageFeedbacks({ url: `/apps/${appId}/feedbacks`, body: { message_id: mid, rating } }) + notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) + return true + } + catch (err) { + notify({ type: 'error', message: t('common.actionMsg.modificationFailed') }) + return false + } + } + + const handleAnnotation = async (mid: string, value: string): Promise => { + try { + await updateLogMessageAnnotations({ url: `/apps/${appId}/annotations`, body: { message_id: mid, content: value } }) + notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) + return true + } + catch (err) { + notify({ type: 'error', message: t('common.actionMsg.modificationFailed') }) + return false + } + } + + if (!conversationDetail) + return null + + return + detail={conversationDetail} + onFeedback={handleFeedback} + onSubmitAnnotation={handleAnnotation} + /> +} + +/** + * Conversation list component including basic information + */ +const ConversationList: FC = ({ logs, appDetail, onRefresh }) => { + const { t } = useTranslation() + const [showDrawer, setShowDrawer] = useState(false) // Whether to display the chat details drawer + const [currentConversation, setCurrentConversation] = useState() // Currently selected conversation + const isChatMode = appDetail?.mode === 'chat' // Whether the app is a chat app + + // Annotated data needs to be highlighted + const renderTdValue = (value: string | number | null, isEmptyStyle: boolean, isHighlight = false, annotation?: Annotation) => { + return ( + + {`${t('appLog.detail.annotationTip', { user: annotation?.account?.name })} ${dayjs.unix(annotation?.created_at || dayjs().unix()).format('MM-DD hh:mm A')}`} + + } + className={(isHighlight && !isChatMode) ? '' : '!hidden'} + selector={`highlight-${randomString(16)}`} + > +
+ {value || '-'} +
+
+ ) + } + + const onCloseDrawer = () => { + onRefresh() + setShowDrawer(false) + setCurrentConversation(undefined) + } + + if (!logs) + return + + return ( + <> + + + + + + + + + + + + + + {logs.data.map((log) => { + const endUser = log.from_end_user_id?.slice(0, 8) + const leftValue = get(log, isChatMode ? 'summary' : 'message.query') + const rightValue = get(log, isChatMode ? 'message_count' : 'message.answer') + return { + setShowDrawer(true) + setCurrentConversation(log) + }}> + + + + + + + + + })} + +
{t('appLog.table.header.time')}{t('appLog.table.header.endUser')}{isChatMode ? t('appLog.table.header.summary') : t('appLog.table.header.input')}{isChatMode ? t('appLog.table.header.messageCount') : t('appLog.table.header.output')}{t('appLog.table.header.userRate')}{t('appLog.table.header.adminRate')}
{!log.read_at && }{dayjs.unix(log.created_at).format(t('appLog.dateTimeFormat'))}{renderTdValue(endUser || defaultValue, !endUser)} + {renderTdValue(leftValue || t('appLog.table.empty.noChat'), !leftValue, isChatMode && log.annotated)} + + {renderTdValue(rightValue === 0 ? 0 : (rightValue || t('appLog.table.empty.noOutput')), !rightValue, !isChatMode && !!log.annotation?.content, log.annotation)} + + {(!log.user_feedback_stats.like && !log.user_feedback_stats.dislike) + ? renderTdValue(defaultValue, true) + : <> + {!!log.user_feedback_stats.like && } + {!!log.user_feedback_stats.dislike && } + + } + + {(!log.admin_feedback_stats.like && !log.admin_feedback_stats.dislike) + ? renderTdValue(defaultValue, true) + : <> + {!!log.admin_feedback_stats.like && } + {!!log.admin_feedback_stats.dislike && } + + } +
+ + + {isChatMode + ? + : + } + + + + ) +} + +export default ConversationList diff --git a/web/app/components/app/log/style.module.css b/web/app/components/app/log/style.module.css new file mode 100644 index 0000000000..67a9fe3bf5 --- /dev/null +++ b/web/app/components/app/log/style.module.css @@ -0,0 +1,9 @@ +.logTable td { + padding: 7px 8px; + box-sizing: border-box; + max-width: 200px; +} + +.pagination li { + list-style: none; +} diff --git a/web/app/components/app/overview/appCard.tsx b/web/app/components/app/overview/appCard.tsx new file mode 100644 index 0000000000..663383f047 --- /dev/null +++ b/web/app/components/app/overview/appCard.tsx @@ -0,0 +1,211 @@ +'use client' +import React, { useState } from 'react' +import { + Cog8ToothIcon, + DocumentTextIcon, + RocketLaunchIcon, + ShareIcon, +} from '@heroicons/react/24/outline' +import { SparklesIcon } from '@heroicons/react/24/solid' +import { usePathname, useRouter } from 'next/navigation' +import { useTranslation } from 'react-i18next' +import SettingsModal from './settings' +import ShareLink from './share-link' +import CustomizeModal from './customize' +import Tooltip from '@/app/components/base/tooltip' +import AppBasic, { randomString } from '@/app/components/app-sidebar/basic' +import Button from '@/app/components/base/button' +import Tag from '@/app/components/base/tag' +import Switch from '@/app/components/base/switch' +import type { AppDetailResponse } from '@/models/app' + +export type IAppCardProps = { + className?: string + appInfo: AppDetailResponse + cardType?: 'app' | 'api' + customBgColor?: string + onChangeStatus: (val: boolean) => Promise + onSaveSiteConfig?: (params: any) => Promise + onGenerateCode?: () => Promise +} + +// todo: get image url from appInfo +const defaultUrl = 'https://images.unsplash.com/photo-1472099645785-5658abf4ff4e?ixlib=rb-1.2.1&ixid=eyJhcHBfaWQiOjEyMDd9&auto=format&fit=facearea&facepad=2&w=256&h=256&q=80' + +function AppCard({ + appInfo, + cardType = 'app', + customBgColor, + onChangeStatus, + onSaveSiteConfig, + onGenerateCode, + className, +}: IAppCardProps) { + const router = useRouter() + const pathname = usePathname() + const [showSettingsModal, setShowSettingsModal] = useState(false) + const [showShareModal, setShowShareModal] = useState(false) + const [showCustomizeModal, setShowCustomizeModal] = useState(false) + const { t } = useTranslation() + + const OPERATIONS_MAP = { + app: [ + { opName: t('appOverview.overview.appInfo.preview'), opIcon: RocketLaunchIcon }, + { opName: t('appOverview.overview.appInfo.share.entry'), opIcon: ShareIcon }, + { opName: t('appOverview.overview.appInfo.settings.entry'), opIcon: Cog8ToothIcon }, + ], + api: [{ opName: t('appOverview.overview.apiInfo.doc'), opIcon: DocumentTextIcon }], + } + + const isApp = cardType === 'app' + const basicName = isApp ? appInfo?.site?.title : t('appOverview.overview.apiInfo.title') + const runningStatus = isApp ? appInfo.enable_site : appInfo.enable_api + const { app_base_url, access_token } = appInfo.site ?? {} + const appUrl = `${app_base_url}/${appInfo.mode}/${access_token}` + const apiUrl = appInfo?.api_base_url + + let bgColor = 'bg-primary-50 bg-opacity-40' + if (cardType === 'api') + bgColor = 'bg-purple-50' + + const genClickFuncByName = (opName: string) => { + switch (opName) { + case t('appOverview.overview.appInfo.preview'): + return () => { + window.open(appUrl, '_blank') + } + case t('appOverview.overview.appInfo.share.entry'): + return () => { + setShowShareModal(true) + } + case t('appOverview.overview.appInfo.settings.entry'): + return () => { + setShowSettingsModal(true) + } + default: + // jump to page develop + return () => { + const pathSegments = pathname.split('/') + pathSegments.pop() + router.push(`${pathSegments.join('/')}/develop`) + } + } + } + + const onClickCustomize = () => { + setShowCustomizeModal(true) + } + + return ( +
+
+
+ +
+ + {runningStatus ? t('appOverview.overview.status.running') : t('appOverview.overview.status.disable')} + + +
+
+
+
+
+ {isApp ? t('appOverview.overview.appInfo.accessibleAddress') : t('appOverview.overview.apiInfo.accessibleAddress')} +
+
+ {isApp ? appUrl : apiUrl} +
+
+
+
+ {OPERATIONS_MAP[cardType].map((op) => { + return ( + + ) + })} +
+
+ {isApp + ? ( +
+
+ + {t('appOverview.overview.appInfo.customize.entry')} +
+
+ ) + : null} + {isApp + ? ( +
+ setShowShareModal(false)} + linkUrl={appUrl} + onGenerateCode={onGenerateCode} + /> + setShowSettingsModal(false)} + onSave={onSaveSiteConfig} + /> + setShowCustomizeModal(false)} + appId={appInfo.id} + mode={appInfo.mode} + /> +
+ ) + : null} +
+ ) +} + +export default AppCard diff --git a/web/app/components/app/overview/appChart.tsx b/web/app/components/app/overview/appChart.tsx new file mode 100644 index 0000000000..a8dae564e6 --- /dev/null +++ b/web/app/components/app/overview/appChart.tsx @@ -0,0 +1,291 @@ +'use client' +import type { FC } from 'react' +import React from 'react' +import ReactECharts from 'echarts-for-react' +import type { EChartsOption } from 'echarts' +import useSWR from 'swr' +import dayjs from 'dayjs' +import { get } from 'lodash-es' +import { formatNumber } from '@/utils/format' +import { useTranslation } from 'react-i18next' +import Basic from '@/app/components/app-sidebar/basic' +import Loading from '@/app/components/base/loading' +import type { AppDailyConversationsResponse, AppDailyEndUsersResponse, AppTokenCostsResponse } from '@/models/app' +import { getAppDailyConversations, getAppDailyEndUsers, getAppTokenCosts } from '@/service/apps' +const valueFormatter = (v: string | number) => v + +const COLOR_TYPE_MAP = { + green: { + lineColor: 'rgba(6, 148, 162, 1)', + bgColor: ['rgba(6, 148, 162, 0.2)', 'rgba(67, 174, 185, 0.08)'], + }, + orange: { + lineColor: 'rgba(255, 138, 76, 1)', + bgColor: ['rgba(254, 145, 87, 0.2)', 'rgba(255, 138, 76, 0.1)'], + }, + blue: { + lineColor: 'rgba(28, 100, 242, 1)', + bgColor: ['rgba(28, 100, 242, 0.3)', 'rgba(28, 100, 242, 0.1)'], + }, +} + +const COMMON_COLOR_MAP = { + label: '#9CA3AF', + splitLineLight: '#F3F4F6', + splitLineDark: '#E5E7EB', +} + +type IColorType = 'green' | 'orange' | 'blue' +type IChartType = 'conversations' | 'endUsers' | 'costs' +type IChartConfigType = { colorType: IColorType; showTokens?: boolean } + +const commonDateFormat = 'MMM D, YYYY' + +const CHART_TYPE_CONFIG: Record = { + conversations: { + colorType: 'green', + }, + endUsers: { + colorType: 'orange', + }, + costs: { + colorType: 'blue', + showTokens: true, + }, +} + +const sum = (arr: number[]): number => { + return arr.reduce((acr, cur) => { + return acr + cur + }) +} + +export type PeriodParams = { + name: string + query: { + start: string + end: string + } +} + +export type IBizChartProps = { + period: PeriodParams + id: string +} + +export type IChartProps = { + className?: string + basicInfo: { title: string; explanation: string; timePeriod: string } + yMax?: number + chartType: IChartType + chartData: AppDailyConversationsResponse | AppDailyEndUsersResponse | AppTokenCostsResponse | { data: Array<{ date: string; count: number }> } +} + +const Chart: React.FC = ({ + basicInfo: { title, explanation, timePeriod }, + chartType = 'conversations', + chartData, + yMax, + className, +}) => { + const { t } = useTranslation() + const statistics = chartData.data + const statisticsLen = statistics.length + const extraDataForMarkLine = new Array(statisticsLen >= 2 ? statisticsLen - 2 : statisticsLen).fill('1') + extraDataForMarkLine.push('') + extraDataForMarkLine.unshift('') + + const xData = statistics.map(({ date }) => date) + const yField = Object.keys(statistics[0]).find(name => name.includes('count')) || '' + const yData = statistics.map((item) => { + // @ts-expect-error field is valid + return item[yField] || 0 + }) + + const options: EChartsOption = { + dataset: { + dimensions: ['date', yField], + source: statistics, + }, + grid: { top: 8, right: 36, bottom: 0, left: 0, containLabel: true }, + tooltip: { + trigger: 'item', + position: 'top', + borderWidth: 0, + }, + xAxis: [{ + type: 'category', + boundaryGap: false, + axisLabel: { + color: COMMON_COLOR_MAP.label, + hideOverlap: true, + overflow: 'break', + formatter(value) { + return dayjs(value).format(commonDateFormat) + }, + }, + axisLine: { show: false }, + axisTick: { show: false }, + splitLine: { + show: true, + lineStyle: { + color: COMMON_COLOR_MAP.splitLineLight, + width: 1, + type: [10, 10], + }, + interval(index) { + return index === 0 || index === xData.length - 1 + }, + }, + }, { + position: 'bottom', + boundaryGap: false, + data: extraDataForMarkLine, + axisLabel: { show: false }, + axisLine: { show: false }, + axisTick: { show: false }, + splitLine: { + show: true, + lineStyle: { + color: COMMON_COLOR_MAP.splitLineDark, + }, + interval(index, value) { + return !!value + }, + }, + }], + yAxis: { + max: yMax ?? 'dataMax', + type: 'value', + axisLabel: { color: COMMON_COLOR_MAP.label, hideOverlap: true }, + splitLine: { + lineStyle: { + color: COMMON_COLOR_MAP.splitLineLight, + }, + }, + }, + series: [ + { + type: 'line', + showSymbol: true, + // symbol: 'circle', + // triggerLineEvent: true, + symbolSize: 4, + lineStyle: { + color: COLOR_TYPE_MAP[CHART_TYPE_CONFIG[chartType].colorType].lineColor, + width: 2, + }, + itemStyle: { + color: COLOR_TYPE_MAP[CHART_TYPE_CONFIG[chartType].colorType].lineColor, + }, + areaStyle: { + color: { + type: 'linear', + x: 0, + y: 0, + x2: 0, + y2: 1, + colorStops: [{ + offset: 0, color: COLOR_TYPE_MAP[CHART_TYPE_CONFIG[chartType].colorType].bgColor[0], + }, { + offset: 1, color: COLOR_TYPE_MAP[CHART_TYPE_CONFIG[chartType].colorType].bgColor[1], + }], + global: false, + }, + }, + tooltip: { + padding: [8, 12, 8, 12], + formatter(params) { + return `
${params.name}
+
${valueFormatter((params.data as any)[yField])} + ${!CHART_TYPE_CONFIG[chartType].showTokens + ? '' + : ` + ( + ~$${get(params.data, 'total_price', 0)} + ) + `} +
` + }, + }, + }, + ], + } + + const sumData = sum(yData) + + return ( +
+
+ +
+
+ {t('appOverview.analysis.tokenUsage.consumed')} Tokens + ( + ~{sum(statistics.map(item => parseFloat(get(item, 'total_price', '0')))).toLocaleString('en-US', { style: 'currency', currency: 'USD', minimumFractionDigits: 4 })} + ) + } + textStyle={{ main: `!text-3xl !font-normal ${sumData === 0 ? '!text-gray-300' : ''}` }} /> +
+ +
+ ) +} + +const getDefaultChartData = ({ start, end }: { start: string; end: string }) => { + const diffDays = dayjs(end).diff(dayjs(start), 'day') + return Array.from({ length: diffDays || 1 }, () => ({ date: '', count: 0 })).map((item, index) => { + item.date = dayjs(start).add(index, 'day').format(commonDateFormat) + return item + }) +} + +export const ConversationsChart: FC = ({ id, period }) => { + const { t } = useTranslation() + const { data: response } = useSWR({ url: `/apps/${id}/statistics/daily-conversations`, params: period.query }, getAppDailyConversations) + if (!response) + return + const noDataFlag = !response.data || response.data.length === 0 + return +} + +export const EndUsersChart: FC = ({ id, period }) => { + const { t } = useTranslation() + + const { data: response } = useSWR({ url: `/apps/${id}/statistics/daily-end-users`, id, params: period.query }, getAppDailyEndUsers) + if (!response) + return + const noDataFlag = !response.data || response.data.length === 0 + return +} + +export const CostChart: FC = ({ id, period }) => { + const { t } = useTranslation() + + const { data: response } = useSWR({ url: `/apps/${id}/statistics/token-costs`, params: period.query }, getAppTokenCosts) + if (!response) + return + const noDataFlag = !response.data || response.data.length === 0 + return +} + +export default Chart diff --git a/web/app/components/app/overview/customize/index.tsx b/web/app/components/app/overview/customize/index.tsx new file mode 100644 index 0000000000..a1d46ff0b2 --- /dev/null +++ b/web/app/components/app/overview/customize/index.tsx @@ -0,0 +1,108 @@ +'use client' +import type { FC } from 'react' +import React from 'react' +import { AppMode } from '@/types/app' +import { ArrowTopRightOnSquareIcon } from '@heroicons/react/24/outline' +import { useTranslation } from 'react-i18next' +import { useContext } from 'use-context-selector' +import I18n from '@/context/i18n' +import Button from '@/app/components/base/button' +import Modal from '@/app/components/base/modal' +import Tag from '@/app/components/base/tag' + +type IShareLinkProps = { + isShow: boolean + onClose: () => void + linkUrl: string + appId: string + mode: AppMode +} + +const StepNum: FC<{ children: React.ReactNode }> = ({ children }) => +
+ {children} +
+ + + +const GithubIcon = ({ className }: { className: string }) => { + return ( + + + + ) +} + +const prefixCustomize = 'appOverview.overview.appInfo.customize' + +const CustomizeModal: FC = ({ + isShow, + onClose, + appId, + mode, +}) => { + const { t } = useTranslation() + const { locale } = useContext(I18n) + const isChatApp = mode === 'chat' + + return +
+ {t(`${prefixCustomize}.way`)} 1 +

{t(`${prefixCustomize}.way1.name`)}

+
+ 1 +
+
{t(`${prefixCustomize}.way1.step1`)}
+
{t(`${prefixCustomize}.way1.step1Tip`)}
+ + + +
+
+
+ 2 +
+
{t(`${prefixCustomize}.way1.step2`)}
+
{t(`${prefixCustomize}.way1.step2Tip`)}
+
+            export const APP_ID = '{appId}'
+ export const API_KEY = {`''`} +
+
+
+
+ 3 +
+
{t(`${prefixCustomize}.way1.step3`)}
+
{t(`${prefixCustomize}.way1.step3Tip`)}
+ + + +
+
+
+
+ {t(`${prefixCustomize}.way`)} 2 +

{t(`${prefixCustomize}.way2.name`)}

+ +
+
+} + +export default CustomizeModal diff --git a/web/app/components/app/overview/settings/index.tsx b/web/app/components/app/overview/settings/index.tsx new file mode 100644 index 0000000000..19e76041af --- /dev/null +++ b/web/app/components/app/overview/settings/index.tsx @@ -0,0 +1,145 @@ +'use client' +import type { FC } from 'react' +import React, { useState } from 'react' +import { ChevronRightIcon } from '@heroicons/react/20/solid' +import Link from 'next/link' +import { Trans, useTranslation } from 'react-i18next' +import s from './style.module.css' +import Modal from '@/app/components/base/modal' +import Button from '@/app/components/base/button' +import Switch from '@/app/components/base/switch' +import AppIcon from '@/app/components/base/app-icon' +import { SimpleSelect } from '@/app/components/base/select' +import type { AppDetailResponse } from '@/models/app' +import type { Language } from '@/types/app' + +export type ISettingsModalProps = { + appInfo: AppDetailResponse + isShow: boolean + defaultValue?: string + onClose: () => void + onSave: (params: ConfigParams) => Promise +} + +export type ConfigParams = { + title: string + description: string + default_language: string + prompt_public: boolean +} + +const LANGUAGE_MAP: Record = { + 'en-US': 'English(United States)', + 'zh-Hans': '简体中文', +} + +const prefixSettings = 'appOverview.overview.appInfo.settings' + +const SettingsModal: FC = ({ + appInfo, + isShow = false, + onClose, + onSave, +}) => { + const [isShowMore, setIsShowMore] = useState(false) + const { title, description, copyright, privacy_policy, default_language } = appInfo.site + const [inputInfo, setInputInfo] = useState({ title, desc: description, copyright, privacyPolicy: privacy_policy }) + const [language, setLanguage] = useState(default_language) + const [saveLoading, setSaveLoading] = useState(false) + const { t } = useTranslation() + + const onHide = () => { + onClose() + setTimeout(() => { + setIsShowMore(false) + }, 200) + } + + const onClickSave = async () => { + setSaveLoading(true) + const params = { + title: inputInfo.title, + description: inputInfo.desc, + default_language: language, + prompt_public: false, + copyright: inputInfo.copyright, + privacy_policy: inputInfo.privacyPolicy, + } + await onSave(params) + setSaveLoading(false) + onHide() + } + + const onChange = (field: string) => { + return (e: any) => { + setInputInfo(item => ({ ...item, [field]: e.target.value })) + } + } + + return ( + +
{t(`${prefixSettings}.webName`)}
+
+ + +
+
{t(`${prefixSettings}.webDesc`)}
+

{t(`${prefixSettings}.webDescTip`)}

+ + +
+
+ {query?.length} +
+ +
+ + + + + + ) +} +export default React.memo(ConfigSence) diff --git a/web/app/components/share/text-generation/history/index.tsx b/web/app/components/share/text-generation/history/index.tsx new file mode 100644 index 0000000000..a100c98a17 --- /dev/null +++ b/web/app/components/share/text-generation/history/index.tsx @@ -0,0 +1,79 @@ +'use client' +import React, { useState } from 'react' +import useSWR from 'swr' +import { + ChevronDownIcon, + ChevronUpIcon, +} from '@heroicons/react/24/outline' +import { fetchHistories } from '@/models/history' +import type { History as HistoryItem } from '@/models/history' +import Loading from '@/app/components/base/loading' +import { mockAPI } from '@/test/test_util' + +mockAPI() + +export type IHistoryProps = { + dictionary: any +} + +const HistoryCard = ( + { history }: { history: HistoryItem }, +) => { + return ( +
+
+ {history.source} +
+ + ) +} + +const History = ({ + dictionary, +}: IHistoryProps) => { + const { data, error } = useSWR('http://localhost:3000/api/histories', fetchHistories) + const [showHistory, setShowHistory] = useState(false) + + const DivideLine = () => { + return
+ {/* divider line */} + + +
+ +
+ + + } + {/* agree to our Terms and Privacy Policy. */} +
+ {t('login.tosDesc')} +   + {t('login.tos')} +  &  + {t('login.pp')} +
+ +
+
+ + ) +} + +export default NormalForm diff --git a/web/app/signin/oneMoreStep.tsx b/web/app/signin/oneMoreStep.tsx new file mode 100644 index 0000000000..9f7f72378b --- /dev/null +++ b/web/app/signin/oneMoreStep.tsx @@ -0,0 +1,160 @@ +'use client' +import React, { useEffect, useReducer } from 'react' +import { useTranslation } from 'react-i18next' +import useSWR from 'swr' +import { useRouter } from 'next/navigation' +import Button from '@/app/components/base/button' +import Tooltip from '@/app/components/base/tooltip/index' + +import { SimpleSelect } from '@/app/components/base/select' +import { timezones } from '@/utils/timezone' +import { languageMaps, languages } from '@/utils/language' +import { oneMoreStep } from '@/service/common' +import Toast from '@/app/components/base/toast' + +type IState = { + formState: 'processing' | 'error' | 'success' | 'initial' + invitation_code: string + interface_language: string + timezone: string +} + +const reducer = (state: IState, action: any) => { + switch (action.type) { + case 'invitation_code': + return { ...state, invitation_code: action.value } + case 'interface_language': + return { ...state, interface_language: action.value } + case 'timezone': + return { ...state, timezone: action.value } + case 'formState': + return { ...state, formState: action.value } + case 'failed': + return { + formState: 'initial', + invitation_code: '', + interface_language: 'en-US', + timezone: 'Asia/Shanghai', + } + default: + throw new Error('Unknown action.') + } +} + +const OneMoreStep = () => { + const { t } = useTranslation() + const router = useRouter() + + const [state, dispatch] = useReducer(reducer, { + formState: 'initial', + invitation_code: '', + interface_language: 'en-US', + timezone: 'Asia/Shanghai', + }) + const { data, error } = useSWR(state.formState === 'processing' + ? { + url: '/account/init', + body: { + invitation_code: state.invitation_code, + interface_language: state.interface_language, + timezone: state.timezone, + }, + } + : null, oneMoreStep) + + useEffect(() => { + if (error && error.status === 400) { + Toast.notify({ type: 'error', message: t('login.invalidInvitationCode') }) + dispatch({ type: 'failed', payload: null }) + } + if (data) + router.push('/apps') + }, [data, error]) + + return ( + <> +
+

{t('login.oneMoreStep')}

+

{t('login.createSample')}

+
+ +
+
+
+
+ } + > + {t('login.donthave')} + + +
+ { + dispatch({ type: 'invitation_code', value: e.target.value.trim() }) + }} + /> +
+
+ +
+ +
+ { + dispatch({ type: 'interface_language', value: item.value }) + }} + /> +
+
+
+ + +
+ { + dispatch({ type: 'timezone', value: item.value }) + }} + /> +
+
+
+ +
+
+ + + ) +} + +export default OneMoreStep diff --git a/web/app/signin/page.module.css b/web/app/signin/page.module.css new file mode 100644 index 0000000000..7bf2611f37 --- /dev/null +++ b/web/app/signin/page.module.css @@ -0,0 +1,19 @@ +.githubIcon { + background: center/contain url('./assets/github.svg'); +} + +.googleIcon { + background: center/contain url('./assets/google.svg'); +} + +.logo { + width: 96px; + height: 40px; + background: url(~@/app/components/share/chat/welcome/icons/logo.png) center center no-repeat; + background-size: contain; +} + +.background { + background-image: url('./assets/background.png'); + background-size: cover; +} \ No newline at end of file diff --git a/web/app/signin/page.tsx b/web/app/signin/page.tsx new file mode 100644 index 0000000000..ec413265d1 --- /dev/null +++ b/web/app/signin/page.tsx @@ -0,0 +1,37 @@ +import React from 'react' +import Forms from './forms' +import Header from './_header' +import style from './page.module.css' +import classNames from 'classnames' + +const SignIn = () => { + + return ( + <> +
+
+
+ +
+ © {new Date().getFullYear()} Dify, Inc. All rights reserved. +
+
+ +
+ + + ) +} + +export default SignIn diff --git a/web/app/styles/globals.css b/web/app/styles/globals.css new file mode 100644 index 0000000000..f4710b0275 --- /dev/null +++ b/web/app/styles/globals.css @@ -0,0 +1,134 @@ +@tailwind base; +@tailwind components; +@tailwind utilities; + +:root { + --max-width: 1100px; + --border-radius: 12px; + --font-mono: ui-monospace, Menlo, Monaco, "Cascadia Mono", "Segoe UI Mono", + "Roboto Mono", "Oxygen Mono", "Ubuntu Monospace", "Source Code Pro", + "Fira Mono", "Droid Sans Mono", "Courier New", monospace; + + --foreground-rgb: 0, 0, 0; + --background-start-rgb: 214, 219, 220; + --background-end-rgb: 255, 255, 255; + + --primary-glow: conic-gradient(from 180deg at 50% 50%, + #16abff33 0deg, + #0885ff33 55deg, + #54d6ff33 120deg, + #0071ff33 160deg, + transparent 360deg); + --secondary-glow: radial-gradient(rgba(255, 255, 255, 1), + rgba(255, 255, 255, 0)); + + --tile-start-rgb: 239, 245, 249; + --tile-end-rgb: 228, 232, 233; + --tile-border: conic-gradient(#00000080, + #00000040, + #00000030, + #00000020, + #00000010, + #00000010, + #00000080); + + --callout-rgb: 238, 240, 241; + --callout-border-rgb: 172, 175, 176; + --card-rgb: 180, 185, 188; + --card-border-rgb: 131, 134, 135; +} + +/* @media (prefers-color-scheme: dark) { + :root { + --foreground-rgb: 255, 255, 255; + --background-start-rgb: 0, 0, 0; + --background-end-rgb: 0, 0, 0; + + --primary-glow: radial-gradient(rgba(1, 65, 255, 0.4), rgba(1, 65, 255, 0)); + --secondary-glow: linear-gradient(to bottom right, + rgba(1, 65, 255, 0), + rgba(1, 65, 255, 0), + rgba(1, 65, 255, 0.3)); + + --tile-start-rgb: 2, 13, 46; + --tile-end-rgb: 2, 5, 19; + --tile-border: conic-gradient(#ffffff80, + #ffffff40, + #ffffff30, + #ffffff20, + #ffffff10, + #ffffff10, + #ffffff80); + + --callout-rgb: 20, 20, 20; + --callout-border-rgb: 108, 108, 108; + --card-rgb: 100, 100, 100; + --card-border-rgb: 200, 200, 200; + } +} */ + +* { + box-sizing: border-box; + padding: 0; + margin: 0; +} + +html, +body { + max-width: 100vw; + overflow-x: hidden; +} + +body { + color: rgb(var(--foreground-rgb)); + user-select: none; + /* background: linear-gradient( + to bottom, + transparent, + rgb(var(--background-end-rgb)) + ) + rgb(var(--background-start-rgb)); */ +} + +a { + color: inherit; + text-decoration: none; + outline: none; +} + +button:focus-within { + outline: none; +} + +/* @media (prefers-color-scheme: dark) { + html { + color-scheme: dark; + } +} */ + +/* CSS Utils */ +.h1 { + padding-bottom: 1.5rem; + line-height: 1.5; + font-size: 1.125rem; + color: #111928; +} + +.h2 { + font-size: 14px; + font-weight: 500; + color: #111928; + line-height: 1.5; +} + +.link { + @apply text-blue-600 cursor-pointer hover:opacity-80 transition-opacity duration-200 ease-in-out; +} + +.text-gradient { + background: linear-gradient(91.58deg, #2250F2 -29.55%, #0EBCF3 75.22%); + -webkit-background-clip: text; + -webkit-text-fill-color: transparent; + background-clip: text; + text-fill-color: transparent; +} \ No newline at end of file diff --git a/web/app/styles/markdown.scss b/web/app/styles/markdown.scss new file mode 100644 index 0000000000..fdfaae0cf9 --- /dev/null +++ b/web/app/styles/markdown.scss @@ -0,0 +1,1042 @@ +@mixin light { + color-scheme: light; + --color-prettylights-syntax-comment: #6e7781; + --color-prettylights-syntax-constant: #0550ae; + --color-prettylights-syntax-entity: #8250df; + --color-prettylights-syntax-storage-modifier-import: #24292f; + --color-prettylights-syntax-entity-tag: #116329; + --color-prettylights-syntax-keyword: #cf222e; + --color-prettylights-syntax-string: #0a3069; + --color-prettylights-syntax-variable: #953800; + --color-prettylights-syntax-brackethighlighter-unmatched: #82071e; + --color-prettylights-syntax-invalid-illegal-text: #f6f8fa; + --color-prettylights-syntax-invalid-illegal-bg: #82071e; + --color-prettylights-syntax-carriage-return-text: #f6f8fa; + --color-prettylights-syntax-carriage-return-bg: #cf222e; + --color-prettylights-syntax-string-regexp: #116329; + --color-prettylights-syntax-markup-list: #3b2300; + --color-prettylights-syntax-markup-heading: #0550ae; + --color-prettylights-syntax-markup-italic: #24292f; + --color-prettylights-syntax-markup-bold: #24292f; + --color-prettylights-syntax-markup-deleted-text: #82071e; + --color-prettylights-syntax-markup-deleted-bg: #ffebe9; + --color-prettylights-syntax-markup-inserted-text: #116329; + --color-prettylights-syntax-markup-inserted-bg: #dafbe1; + --color-prettylights-syntax-markup-changed-text: #953800; + --color-prettylights-syntax-markup-changed-bg: #ffd8b5; + --color-prettylights-syntax-markup-ignored-text: #eaeef2; + --color-prettylights-syntax-markup-ignored-bg: #0550ae; + --color-prettylights-syntax-meta-diff-range: #8250df; + --color-prettylights-syntax-brackethighlighter-angle: #57606a; + --color-prettylights-syntax-sublimelinter-gutter-mark: #8c959f; + --color-prettylights-syntax-constant-other-reference-link: #0a3069; + --color-fg-default: #24292f; + --color-fg-muted: #57606a; + --color-fg-subtle: #6e7781; + --color-canvas-default: transparent; + --color-canvas-subtle: #f6f8fa; + --color-border-default: #d0d7de; + --color-border-muted: hsla(210, 18%, 87%, 1); + --color-neutral-muted: rgba(175, 184, 193, 0.2); + --color-accent-fg: #0969da; + --color-accent-emphasis: #0969da; + --color-attention-subtle: #fff8c5; + --color-danger-fg: #cf222e; +} + +.markdown-body { + -ms-text-size-adjust: 100%; + -webkit-text-size-adjust: 100%; + margin: 0; + color: #101828; + background-color: var(--color-canvas-default); + font-size: 14px; + font-weight: 400; + line-height: 1.5; + word-wrap: break-word; + user-select: text; +} + +.light { + @include light; +} + +:root { + @include light; +} + +@media (prefers-color-scheme: light) { + :root { + @include light; + } +} + +.markdown-body .octicon { + display: inline-block; + fill: currentColor; + vertical-align: text-bottom; +} + +.markdown-body h1:hover .anchor .octicon-link:before, +.markdown-body h2:hover .anchor .octicon-link:before, +.markdown-body h3:hover .anchor .octicon-link:before, +.markdown-body h4:hover .anchor .octicon-link:before, +.markdown-body h5:hover .anchor .octicon-link:before, +.markdown-body h6:hover .anchor .octicon-link:before { + width: 16px; + height: 16px; + content: " "; + display: inline-block; + background-color: currentColor; + -webkit-mask-image: url("data:image/svg+xml,"); + mask-image: url("data:image/svg+xml,"); +} + +.markdown-body details, +.markdown-body figcaption, +.markdown-body figure { + display: block; +} + +.markdown-body summary { + display: list-item; +} + +.markdown-body [hidden] { + display: none !important; +} + +.markdown-body a { + background-color: transparent; + color: var(--color-accent-fg); + text-decoration: none; +} + +.markdown-body abbr[title] { + border-bottom: none; + text-decoration: underline dotted; +} + +.markdown-body b, +.markdown-body strong { + font-weight: var(--base-text-weight-semibold, 600); +} + +.markdown-body dfn { + font-style: italic; +} + +.markdown-body mark { + background-color: var(--color-attention-subtle); + color: var(--color-fg-default); +} + +.markdown-body small { + font-size: 90%; +} + +.markdown-body sub, +.markdown-body sup { + font-size: 75%; + line-height: 0; + position: relative; + vertical-align: baseline; +} + +.markdown-body sub { + bottom: -0.25em; +} + +.markdown-body sup { + top: -0.5em; +} + +.markdown-body img { + border-style: none; + max-width: 100%; + box-sizing: content-box; + background-color: var(--color-canvas-default); +} + +.markdown-body code, +.markdown-body kbd, +.markdown-body pre, +.markdown-body samp { + font-family: monospace; + font-size: 1em; +} + +.markdown-body figure { + margin: 1em 40px; +} + +.markdown-body hr { + box-sizing: content-box; + overflow: hidden; + background: transparent; + border-bottom: 1px solid var(--color-border-muted); + height: 0.25em; + padding: 0; + margin: 24px 0; + background-color: var(--color-border-default); + border: 0; +} + +.markdown-body input { + font: inherit; + margin: 0; + overflow: visible; + font-family: inherit; + font-size: inherit; + line-height: inherit; +} + +.markdown-body [type="button"], +.markdown-body [type="reset"], +.markdown-body [type="submit"] { + -webkit-appearance: button; +} + +.markdown-body [type="checkbox"], +.markdown-body [type="radio"] { + box-sizing: border-box; + padding: 0; +} + +.markdown-body [type="number"]::-webkit-inner-spin-button, +.markdown-body [type="number"]::-webkit-outer-spin-button { + height: auto; +} + +.markdown-body [type="search"]::-webkit-search-cancel-button, +.markdown-body [type="search"]::-webkit-search-decoration { + -webkit-appearance: none; +} + +.markdown-body ::-webkit-input-placeholder { + color: inherit; + opacity: 0.54; +} + +.markdown-body ::-webkit-file-upload-button { + -webkit-appearance: button; + font: inherit; +} + +.markdown-body a:hover { + text-decoration: underline; +} + +.markdown-body ::placeholder { + color: var(--color-fg-subtle); + opacity: 1; +} + +.markdown-body hr::before { + display: table; + content: ""; +} + +.markdown-body hr::after { + display: table; + clear: both; + content: ""; +} + +.markdown-body table { + border-spacing: 0; + border-collapse: collapse; + display: block; + width: max-content; + max-width: 100%; + overflow: auto; +} + +.markdown-body td, +.markdown-body th { + padding: 0; +} + +.markdown-body details summary { + cursor: pointer; +} + +.markdown-body details:not([open])>*:not(summary) { + display: none !important; +} + +.markdown-body a:focus, +.markdown-body [role="button"]:focus, +.markdown-body input[type="radio"]:focus, +.markdown-body input[type="checkbox"]:focus { + outline: 2px solid var(--color-accent-fg); + outline-offset: -2px; + box-shadow: none; +} + +.markdown-body a:focus:not(:focus-visible), +.markdown-body [role="button"]:focus:not(:focus-visible), +.markdown-body input[type="radio"]:focus:not(:focus-visible), +.markdown-body input[type="checkbox"]:focus:not(:focus-visible) { + outline: solid 1px transparent; +} + +.markdown-body a:focus-visible, +.markdown-body [role="button"]:focus-visible, +.markdown-body input[type="radio"]:focus-visible, +.markdown-body input[type="checkbox"]:focus-visible { + outline: 2px solid var(--color-accent-fg); + outline-offset: -2px; + box-shadow: none; +} + +.markdown-body a:not([class]):focus, +.markdown-body a:not([class]):focus-visible, +.markdown-body input[type="radio"]:focus, +.markdown-body input[type="radio"]:focus-visible, +.markdown-body input[type="checkbox"]:focus, +.markdown-body input[type="checkbox"]:focus-visible { + outline-offset: 0; +} + +.markdown-body kbd { + display: inline-block; + padding: 3px 5px; + font: 11px ui-monospace, SFMono-Regular, SF Mono, Menlo, Consolas, + Liberation Mono, monospace; + line-height: 10px; + color: var(--color-fg-default); + vertical-align: middle; + background-color: var(--color-canvas-subtle); + border: solid 1px var(--color-neutral-muted); + border-bottom-color: var(--color-neutral-muted); + border-radius: 6px; + box-shadow: inset 0 -1px 0 var(--color-neutral-muted); +} + +.markdown-body h1, +.markdown-body h2, +.markdown-body h3, +.markdown-body h4, +.markdown-body h5, +.markdown-body h6 { + margin-top: 24px; + margin-bottom: 16px; + font-weight: var(--base-text-weight-semibold, 600); + line-height: 1.25; +} + + +.markdown-body p { + margin-top: 0; + margin-bottom: 10px; +} + +.markdown-body blockquote { + margin: 0; + padding: 0 8px; + border-left: 2px solid #2970FF; +} + +.markdown-body ul, +.markdown-body ol { + margin-top: 0; + margin-bottom: 0; + padding-left: 2em; +} + +.markdown-body ol { + list-style: decimal; +} + +.markdown-body ul { + list-style: disc; +} + +.markdown-body ol ol, +.markdown-body ul ol { + list-style-type: lower-roman; +} + +.markdown-body ul ul ol, +.markdown-body ul ol ol, +.markdown-body ol ul ol, +.markdown-body ol ol ol { + list-style-type: lower-alpha; +} + +.markdown-body dd { + margin-left: 0; +} + +.markdown-body tt, +.markdown-body code, +.markdown-body samp { + font-family: ui-monospace, SFMono-Regular, SF Mono, Menlo, Consolas, + Liberation Mono, monospace; + font-size: 12px; +} + +.markdown-body pre { + margin-top: 0; + margin-bottom: 0; + font-family: ui-monospace, SFMono-Regular, SF Mono, Menlo, Consolas, + Liberation Mono, monospace; + font-size: 12px; + word-wrap: normal; +} + +.markdown-body .octicon { + display: inline-block; + overflow: visible !important; + vertical-align: text-bottom; + fill: currentColor; +} + +.markdown-body input::-webkit-outer-spin-button, +.markdown-body input::-webkit-inner-spin-button { + margin: 0; + -webkit-appearance: none; + appearance: none; +} + +.markdown-body::before { + display: table; + content: ""; +} + +.markdown-body::after { + display: table; + clear: both; + content: ""; +} + +.markdown-body>*:first-child { + margin-top: 0 !important; +} + +.markdown-body>*:last-child { + margin-bottom: 0 !important; +} + +.markdown-body a:not([href]) { + color: inherit; + text-decoration: none; +} + +.markdown-body .absent { + color: var(--color-danger-fg); +} + +.markdown-body .anchor { + float: left; + padding-right: 4px; + margin-left: -20px; + line-height: 1; +} + +.markdown-body .anchor:focus { + outline: none; +} + +.markdown-body p, +.markdown-body blockquote, +.markdown-body ul, +.markdown-body ol, +.markdown-body dl, +.markdown-body table, +.markdown-body pre, +.markdown-body details { + margin-top: 0; + margin-bottom: 16px; +} + +.markdown-body blockquote> :first-child { + margin-top: 0; +} + +.markdown-body blockquote> :last-child { + margin-bottom: 0; +} + +.markdown-body h1 .octicon-link, +.markdown-body h2 .octicon-link, +.markdown-body h3 .octicon-link, +.markdown-body h4 .octicon-link, +.markdown-body h5 .octicon-link, +.markdown-body h6 .octicon-link { + color: var(--color-fg-default); + vertical-align: middle; + visibility: hidden; +} + +.markdown-body h1:hover .anchor, +.markdown-body h2:hover .anchor, +.markdown-body h3:hover .anchor, +.markdown-body h4:hover .anchor, +.markdown-body h5:hover .anchor, +.markdown-body h6:hover .anchor { + text-decoration: none; +} + +.markdown-body h1:hover .anchor .octicon-link, +.markdown-body h2:hover .anchor .octicon-link, +.markdown-body h3:hover .anchor .octicon-link, +.markdown-body h4:hover .anchor .octicon-link, +.markdown-body h5:hover .anchor .octicon-link, +.markdown-body h6:hover .anchor .octicon-link { + visibility: visible; +} + +.markdown-body h1 tt, +.markdown-body h1 code, +.markdown-body h2 tt, +.markdown-body h2 code, +.markdown-body h3 tt, +.markdown-body h3 code, +.markdown-body h4 tt, +.markdown-body h4 code, +.markdown-body h5 tt, +.markdown-body h5 code, +.markdown-body h6 tt, +.markdown-body h6 code { + padding: 0 0.2em; + font-size: inherit; +} + +.markdown-body summary h1, +.markdown-body summary h2, +.markdown-body summary h3, +.markdown-body summary h4, +.markdown-body summary h5, +.markdown-body summary h6 { + display: inline-block; +} + +.markdown-body summary h1 .anchor, +.markdown-body summary h2 .anchor, +.markdown-body summary h3 .anchor, +.markdown-body summary h4 .anchor, +.markdown-body summary h5 .anchor, +.markdown-body summary h6 .anchor { + margin-left: -40px; +} + +.markdown-body summary h1, +.markdown-body summary h2 { + padding-bottom: 0; + border-bottom: 0; +} + +.markdown-body ul.no-list, +.markdown-body ol.no-list { + padding: 0; + list-style-type: none; +} + +.markdown-body ol[type="a"] { + list-style-type: lower-alpha; +} + +.markdown-body ol[type="A"] { + list-style-type: upper-alpha; +} + +.markdown-body ol[type="i"] { + list-style-type: lower-roman; +} + +.markdown-body ol[type="I"] { + list-style-type: upper-roman; +} + +.markdown-body ol[type="1"] { + list-style-type: decimal; +} + +.markdown-body div>ol:not([type]) { + list-style-type: decimal; +} + +.markdown-body ul ul, +.markdown-body ul ol, +.markdown-body ol ol, +.markdown-body ol ul { + margin-top: 0; + margin-bottom: 0; +} + +.markdown-body li>p { + margin-top: 16px; +} + +.markdown-body li+li { + margin-top: 0.25em; +} + +.markdown-body dl { + padding: 0; +} + +.markdown-body dl dt { + padding: 0; + margin-top: 16px; + font-size: 1em; + font-style: italic; + font-weight: var(--base-text-weight-semibold, 600); +} + +.markdown-body dl dd { + padding: 0 16px; + margin-bottom: 16px; +} + +.markdown-body table th { + font-weight: var(--base-text-weight-semibold, 600); +} + +.markdown-body table th, +.markdown-body table td { + padding: 6px 13px; + border: 1px solid var(--color-border-default); +} + +.markdown-body table tr { + background-color: var(--color-canvas-default); + border-top: 1px solid var(--color-border-muted); +} + +.markdown-body table tr:nth-child(2n) { + background-color: var(--color-canvas-subtle); +} + +.markdown-body table img { + background-color: transparent; +} + +.markdown-body img[align="right"] { + padding-left: 20px; +} + +.markdown-body img[align="left"] { + padding-right: 20px; +} + +.markdown-body .emoji { + max-width: none; + vertical-align: text-top; + background-color: transparent; +} + +.markdown-body span.frame { + display: block; + overflow: hidden; +} + +.markdown-body span.frame>span { + display: block; + float: left; + width: auto; + padding: 7px; + margin: 13px 0 0; + overflow: hidden; + border: 1px solid var(--color-border-default); +} + +.markdown-body span.frame span img { + display: block; + float: left; +} + +.markdown-body span.frame span span { + display: block; + padding: 5px 0 0; + clear: both; + color: var(--color-fg-default); +} + +.markdown-body span.align-center { + display: block; + overflow: hidden; + clear: both; +} + +.markdown-body span.align-center>span { + display: block; + margin: 13px auto 0; + overflow: hidden; + text-align: center; +} + +.markdown-body span.align-center span img { + margin: 0 auto; + text-align: center; +} + +.markdown-body span.align-right { + display: block; + overflow: hidden; + clear: both; +} + +.markdown-body span.align-right>span { + display: block; + margin: 13px 0 0; + overflow: hidden; + text-align: right; +} + +.markdown-body span.align-right span img { + margin: 0; + text-align: right; +} + +.markdown-body span.float-left { + display: block; + float: left; + margin-right: 13px; + overflow: hidden; +} + +.markdown-body span.float-left span { + margin: 13px 0 0; +} + +.markdown-body span.float-right { + display: block; + float: right; + margin-left: 13px; + overflow: hidden; +} + +.markdown-body span.float-right>span { + display: block; + margin: 13px auto 0; + overflow: hidden; + text-align: right; +} + +.markdown-body code, +.markdown-body tt { + padding: 0.2em 0.4em; + margin: 0; + font-size: 85%; + white-space: break-spaces; + background-color: var(--color-neutral-muted); + border-radius: 6px; +} + +.markdown-body code br, +.markdown-body tt br { + display: none; +} + +.markdown-body del code { + text-decoration: inherit; +} + +.markdown-body samp { + font-size: 85%; +} + +.markdown-body pre code { + font-size: 100%; +} + +.markdown-body pre>code { + padding: 0; + margin: 0; + word-break: normal; + white-space: pre; + background: transparent; + border: 0; +} + +.markdown-body .highlight { + margin-bottom: 16px; +} + +.markdown-body .highlight pre { + margin-bottom: 0; + word-break: normal; +} + +.markdown-body .highlight pre, +.markdown-body pre { + padding: 16px; + background: #fff; + overflow: auto; + font-size: 85%; + line-height: 1.45; + border-radius: 6px; +} + +.markdown-body pre code, +.markdown-body pre tt { + display: inline-block; + max-width: 100%; + padding: 0; + margin: 0; + overflow-x: scroll; + line-height: inherit; + word-wrap: normal; + background-color: transparent; + border: 0; +} + +.markdown-body .csv-data td, +.markdown-body .csv-data th { + padding: 5px; + overflow: hidden; + font-size: 12px; + line-height: 1; + text-align: left; + white-space: nowrap; +} + +.markdown-body .csv-data .blob-num { + padding: 10px 8px 9px; + text-align: right; + background: var(--color-canvas-default); + border: 0; +} + +.markdown-body .csv-data tr { + border-top: 0; +} + +.markdown-body .csv-data th { + font-weight: var(--base-text-weight-semibold, 600); + background: var(--color-canvas-subtle); + border-top: 0; +} + +.markdown-body [data-footnote-ref]::before { + content: "["; +} + +.markdown-body [data-footnote-ref]::after { + content: "]"; +} + +.markdown-body .footnotes { + font-size: 12px; + color: var(--color-fg-muted); + border-top: 1px solid var(--color-border-default); +} + +.markdown-body .footnotes ol { + padding-left: 16px; +} + +.markdown-body .footnotes ol ul { + display: inline-block; + padding-left: 16px; + margin-top: 16px; +} + +.markdown-body .footnotes li { + position: relative; +} + +.markdown-body .footnotes li:target::before { + position: absolute; + top: -8px; + right: -8px; + bottom: -8px; + left: -24px; + pointer-events: none; + content: ""; + border: 2px solid var(--color-accent-emphasis); + border-radius: 6px; +} + +.markdown-body .footnotes li:target { + color: var(--color-fg-default); +} + +.markdown-body .footnotes .data-footnote-backref g-emoji { + font-family: monospace; +} + +.markdown-body .pl-c { + color: var(--color-prettylights-syntax-comment); +} + +.markdown-body .pl-c1, +.markdown-body .pl-s .pl-v { + color: var(--color-prettylights-syntax-constant); +} + +.markdown-body .pl-e, +.markdown-body .pl-en { + color: var(--color-prettylights-syntax-entity); +} + +.markdown-body .pl-smi, +.markdown-body .pl-s .pl-s1 { + color: var(--color-prettylights-syntax-storage-modifier-import); +} + +.markdown-body .pl-ent { + color: var(--color-prettylights-syntax-entity-tag); +} + +.markdown-body .pl-k { + color: var(--color-prettylights-syntax-keyword); +} + +.markdown-body .pl-s, +.markdown-body .pl-pds, +.markdown-body .pl-s .pl-pse .pl-s1, +.markdown-body .pl-sr, +.markdown-body .pl-sr .pl-cce, +.markdown-body .pl-sr .pl-sre, +.markdown-body .pl-sr .pl-sra { + color: var(--color-prettylights-syntax-string); +} + +.markdown-body .pl-v, +.markdown-body .pl-smw { + color: var(--color-prettylights-syntax-variable); +} + +.markdown-body .pl-bu { + color: var(--color-prettylights-syntax-brackethighlighter-unmatched); +} + +.markdown-body .pl-ii { + color: var(--color-prettylights-syntax-invalid-illegal-text); + background-color: var(--color-prettylights-syntax-invalid-illegal-bg); +} + +.markdown-body .pl-c2 { + color: var(--color-prettylights-syntax-carriage-return-text); + background-color: var(--color-prettylights-syntax-carriage-return-bg); +} + +.markdown-body .pl-sr .pl-cce { + font-weight: bold; + color: var(--color-prettylights-syntax-string-regexp); +} + +.markdown-body .pl-ml { + color: var(--color-prettylights-syntax-markup-list); +} + +.markdown-body .pl-mh, +.markdown-body .pl-mh .pl-en, +.markdown-body .pl-ms { + font-weight: bold; + color: var(--color-prettylights-syntax-markup-heading); +} + +.markdown-body .pl-mi { + font-style: italic; + color: var(--color-prettylights-syntax-markup-italic); +} + +.markdown-body .pl-mb { + font-weight: bold; + color: var(--color-prettylights-syntax-markup-bold); +} + +.markdown-body .pl-md { + color: var(--color-prettylights-syntax-markup-deleted-text); + background-color: var(--color-prettylights-syntax-markup-deleted-bg); +} + +.markdown-body .pl-mi1 { + color: var(--color-prettylights-syntax-markup-inserted-text); + background-color: var(--color-prettylights-syntax-markup-inserted-bg); +} + +.markdown-body .pl-mc { + color: var(--color-prettylights-syntax-markup-changed-text); + background-color: var(--color-prettylights-syntax-markup-changed-bg); +} + +.markdown-body .pl-mi2 { + color: var(--color-prettylights-syntax-markup-ignored-text); + background-color: var(--color-prettylights-syntax-markup-ignored-bg); +} + +.markdown-body .pl-mdr { + font-weight: bold; + color: var(--color-prettylights-syntax-meta-diff-range); +} + +.markdown-body .pl-ba { + color: var(--color-prettylights-syntax-brackethighlighter-angle); +} + +.markdown-body .pl-sg { + color: var(--color-prettylights-syntax-sublimelinter-gutter-mark); +} + +.markdown-body .pl-corl { + text-decoration: underline; + color: var(--color-prettylights-syntax-constant-other-reference-link); +} + +.markdown-body g-emoji { + display: inline-block; + min-width: 1ch; + font-family: "Apple Color Emoji", "Segoe UI Emoji", "Segoe UI Symbol"; + font-size: 1em; + font-style: normal !important; + font-weight: var(--base-text-weight-normal, 400); + line-height: 1; + vertical-align: -0.075em; +} + +.markdown-body g-emoji img { + width: 1em; + height: 1em; +} + +.markdown-body .task-list-item { + list-style-type: none; +} + +.markdown-body .task-list-item label { + font-weight: var(--base-text-weight-normal, 400); +} + +.markdown-body .task-list-item.enabled label { + cursor: pointer; +} + +.markdown-body .task-list-item+.task-list-item { + margin-top: 4px; +} + +.markdown-body .task-list-item .handle { + display: none; +} + +.markdown-body .task-list-item-checkbox { + margin: 0 0.2em 0.25em -1.4em; + vertical-align: middle; +} + +.markdown-body .contains-task-list:dir(rtl) .task-list-item-checkbox { + margin: 0 -1.6em 0.25em 0.2em; +} + +.markdown-body .contains-task-list { + position: relative; +} + +.markdown-body .contains-task-list:hover .task-list-item-convert-container, +.markdown-body .contains-task-list:focus-within .task-list-item-convert-container { + display: block; + width: auto; + height: 24px; + overflow: visible; + clip: auto; +} + +.markdown-body ::-webkit-calendar-picker-indicator { + filter: invert(50%); +} \ No newline at end of file diff --git a/web/config/index.ts b/web/config/index.ts new file mode 100644 index 0000000000..f5a85b52a4 --- /dev/null +++ b/web/config/index.ts @@ -0,0 +1,104 @@ +const isDevelopment = process.env.NODE_ENV === 'development'; + +export let apiPrefix = ''; +let publicApiPrefix = ''; + +// NEXT_PUBLIC_API_PREFIX=/console/api NEXT_PUBLIC_PUBLIC_API_PREFIX=/api npm run start +if (process.env.NEXT_PUBLIC_API_PREFIX && process.env.NEXT_PUBLIC_PUBLIC_API_PREFIX) { + apiPrefix = process.env.NEXT_PUBLIC_API_PREFIX; + publicApiPrefix = process.env.NEXT_PUBLIC_PUBLIC_API_PREFIX; +} else if ( + globalThis.document?.body?.getAttribute('data-api-prefix') && + globalThis.document?.body?.getAttribute('data-pubic-api-prefix') +) { + // Not bulild can not get env from process.env.NEXT_PUBLIC_ in browser https://nextjs.org/docs/basic-features/environment-variables#exposing-environment-variables-to-the-browser + apiPrefix = globalThis.document.body.getAttribute('data-api-prefix') as string + publicApiPrefix = globalThis.document.body.getAttribute('data-pubic-api-prefix') as string +} else { + if (isDevelopment) { + apiPrefix = 'https://cloud.dify.dev/console/api'; + publicApiPrefix = 'https://dev.udify.app/api'; + } else { + // const domainParts = globalThis.location?.host?.split('.'); + // in production env, the host is dify.app . In other env, the host is [dev].dify.app + // const env = domainParts.length === 2 ? 'ai' : domainParts?.[0]; + apiPrefix = '/console/api'; + publicApiPrefix = `/api`; // avoid browser private mode api cross origin + } +} + + +export const API_PREFIX: string = apiPrefix; +export const PUBLIC_API_PREFIX: string = publicApiPrefix; + +// mock server +export const MOCK_API_PREFIX = 'http://127.0.0.1:3001' + +const EDITION = process.env.NEXT_PUBLIC_EDITION || globalThis.document?.body?.getAttribute('data-public-edition') +export const IS_CE_EDITION = EDITION === 'SELF_HOSTED' + +export const TONE_LIST = [ + { + id: 1, + name: 'Creative', + config: { + temperature: 0.8, + top_p: 0.9, + presence_penalty: 0.1, + frequency_penalty: 0.1, + }, + }, + { + id: 2, + name: 'Balanced', + config: { + temperature: 0.5, + top_p: 0.85, + presence_penalty: 0.2, + frequency_penalty: 0.3, + }, + }, + { + id: 3, + name: 'Precise', + config: { + temperature: 0.2, + top_p: 0.75, + presence_penalty: 0.5, + frequency_penalty: 0.5, + }, + }, + { + id: 4, + name: 'Custom', + }, +] + +export const LOCALE_COOKIE_NAME = 'locale' + +export const DEFAULT_VALUE_MAX_LEN = 48 + +export const zhRegex = /^[\u4e00-\u9fa5]$/gm +export const emojiRegex = /^[\uD800-\uDBFF][\uDC00-\uDFFF]$/gm +export const emailRegex = /^(([^<>()[\]\\.,;:\s@\"]+(\.[^<>()[\]\\.,;:\s@\"]+)*)|(\".+\"))@((\[[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\])|(([a-zA-Z\-0-9]+\.)+[a-zA-Z]{2,}))$/ +const MAX_ZN_VAR_NAME_LENGHT = 8 +const MAX_EN_VAR_VALUE_LENGHT = 16 +export const getMaxVarNameLength = (value: string) => { + if (zhRegex.test(value)) { + return MAX_ZN_VAR_NAME_LENGHT + } + return MAX_EN_VAR_VALUE_LENGHT +} + +export const MAX_VAR_KEY_LENGHT = 16 + +export const VAR_ITEM_TEMPLATE = { + key: '', + name: '', + type: 'string', + max_length: DEFAULT_VALUE_MAX_LEN, + required: true +} + + + diff --git a/web/context/app-context.ts b/web/context/app-context.ts new file mode 100644 index 0000000000..d31b9fedca --- /dev/null +++ b/web/context/app-context.ts @@ -0,0 +1,27 @@ +'use client' + +import { createContext, useContext } from 'use-context-selector' +import type { App } from '@/types/app' +import type { UserProfileResponse } from '@/models/common' + +export type AppContextValue = { + apps: App[] + mutateApps: () => void + userProfile: UserProfileResponse + mutateUserProfile: () => void +} + +const AppContext = createContext({ + apps: [], + mutateApps: () => { }, + userProfile: { + id: '', + name: '', + email: '', + }, + mutateUserProfile: () => { }, +}) + +export const useAppContext = () => useContext(AppContext) + +export default AppContext diff --git a/web/context/dataset-detail.ts b/web/context/dataset-detail.ts new file mode 100644 index 0000000000..b507fbcc4c --- /dev/null +++ b/web/context/dataset-detail.ts @@ -0,0 +1,5 @@ +import { createContext } from 'use-context-selector' + +const DatasetDetailContext = createContext<{ indexingTechnique?: string; }>({}) + +export default DatasetDetailContext diff --git a/web/context/datasets-context.tsx b/web/context/datasets-context.tsx new file mode 100644 index 0000000000..a954d612d4 --- /dev/null +++ b/web/context/datasets-context.tsx @@ -0,0 +1,20 @@ +'use client' + +import { createContext, useContext } from 'use-context-selector' +import type { DataSet } from '@/models/datasets' + +export type DatasetsContextValue = { + datasets: DataSet[] + mutateDatasets: () => void + currentDataset?: DataSet +} + +const DatasetsContext = createContext({ + datasets: [], + mutateDatasets: () => {}, + currentDataset: undefined +}) + +export const useDatasetsContext = () => useContext(DatasetsContext) + +export default DatasetsContext diff --git a/web/context/debug-configuration.ts b/web/context/debug-configuration.ts new file mode 100644 index 0000000000..bcf7ccd02f --- /dev/null +++ b/web/context/debug-configuration.ts @@ -0,0 +1,89 @@ +import { createContext } from 'use-context-selector' +import type { CompletionParams, Inputs, ModelConfig, PromptConfig, MoreLikeThisConfig, SuggestedQuestionsAfterAnswerConfig } from '@/models/debug' +import type { DataSet } from '@/models/datasets' + +type IDebugConfiguration = { + appId: string + hasSetAPIKEY: boolean + isTrailFinished: boolean + mode: string + conversationId: string | null // after first chat send + setConversationId: (conversationId: string | null) => void + introduction: string + setIntroduction: (introduction: string) => void + controlClearChatMessage: number + setControlClearChatMessage: (controlClearChatMessage: number) => void + prevPromptConfig: PromptConfig + setPrevPromptConfig: (prevPromptConfig: PromptConfig) => void + moreLikeThisConifg: MoreLikeThisConfig, + setMoreLikeThisConifg: (moreLikeThisConfig: MoreLikeThisConfig) => void + suggestedQuestionsAfterAnswerConfig: SuggestedQuestionsAfterAnswerConfig, + setSuggestedQuestionsAfterAnswerConfig: (suggestedQuestionsAfterAnswerConfig: SuggestedQuestionsAfterAnswerConfig) => void + formattingChanged: boolean + setFormattingChanged: (formattingChanged: boolean) => void + inputs: Inputs + setInputs: (inputs: Inputs) => void + query: string // user question + setQuery: (query: string) => void + // Belows are draft infos + completionParams: CompletionParams + setCompletionParams: (completionParams: CompletionParams) => void + // model_config + modelConfig: ModelConfig + setModelConfig: (modelConfig: ModelConfig) => void + dataSets: DataSet[] + setDataSets: (dataSet: DataSet[]) => void +} + +const DebugConfigurationContext = createContext({ + appId: '', + hasSetAPIKEY: false, + isTrailFinished: false, + mode: '', + conversationId: '', + setConversationId: () => { }, + introduction: '', + setIntroduction: () => { }, + controlClearChatMessage: 0, + setControlClearChatMessage: () => { }, + prevPromptConfig: { + prompt_template: '', + prompt_variables: [], + }, + setPrevPromptConfig: () => { }, + moreLikeThisConifg: { + enabled: false, + }, + setMoreLikeThisConifg: () => { }, + suggestedQuestionsAfterAnswerConfig: { + enabled: false, + }, + setSuggestedQuestionsAfterAnswerConfig: () => { }, + formattingChanged: false, + setFormattingChanged: () => { }, + inputs: {}, + setInputs: () => { }, + query: '', + setQuery: () => { }, + completionParams: { + max_tokens: 16, + temperature: 1, // 0-2 + top_p: 1, + presence_penalty: 1, // -2-2 + frequency_penalty: 1, // -2-2 + }, + setCompletionParams: () => { }, + modelConfig: { + provider: 'OPENAI', // 'OPENAI' + model_id: 'gpt-3.5-turbo', // 'gpt-3.5-turbo' + configs: { + prompt_template: '', + prompt_variables: [], + }, + }, + setModelConfig: () => { }, + dataSets: [], + setDataSets: () => { }, +}) + +export default DebugConfigurationContext diff --git a/web/context/i18n.ts b/web/context/i18n.ts new file mode 100644 index 0000000000..1870b65302 --- /dev/null +++ b/web/context/i18n.ts @@ -0,0 +1,18 @@ +import { createContext } from 'use-context-selector' +import type { Locale } from '@/i18n' + +type II18NContext = { + locale: Locale + i18n: Record, + setLocaleOnClient: (locale: Locale) => void + // setI8N: (i18n: Record) => void, +} + +const I18NContext = createContext({ + locale: 'en', + i18n: {}, + setLocaleOnClient: (lang: Locale) => { } + // setI8N: () => {}, +}) + +export default I18NContext diff --git a/web/context/workspace-context.tsx b/web/context/workspace-context.tsx new file mode 100644 index 0000000000..d25fe3c2a4 --- /dev/null +++ b/web/context/workspace-context.tsx @@ -0,0 +1,35 @@ +'use client' + +import { createContext, useContext } from 'use-context-selector' +import useSWR from 'swr' +import { fetchWorkspaces } from '@/service/common' +import type { IWorkspace } from '@/models/common' + +export type WorkspacesContextValue = { + workspaces: IWorkspace[] +} + +const WorkspacesContext = createContext({ + workspaces: [] +}) + +interface IWorkspaceProviderProps { + children: React.ReactNode +} +export const WorkspaceProvider = ({ + children +}: IWorkspaceProviderProps) => { + const { data } = useSWR({ url: '/workspaces' }, fetchWorkspaces) + + return ( + + {children} + + ) +} + +export const useWorkspacesContext = () => useContext(WorkspacesContext) + +export default WorkspacesContext diff --git a/web/dictionaries/en.json b/web/dictionaries/en.json new file mode 100644 index 0000000000..2ec65fce67 --- /dev/null +++ b/web/dictionaries/en.json @@ -0,0 +1,27 @@ +{ + "common": { + "confrim": "Confirm", + "cancel": "Cancel", + "refresh": "Refresh" + }, + "index": { + "welcome": "Welcome to " + }, + "signin": {}, + "app": { + "overview": { + "title": "Overview", + "To get started,": "To get started,", + "enter your OpenAI API key below": "enter your OpenAI API key below", + "Get your API key from OpenAI dashboard": "Get your API key from OpenAI dashboard", + "Token Usage": "Token Usage" + }, + "logs": { + "title": "Logs", + "description": "You can review and annotate the conversation and response text of the LLM, which will be used for subsequent model fine-tuning." + }, + "textGeneration": { + "history": "History" + } + } +} \ No newline at end of file diff --git a/web/dictionaries/zh-Hans.json b/web/dictionaries/zh-Hans.json new file mode 100644 index 0000000000..e88b03b733 --- /dev/null +++ b/web/dictionaries/zh-Hans.json @@ -0,0 +1,27 @@ +{ + "common": { + "confrim": "确定", + "cancel": "取消", + "refresh": "刷新" + }, + "index": { + "welcome": "欢迎来到 " + }, + "signin": {}, + "app": { + "overview": { + "title": "概览", + "To get started,": "从这里开始", + "enter your OpenAI API key below 👇": "输入你的 OpenAI API 密钥👇", + "Get your API key from OpenAI dashboard": "去 OpenAI 管理面板获取", + "Token Usage": "Token 消耗" + }, + "logs": { + "title": "日志", + "description": "日志记录了应用的运行情况,包括用户的输入和 AI 的回复。" + }, + "textGeneration": { + "history": "历史" + } + } +} \ No newline at end of file diff --git a/web/docker/entrypoint.sh b/web/docker/entrypoint.sh new file mode 100644 index 0000000000..d22db21473 --- /dev/null +++ b/web/docker/entrypoint.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +set -e + +export NEXT_PUBLIC_DEPLOY_ENV=${DEPLOY_ENV} +export NEXT_PUBLIC_EDITION=${EDITION} +export NEXT_PUBLIC_API_PREFIX=${CONSOLE_URL}/console/api +export NEXT_PUBLIC_PUBLIC_API_PREFIX=${APP_URL}/api + +/usr/local/bin/pm2 -v +/usr/local/bin/pm2-runtime start /app/web/pm2.json diff --git a/web/docker/pm2.json b/web/docker/pm2.json new file mode 100644 index 0000000000..7929a8315d --- /dev/null +++ b/web/docker/pm2.json @@ -0,0 +1,12 @@ +{ + "apps": [ + { + "name": "WebApp", + "exec_mode": "cluster", + "instances": 1, + "script": "./node_modules/next/dist/bin/next", + "cwd": "/app/web", + "args": "start" + } + ] + } diff --git a/web/hooks/use-breakpoints.ts b/web/hooks/use-breakpoints.ts new file mode 100644 index 0000000000..1aab56a9fd --- /dev/null +++ b/web/hooks/use-breakpoints.ts @@ -0,0 +1,27 @@ +'use client' +import React from 'react' + +export enum MediaType { + mobile = 'mobile', + tablet = 'tablet', + pc = 'pc', +} + +const useBreakpoints = () => { + const [width, setWidth] = React.useState(globalThis.innerWidth); + const media = (() => { + if (width <= 640) return MediaType.mobile; + if (width <= 768) return MediaType.tablet; + return MediaType.pc; + })(); + + React.useEffect(() => { + const handleWindowResize = () => setWidth(window.innerWidth); + window.addEventListener("resize", handleWindowResize); + return () => window.removeEventListener("resize", handleWindowResize); + }, []); + + return media; +} + +export default useBreakpoints \ No newline at end of file diff --git a/web/hooks/use-copy-to-clipboard.ts b/web/hooks/use-copy-to-clipboard.ts new file mode 100644 index 0000000000..76f792c6b4 --- /dev/null +++ b/web/hooks/use-copy-to-clipboard.ts @@ -0,0 +1,29 @@ +import { useState } from 'react' + +type CopiedValue = string | null +type CopyFn = (text: string) => Promise + +function useCopyToClipboard(): [CopiedValue, CopyFn] { + const [copiedText, setCopiedText] = useState(null) + + const copy: CopyFn = async text => { + if (!navigator?.clipboard) { + console.warn('Clipboard not supported') + return false + } + + try { + await navigator.clipboard.writeText(text) + setCopiedText(text) + return true + } catch (error) { + console.warn('Copy failed', error) + setCopiedText(null) + return false + } + } + + return [copiedText, copy] +} + +export default useCopyToClipboard \ No newline at end of file diff --git a/web/hooks/use-metadata.ts b/web/hooks/use-metadata.ts new file mode 100644 index 0000000000..516047fda5 --- /dev/null +++ b/web/hooks/use-metadata.ts @@ -0,0 +1,391 @@ +"use client"; +import { useTranslation } from "react-i18next"; +import dayjs from "dayjs"; +import { formatNumber, formatFileSize, formatTime } from '@/utils/format' +import type { DocType } from '@/models/datasets' + +export type inputType = 'input' | 'select' | 'textarea' +export type metadataType = DocType | 'originInfo' | 'technicalParameters' + +type MetadataMap = Record< + metadataType, + { + text: string; + allowEdit?: boolean; + icon?: React.ReactNode; + iconName?: string; + subFieldsMap: Record< + string, + { + label: string; + inputType?: inputType; + field?: string; + render?: (value: any, total?: number) => React.ReactNode | string + } + >; + } +>; + +const fieldPrefix = "datasetDocuments.metadata.field"; + +export const useMetadataMap = (): MetadataMap => { + const { t } = useTranslation(); + return { + book: { + text: t("datasetDocuments.metadata.type.book"), + iconName: "bookOpen", + subFieldsMap: { + title: { label: t(`${fieldPrefix}.book.title`) }, + language: { + label: t(`${fieldPrefix}.book.language`), + inputType: "select", + }, + author: { label: t(`${fieldPrefix}.book.author`) }, + publisher: { label: t(`${fieldPrefix}.book.publisher`) }, + publication_date: { label: t(`${fieldPrefix}.book.publicationDate`) }, + isbn: { label: t(`${fieldPrefix}.book.ISBN`) }, + category: { + label: t(`${fieldPrefix}.book.category`), + inputType: "select", + }, + }, + }, + web_page: { + text: t("datasetDocuments.metadata.type.webPage"), + iconName: "globe", + subFieldsMap: { + title: { label: t(`${fieldPrefix}.webPage.title`) }, + url: { label: t(`${fieldPrefix}.webPage.url`) }, + language: { + label: t(`${fieldPrefix}.webPage.language`), + inputType: "select", + }, + ['author/publisher']: { label: t(`${fieldPrefix}.webPage.authorPublisher`) }, + publish_date: { label: t(`${fieldPrefix}.webPage.publishDate`) }, + ['topics/keywords']: { label: t(`${fieldPrefix}.webPage.topicsKeywords`) }, + description: { label: t(`${fieldPrefix}.webPage.description`) }, + }, + }, + paper: { + text: t("datasetDocuments.metadata.type.paper"), + iconName: "graduationHat", + subFieldsMap: { + title: { label: t(`${fieldPrefix}.paper.title`) }, + language: { + label: t(`${fieldPrefix}.paper.language`), + inputType: "select", + }, + author: { label: t(`${fieldPrefix}.paper.author`) }, + publish_date: { label: t(`${fieldPrefix}.paper.publishDate`) }, + ['journal/conference_name']: { + label: t(`${fieldPrefix}.paper.journalConferenceName`), + }, + ['volume/issue/page_numbers']: { label: t(`${fieldPrefix}.paper.volumeIssuePage`) }, + doi: { label: t(`${fieldPrefix}.paper.DOI`) }, + ['topics/keywords']: { label: t(`${fieldPrefix}.paper.topicsKeywords`) }, + abstract: { + label: t(`${fieldPrefix}.paper.abstract`), + inputType: "textarea", + }, + }, + }, + social_media_post: { + text: t("datasetDocuments.metadata.type.socialMediaPost"), + iconName: "atSign", + subFieldsMap: { + platform: { label: t(`${fieldPrefix}.socialMediaPost.platform`) }, + ['author/username']: { + label: t(`${fieldPrefix}.socialMediaPost.authorUsername`), + }, + publish_date: { label: t(`${fieldPrefix}.socialMediaPost.publishDate`) }, + post_url: { label: t(`${fieldPrefix}.socialMediaPost.postURL`) }, + ['topics/tags']: { label: t(`${fieldPrefix}.socialMediaPost.topicsTags`) }, + }, + }, + personal_document: { + text: t("datasetDocuments.metadata.type.personalDocument"), + iconName: "file", + subFieldsMap: { + title: { label: t(`${fieldPrefix}.personalDocument.title`) }, + author: { label: t(`${fieldPrefix}.personalDocument.author`) }, + creation_date: { + label: t(`${fieldPrefix}.personalDocument.creationDate`), + }, + last_modified_date: { + label: t(`${fieldPrefix}.personalDocument.lastModifiedDate`), + }, + document_type: { + label: t(`${fieldPrefix}.personalDocument.documentType`), + inputType: "select", + }, + ['tags/category']: { + label: t(`${fieldPrefix}.personalDocument.tagsCategory`), + }, + }, + }, + business_document: { + text: t("datasetDocuments.metadata.type.businessDocument"), + iconName: "briefcase", + subFieldsMap: { + title: { label: t(`${fieldPrefix}.businessDocument.title`) }, + author: { label: t(`${fieldPrefix}.businessDocument.author`) }, + creation_date: { + label: t(`${fieldPrefix}.businessDocument.creationDate`), + }, + last_modified_date: { + label: t(`${fieldPrefix}.businessDocument.lastModifiedDate`), + }, + document_type: { + label: t(`${fieldPrefix}.businessDocument.documentType`), + inputType: "select", + }, + ['department/team']: { + label: t(`${fieldPrefix}.businessDocument.departmentTeam`), + }, + }, + }, + im_chat_log: { + text: t("datasetDocuments.metadata.type.IMChat"), + iconName: "messageTextCircle", + subFieldsMap: { + chat_platform: { label: t(`${fieldPrefix}.IMChat.chatPlatform`) }, + ['chat_participants/group_name']: { + label: t(`${fieldPrefix}.IMChat.chatPartiesGroupName`), + }, + start_date: { label: t(`${fieldPrefix}.IMChat.startDate`) }, + end_date: { label: t(`${fieldPrefix}.IMChat.endDate`) }, + participants: { label: t(`${fieldPrefix}.IMChat.participants`) }, + topicsKeywords: { + label: t(`${fieldPrefix}.IMChat.topicsKeywords`), + inputType: "textarea", + }, + fileType: { label: t(`${fieldPrefix}.IMChat.fileType`) }, + }, + }, + wikipedia_entry: { + text: t("datasetDocuments.metadata.type.wikipediaEntry"), + allowEdit: false, + subFieldsMap: { + title: { label: t(`${fieldPrefix}.wikipediaEntry.title`) }, + language: { + label: t(`${fieldPrefix}.wikipediaEntry.language`), + inputType: "select", + }, + web_page_url: { label: t(`${fieldPrefix}.wikipediaEntry.webpageURL`) }, + ['editor/contributor']: { + label: t(`${fieldPrefix}.wikipediaEntry.editorContributor`), + }, + last_edit_date: { + label: t(`${fieldPrefix}.wikipediaEntry.lastEditDate`), + }, + ['summary/introduction']: { + label: t(`${fieldPrefix}.wikipediaEntry.summaryIntroduction`), + inputType: "textarea", + }, + }, + }, + synced_from_notion: { + text: t("datasetDocuments.metadata.type.notion"), + allowEdit: false, + subFieldsMap: { + title: { label: t(`${fieldPrefix}.notion.title`) }, + language: { label: t(`${fieldPrefix}.notion.lang`), inputType: "select" }, + ['author/creator']: { label: t(`${fieldPrefix}.notion.author`) }, + creation_date: { label: t(`${fieldPrefix}.notion.createdTime`) }, + last_modified_date: { + label: t(`${fieldPrefix}.notion.lastModifiedTime`), + }, + notion_page_link: { label: t(`${fieldPrefix}.notion.url`) }, + ['category/tags']: { label: t(`${fieldPrefix}.notion.tag`) }, + description: { label: t(`${fieldPrefix}.notion.desc`) }, + }, + }, + synced_from_github: { + text: t("datasetDocuments.metadata.type.github"), + allowEdit: false, + subFieldsMap: { + repository_name: { label: t(`${fieldPrefix}.github.repoName`) }, + repository_description: { label: t(`${fieldPrefix}.github.repoDesc`) }, + ['repository_owner/organization']: { label: t(`${fieldPrefix}.github.repoOwner`) }, + code_filename: { label: t(`${fieldPrefix}.github.fileName`) }, + code_file_path: { label: t(`${fieldPrefix}.github.filePath`) }, + programming_language: { label: t(`${fieldPrefix}.github.programmingLang`) }, + github_link: { label: t(`${fieldPrefix}.github.url`) }, + open_source_license: { label: t(`${fieldPrefix}.github.license`) }, + commit_date: { label: t(`${fieldPrefix}.github.lastCommitTime`) }, + commit_author: { + label: t(`${fieldPrefix}.github.lastCommitAuthor`), + }, + }, + }, + originInfo: { + text: "", + allowEdit: false, + subFieldsMap: { + name: { label: t(`${fieldPrefix}.originInfo.originalFilename`) }, + "data_source_info.upload_file.size": { + label: t(`${fieldPrefix}.originInfo.originalFileSize`), + render: (value) => formatFileSize(value) + }, + created_at: { + label: t(`${fieldPrefix}.originInfo.uploadDate`), + render: (value) => dayjs.unix(value).format(t('datasetDocuments.metadata.dateTimeFormat') as string) + }, + completed_at: { + label: t(`${fieldPrefix}.originInfo.lastUpdateDate`), + render: (value) => dayjs.unix(value).format(t('datasetDocuments.metadata.dateTimeFormat') as string) + }, + data_source_type: { + label: t(`${fieldPrefix}.originInfo.source`), + render: (value) => t(`datasetDocuments.metadata.source.${value}`) + }, + }, + }, + technicalParameters: { + text: t("datasetDocuments.metadata.type.technicalParameters"), + allowEdit: false, + subFieldsMap: { + 'dataset_process_rule.mode': { + label: t(`${fieldPrefix}.technicalParameters.segmentSpecification`), + render: value => value === 'automatic' ? (t('datasetDocuments.embedding.automatic') as string) : (t('datasetDocuments.embedding.custom') as string) + }, + 'dataset_process_rule.rules.segmentation.max_tokens': { + label: t(`${fieldPrefix}.technicalParameters.segmentLength`), + render: value => formatNumber(value) + }, + average_segment_length: { + label: t(`${fieldPrefix}.technicalParameters.avgParagraphLength`), + render: (value) => `${formatNumber(value)} characters` + }, + segment_count: { + label: t(`${fieldPrefix}.technicalParameters.paragraphs`), + render: (value) => `${formatNumber(value)} paragraphs` + }, + hit_count: { + label: t(`${fieldPrefix}.technicalParameters.hitCount`), + render: (value, total) => { + const v = value || 0; + return `${!total ? 0 : ((v / total) * 100).toFixed(2)}% (${v}/${total})` + } + }, + indexing_latency: { + label: t(`${fieldPrefix}.technicalParameters.embeddingTime`), + render: (value) => formatTime(value) + }, + tokens: { + label: t(`${fieldPrefix}.technicalParameters.embeddedSpend`), + render: (value) => `${formatNumber(value)} tokens` + }, + }, + }, + }; +}; + +const langPrefix = "datasetDocuments.metadata.languageMap."; + +export const useLanguages = () => { + const { t } = useTranslation(); + return { + zh: t(langPrefix + "zh"), + en: t(langPrefix + "en"), + es: t(langPrefix + "es"), + fr: t(langPrefix + "fr"), + de: t(langPrefix + "de"), + ja: t(langPrefix + "ja"), + ko: t(langPrefix + "ko"), + ru: t(langPrefix + "ru"), + ar: t(langPrefix + "ar"), + pt: t(langPrefix + "pt"), + it: t(langPrefix + "it"), + nl: t(langPrefix + "nl"), + pl: t(langPrefix + "pl"), + sv: t(langPrefix + "sv"), + tr: t(langPrefix + "tr"), + he: t(langPrefix + "he"), + hi: t(langPrefix + "hi"), + da: t(langPrefix + "da"), + fi: t(langPrefix + "fi"), + no: t(langPrefix + "no"), + hu: t(langPrefix + "hu"), + el: t(langPrefix + "el"), + cs: t(langPrefix + "cs"), + th: t(langPrefix + "th"), + id: t(langPrefix + "id"), + }; +}; + +const bookCategoryPrefix = "datasetDocuments.metadata.categoryMap.book."; + +export const useBookCategories = () => { + const { t } = useTranslation(); + return { + fiction: t(bookCategoryPrefix + "fiction"), + biography: t(bookCategoryPrefix + "biography"), + history: t(bookCategoryPrefix + "history"), + science: t(bookCategoryPrefix + "science"), + technology: t(bookCategoryPrefix + "technology"), + education: t(bookCategoryPrefix + "education"), + philosophy: t(bookCategoryPrefix + "philosophy"), + religion: t(bookCategoryPrefix + "religion"), + socialSciences: t(bookCategoryPrefix + "socialSciences"), + art: t(bookCategoryPrefix + "art"), + travel: t(bookCategoryPrefix + "travel"), + health: t(bookCategoryPrefix + "health"), + selfHelp: t(bookCategoryPrefix + "selfHelp"), + businessEconomics: t(bookCategoryPrefix + "businessEconomics"), + cooking: t(bookCategoryPrefix + "cooking"), + childrenYoungAdults: t(bookCategoryPrefix + "childrenYoungAdults"), + comicsGraphicNovels: t(bookCategoryPrefix + "comicsGraphicNovels"), + poetry: t(bookCategoryPrefix + "poetry"), + drama: t(bookCategoryPrefix + "drama"), + other: t(bookCategoryPrefix + "other"), + }; +}; + +const personalDocCategoryPrefix = + "datasetDocuments.metadata.categoryMap.personalDoc."; + +export const usePersonalDocCategories = () => { + const { t } = useTranslation(); + return { + notes: t(personalDocCategoryPrefix + "notes"), + blogDraft: t(personalDocCategoryPrefix + "blogDraft"), + diary: t(personalDocCategoryPrefix + "diary"), + researchReport: t(personalDocCategoryPrefix + "researchReport"), + bookExcerpt: t(personalDocCategoryPrefix + "bookExcerpt"), + schedule: t(personalDocCategoryPrefix + "schedule"), + list: t(personalDocCategoryPrefix + "list"), + projectOverview: t(personalDocCategoryPrefix + "projectOverview"), + photoCollection: t(personalDocCategoryPrefix + "photoCollection"), + creativeWriting: t(personalDocCategoryPrefix + "creativeWriting"), + codeSnippet: t(personalDocCategoryPrefix + "codeSnippet"), + designDraft: t(personalDocCategoryPrefix + "designDraft"), + personalResume: t(personalDocCategoryPrefix + "personalResume"), + other: t(personalDocCategoryPrefix + "other"), + }; +}; + +const businessDocCategoryPrefix = + "datasetDocuments.metadata.categoryMap.businessDoc."; + +export const useBusinessDocCategories = () => { + const { t } = useTranslation(); + return { + meetingMinutes: t(businessDocCategoryPrefix + "meetingMinutes"), + researchReport: t(businessDocCategoryPrefix + "researchReport"), + proposal: t(businessDocCategoryPrefix + "proposal"), + employeeHandbook: t(businessDocCategoryPrefix + "employeeHandbook"), + trainingMaterials: t(businessDocCategoryPrefix + "trainingMaterials"), + requirementsDocument: t(businessDocCategoryPrefix + "requirementsDocument"), + designDocument: t(businessDocCategoryPrefix + "designDocument"), + productSpecification: t(businessDocCategoryPrefix + "productSpecification"), + financialReport: t(businessDocCategoryPrefix + "financialReport"), + marketAnalysis: t(businessDocCategoryPrefix + "marketAnalysis"), + projectPlan: t(businessDocCategoryPrefix + "projectPlan"), + teamStructure: t(businessDocCategoryPrefix + "teamStructure"), + policiesProcedures: t(businessDocCategoryPrefix + "policiesProcedures"), + contractsAgreements: t(businessDocCategoryPrefix + "contractsAgreements"), + emailCorrespondence: t(businessDocCategoryPrefix + "emailCorrespondence"), + other: t(businessDocCategoryPrefix + "other"), + }; +}; diff --git a/web/i18n/client.ts b/web/i18n/client.ts new file mode 100644 index 0000000000..39b6d01656 --- /dev/null +++ b/web/i18n/client.ts @@ -0,0 +1,16 @@ +import Cookies from 'js-cookie' +import type { Locale } from '.' +import { i18n } from '.' +import { LOCALE_COOKIE_NAME } from '@/config' +import { changeLanguage } from '@/i18n/i18next-config' + +// same logic as server +export const getLocaleOnClient = (): Locale => { + return Cookies.get(LOCALE_COOKIE_NAME) as Locale || i18n.defaultLocale +} + +export const setLocaleOnClient = (locale: Locale) => { + Cookies.set(LOCALE_COOKIE_NAME, locale) + changeLanguage(locale) + location.reload() +} diff --git a/web/i18n/i18next-config.ts b/web/i18n/i18next-config.ts new file mode 100644 index 0000000000..49d729e7c7 --- /dev/null +++ b/web/i18n/i18next-config.ts @@ -0,0 +1,92 @@ +'use client' +import i18n from 'i18next' +import { initReactI18next } from 'react-i18next' +import commonEn from './lang/common.en' +import commonZh from './lang/common.zh' +import loginEn from './lang/login.en' +import loginZh from './lang/login.zh' +import registerEn from './lang/register.en' +import registerZh from './lang/register.zh' +import layoutEn from './lang/layout.en' +import layoutZh from './lang/layout.zh' +import appEn from './lang/app.en' +import appZh from './lang/app.zh' +import appOverviewEn from './lang/app-overview.en' +import appOverviewZh from './lang/app-overview.zh' +import appDebugEn from './lang/app-debug.en' +import appDebugZh from './lang/app-debug.zh' +import appApiEn from './lang/app-api.en' +import appApiZh from './lang/app-api.zh' +import appLogEn from './lang/app-log.en' +import appLogZh from './lang/app-log.zh' +import shareEn from './lang/share-app.en' +import shareZh from './lang/share-app.zh' +import datasetEn from './lang/dataset.en' +import datasetZh from './lang/dataset.zh' +import datasetDocumentsEn from './lang/dataset-documents.en' +import datasetDocumentsZh from './lang/dataset-documents.zh' +import datasetHitTestingEn from './lang/dataset-hit-testing.en' +import datasetHitTestingZh from './lang/dataset-hit-testing.zh' +import datasetSettingsEn from './lang/dataset-settings.en' +import datasetSettingsZh from './lang/dataset-settings.zh' +import datasetCreationEn from './lang/dataset-creation.en' +import datasetCreationZh from './lang/dataset-creation.zh' +import { getLocaleOnClient } from '@/i18n/client' + +const resources = { + 'en': { + translation: { + common: commonEn, + layout: layoutEn, // page layout + login: loginEn, + register: registerEn, + // app + app: appEn, + appOverview: appOverviewEn, + appDebug: appDebugEn, + appApi: appApiEn, + appLog: appLogEn, + // share + share: shareEn, + dataset: datasetEn, + datasetDocuments: datasetDocumentsEn, + datasetHitTesting: datasetHitTestingEn, + datasetSettings: datasetSettingsEn, + datasetCreation: datasetCreationEn, + }, + }, + 'zh-Hans': { + translation: { + common: commonZh, + layout: layoutZh, + login: loginZh, + register: registerZh, + // app + app: appZh, + appOverview: appOverviewZh, + appDebug: appDebugZh, + appApi: appApiZh, + appLog: appLogZh, + // share + share: shareZh, + dataset: datasetZh, + datasetDocuments: datasetDocumentsZh, + datasetHitTesting: datasetHitTestingZh, + datasetSettings: datasetSettingsZh, + datasetCreation: datasetCreationZh, + }, + }, +} + +i18n.use(initReactI18next) + // init i18next + // for all options read: https://www.i18next.com/overview/configuration-options + .init({ + lng: getLocaleOnClient(), + fallbackLng: 'en', + // debug: true, + resources, + }) + +export const changeLanguage = i18n.changeLanguage +export default i18n diff --git a/web/i18n/i18next-serverside-config.ts b/web/i18n/i18next-serverside-config.ts new file mode 100644 index 0000000000..fe89475f79 --- /dev/null +++ b/web/i18n/i18next-serverside-config.ts @@ -0,0 +1,26 @@ +import { createInstance } from 'i18next' +import resourcesToBackend from 'i18next-resources-to-backend' +import { initReactI18next } from 'react-i18next/initReactI18next' +import { Locale } from '.' + +// https://locize.com/blog/next-13-app-dir-i18n/ +const initI18next = async (lng: Locale, ns: string) => { + const i18nInstance = createInstance() + await i18nInstance + .use(initReactI18next) + .use(resourcesToBackend((language: string, namespace: string) => import(`./lang/${namespace}.${language}.ts`))) + .init({ + lng: lng === 'zh-Hans' ? 'zh' : lng, + ns, + fallbackLng: 'en', + }) + return i18nInstance +} + +export async function useTranslation(lng: Locale, ns = '', options: Record = {}) { + const i18nextInstance = await initI18next(lng, ns) + return { + t: i18nextInstance.getFixedT(lng, ns, options.keyPrefix), + i18n: i18nextInstance + } +} \ No newline at end of file diff --git a/web/i18n/index.ts b/web/i18n/index.ts new file mode 100644 index 0000000000..914b8ae112 --- /dev/null +++ b/web/i18n/index.ts @@ -0,0 +1,6 @@ +export const i18n = { + defaultLocale: 'en', + locales: ['en', 'zh-Hans'], +} as const + +export type Locale = typeof i18n['locales'][number] diff --git a/web/i18n/lang/app-api.en.ts b/web/i18n/lang/app-api.en.ts new file mode 100644 index 0000000000..cf188e0106 --- /dev/null +++ b/web/i18n/lang/app-api.en.ts @@ -0,0 +1,76 @@ +const translation = { + apiServer: "API Server", + apiKey: "API Key", + status: "Status", + disabled: "Disabled", + ok: "In Service", + copy: "Copy", + copied: "Copied", + never: "Never", + apiKeyModal: { + apiSecretKey: "API Secret key", + apiSecretKeyTips: "To prevent API abuse, protect your API Key. Avoid using it as plain text in front-end code. :)", + createNewSecretKey: "Create new Secret key", + secretKey: "Secret Key", + created: "CREATED", + lastUsed: "LAST USED", + generateTips: "Keep this key in a secure and accessible place." + }, + actionMsg: { + deleteConfirmTitle: "Delete this secret key?", + deleteConfirmTips: "This action cannot be undone.", + ok: "OK" + }, + completionMode: { + title: "Completion App API", + info: "For high-quality text generation, such as articles, summaries, and translations, use the completion-messages API with user input. Text generation relies on the model parameters and prompt templates set in Dify Prompt Engineering.", + createCompletionApi: "Create Completion Message", + createCompletionApiTip: "Create a Completion Message to support the question-and-answer mode.", + inputsTips: "(Optional) Provide user input fields as key-value pairs, corresponding to variables in Prompt Eng. Key is the variable name, Value is the parameter value. If the field type is Select, the submitted Value must be one of the preset choices.", + queryTips: "User input text content.", + blocking: "Blocking type, waiting for execution to complete and returning results. (Requests may be interrupted if the process is long)", + streaming: "streaming returns. Implementation of streaming return based on SSE (Server-Sent Events).", + messageFeedbackApi: "Message feedback (like)", + messageFeedbackApiTip: "Rate received messages on behalf of end-users with likes or dislikes. This data is visible in the Logs & Annotations page and used for future model fine-tuning.", + messageIDTip: "Message ID", + ratingTip: "like or dislike, null is undo", + parametersApi: "Obtain application parameter information", + parametersApiTip: "Retrieve configured Input parameters, including variable names, field names, types, and default values. Typically used for displaying these fields in a form or filling in default values after the client loads." + }, + chatMode: { + title: "Chat App API", + info: "For versatile conversational apps using a Q&A format, call the chat-messages API to initiate dialogue. Maintain ongoing conversations by passing the returned conversation_id. Response parameters and templates depend on Dify Prompt Eng. settings.", + createChatApi: "Create chat message", + createChatApiTip: "Create a new conversation message or continue an existing dialogue.", + inputsTips: "(Optional) Provide user input fields as key-value pairs, corresponding to variables in Prompt Eng. Key is the variable name, Value is the parameter value. If the field type is Select, the submitted Value must be one of the preset choices.", + queryTips: "User input/question content", + blocking: "Blocking type, waiting for execution to complete and returning results. (Requests may be interrupted if the process is long)", + streaming: "streaming returns. Implementation of streaming return based on SSE (Server-Sent Events).", + conversationIdTip: "(Optional) Conversation ID: leave empty for first-time conversation; pass conversation_id from context to continue dialogue.", + messageFeedbackApi: "Message terminal user feedback, like", + messageFeedbackApiTip: "Rate received messages on behalf of end-users with likes or dislikes. This data is visible in the Logs & Annotations page and used for future model fine-tuning.", + messageIDTip: "Message ID", + ratingTip: "like or dislike, null is undo", + chatMsgHistoryApi: "Get the chat history message", + chatMsgHistoryApiTip: "The first page returns the latest `limit` bar, which is in reverse order.", + chatMsgHistoryConversationIdTip: "Conversation ID", + chatMsgHistoryFirstId: "ID of the first chat record on the current page. The default is none.", + chatMsgHistoryLimit: "How many chats are returned in one request", + conversationsListApi: "Get conversation list", + conversationsListApiTip: "Gets the session list of the current user. By default, the last 20 sessions are returned.", + conversationsListFirstIdTip: "The ID of the last record on the current page, default none.", + conversationsListLimitTip: "How many chats are returned in one request", + conversationRenamingApi: "Conversation renaming", + conversationRenamingApiTip: "Rename conversations; the name is displayed in multi-session client interfaces.", + conversationRenamingNameTip: "New name", + parametersApi: "Obtain application parameter information", + parametersApiTip: "Retrieve configured Input parameters, including variable names, field names, types, and default values. Typically used for displaying these fields in a form or filling in default values after the client loads." + }, + develop: { + requestBody: "Request Body", + pathParams: "Path Params", + query: "Query" + } +} + +export default translation diff --git a/web/i18n/lang/app-api.zh.ts b/web/i18n/lang/app-api.zh.ts new file mode 100644 index 0000000000..49dc184dda --- /dev/null +++ b/web/i18n/lang/app-api.zh.ts @@ -0,0 +1,76 @@ +const translation = { + apiServer: "API 服务器", + apiKey: "API 密钥", + status: "状态", + disabled: "已停用", + ok: "运行中", + copy: "复制", + copied: "已复制", + never: "从未", + apiKeyModal: { + apiSecretKey: "API 密钥", + apiSecretKeyTips: "如果不想你的应用 API 被滥用,请保护好你的 API Key :) 最佳实践是避免在前端代码中明文引用。", + createNewSecretKey: "创建密钥", + secretKey: "密钥", + created: "创建时间", + lastUsed: "最后使用", + generateTips: "请将此密钥保存在安全且可访问的地方。" + }, + actionMsg: { + deleteConfirmTitle: "删除此密钥?", + deleteConfirmTips: "删除密钥无法撤销,正在使用中的应用会受影响。", + ok: "好的" + }, + completionMode: { + title: "文本生成型应用 API", + info: "可用于生成高质量文本的应用,例如生成文章、摘要、翻译等,通过调用 completion-messages 接口,发送用户输入得到生成文本结果。用于生成文本的模型参数和提示词模版取决于开发者在 Dify 提示词编排页的设置。", + createCompletionApi: "创建文本补全消息", + createCompletionApiTip: "创建文本补全消息,支持一问一答模式。", + inputsTips: "(选填)以键值对方式提供用户输入字段,与提示词编排中的变量对应。Key 为变量名称,Value 是参数值。如果字段类型为 Select,传入的 Value 需为预设选项之一。", + queryTips: "用户输入的文本正文。", + blocking: "blocking 阻塞型,等待执行完毕后返回结果。(请求若流程较长可能会被中断)", + streaming: "streaming 流式返回。基于 SSE(**[Server-Sent Events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events)**)实现流式返回。", + messageFeedbackApi: "消息反馈(点赞)", + messageFeedbackApiTip: "代表最终用户对返回消息进行评价,可以点赞与点踩,该数据将在“日志与标注”页中可见,并用于后续的模型微调。", + messageIDTip: "消息 ID", + ratingTip: "like 或 dislike, 空值为撤销", + parametersApi: "获取应用配置信息", + parametersApiTip: "获取已配置的 Input 参数,包括变量名、字段名称、类型与默认值。通常用于客户端加载后显示这些字段的表单或填入默认值。" + }, + chatMode: { + title: "对话型应用 API", + info: "可用于大部分场景的对话型应用,采用一问一答模式与用户持续对话。要开始一个对话请调用 chat-messages 接口,通过继续传入返回的 conversation_id 可持续保持该会话。", + createChatApi: "发送对话消息", + createChatApiTip: "创建会话消息,或基于此前的对话继续发送消息。", + inputsTips: "(选填)以键值对方式提供用户输入字段,与提示词编排中的变量对应。Key 为变量名称,Value 是参数值。如果字段类型为 Select,传入的 Value 需为预设选项之一。", + queryTips: " 用户输入/提问内容", + blocking: "blocking 阻塞型,等待执行完毕后返回结果。(请求若流程较长可能会被中断)", + streaming: "streaming 流式返回。基于 SSE(**[Server-Sent Events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events)**)实现流式返回。", + conversationIdTip: "(选填)会话标识符,首次对话可为空,如果要继续对话请传入上下文返回的 conversation_id", + messageFeedbackApi: "消息反馈(点赞)", + messageFeedbackApiTip: "代表最终用户对返回消息进行评价,可以点赞与点踩,该数据将在“日志与标注”页中可见,并用于后续的模型微调。", + messageIDTip: "消息 ID", + ratingTip: "like 或 dislike, 空值为撤销", + chatMsgHistoryApi: "获取会话历史消息", + chatMsgHistoryApiTip: "滚动加载形式返回历史聊天记录,第一页返回最新 `limit` 条,即:倒序返回。", + chatMsgHistoryConversationIdTip: "会话 ID", + chatMsgHistoryFirstId: "当前页第一条聊天记录的 ID,默认 none", + chatMsgHistoryLimit: "一次请求返回多少条聊天记录", + conversationsListApi: "获取会话列表", + conversationsListApiTip: "获取当前用户的会话列表,默认返回最近的 20 条。", + conversationsListFirstIdTip: " 当前页最前面一条记录的 ID,默认 none", + conversationsListLimitTip: "一次请求返回多少条记录", + conversationRenamingApi: "会话重命名", + conversationRenamingApiTip: "对会话进行重命名,会话名称用于显示在支持多会话的客户端上。", + conversationRenamingNameTip: "新的名称", + parametersApi: "获取应用配置信息", + parametersApiTip: "获取已配置的 Input 参数,包括变量名、字段名称、类型与默认值。通常用于客户端加载后显示这些字段的表单或填入默认值。" + }, + develop: { + requestBody: "Request Body", + pathParams: "Path Params", + query: "Query" + } +} + +export default translation diff --git a/web/i18n/lang/app-debug.en.ts b/web/i18n/lang/app-debug.en.ts new file mode 100644 index 0000000000..a7b87a7351 --- /dev/null +++ b/web/i18n/lang/app-debug.en.ts @@ -0,0 +1,139 @@ +const translation = { + pageTitle: "Prompt Engineering", + operation: { + applyConfig: "Publish", + resetConfig: "Reset", + addFeature: "Add Feature", + stopResponding: "Stop responding", + }, + notSetAPIKey: { + title: "LLM provider key has not been set", + trailFinished: "Trail finished", + description: "The LLM provider key has not been set, and it needs to be set before debugging.", + settingBtn: "Go to settings", + }, + trailUseGPT4Info: { + title: 'Does not support gpt-4 now', + description: 'Use gpt-4, please set API Key.', + }, + feature: { + groupChat: { + title: 'Chat enhance', + description: 'Add pre-conversation settings for apps can enhance user experience.' + }, + groupExperience: { + title: 'Experience enhance', + }, + conversationOpener: { + title: "Conversation remakers", + description: "In a chat app, the first sentence that the AI actively speaks to the user is usually used as a welcome." + }, + suggestedQuestionsAfterAnswer: { + title: 'Follow-up', + description: 'Setting up next questions suggestion can give users a better chat.', + resDes: '3 suggestions for user next question.', + tryToAsk: 'Try to ask', + }, + moreLikeThis: { + title: "More like this", + description: "Generate multiple texts at once, and then edit and continue to generate", + generateNumTip: "Number of each generated times", + tip: "Using this feature will incur additional tokens overhead" + }, + dataSet: { + title: "Context", + noData: "You can import datasets as context", + words: "Words", + textBlocks: "Text Blocks", + selectTitle: "Select reference dataset", + selected: "Datasets selected", + noDataSet: "No dataset found", + toCreate: "Go to create", + notSupportSelectMulti: 'Currently only support one dataset' + } + }, + resetConfig: { + title: "Confirm reset?", + message: + "Reset discards changes, restoring the last published configuration.", + }, + errorMessage: { + nameOfKeyRequired: "name of the key: {{key}} required", + valueOfVarRequired: "Variables value can not be empty", + queryRequired: "Request text is required.", + waitForResponse: + "Please wait for the response to the previous message to complete.", + }, + chatSubTitle: "Pre Prompt", + completionSubTitle: "Prefix Prompt", + promptTip: + "Prompts guide AI responses with instructions and constraints. Insert variables like {{input}}. This prompt won't be visible to users.", + formattingChangedTitle: "Formatting changed", + formattingChangedText: + "Modifying the formatting will reset the debug area, are you sure?", + variableTitle: "Variables", + variableTip: + "Users fill variables in a form, automatically replacing variables in the prompt.", + notSetVar: "Variables allow users to introduce prompt words or opening remarks when filling out forms. You can try entering \"{{input}}\" in the prompt words.", + autoAddVar: "Undefined variables referenced in pre-prompt, are you want to add them in user input form?", + variableTable: { + key: "Variable Key", + name: "User Input Field Name", + optional: "Optional", + type: "Input Type", + action: "Actions", + typeString: "String", + typeSelect: "Select", + }, + varKeyError: { + canNoBeEmpty: "Variable key can not be empty", + tooLong: "Variable key: {{key}} too length. Can not be longer then 16 characters", + notValid: "Variable key: {{key}} is invalid. Can only contain letters, numbers, and underscores", + notStartWithNumber: "Variable key: {{key}} can not start with a number", + }, + variableConig: { + modalTitle: "Field settings", + description: "Setting for variable {{varName}}", + fieldType: 'Field type', + string: 'Text', + select: 'Select', + notSet: 'Not set, try typing {{input}} in the prefix prompt', + stringTitle: "Form text box options", + maxLength: "Max length", + options: "Options", + addOption: "Add option", + }, + openingStatement: { + title: "Opening remarks", + add: "Add", + writeOpner: "Write remarks", + placeholder: "Write your remarks message here", + noDataPlaceHolder: + "Starting the conversation with the user can help AI establish a closer connection with them in conversational applications.", + varTip: 'You can use variables, try type {{variable}}', + tooShort: "At least 20 words of initial prompt are required to generate an opening remarks for the conversation.", + notIncludeKey: "The initial prompt does not include the variable: {{key}}. Please add it to the initial prompt.", + }, + modelConfig: { + model: "Model", + setTone: "Set tone of responses", + title: "Model and Parameters", + }, + inputs: { + title: "Debugging and Previewing", + noPrompt: "Try write some prompt in pre-prompt input", + userInputField: "User Input Field", + noVar: "Fill in the value of the variable, which will be automatically replaced in the prompt word every time a new session is started.", + chatVarTip: + "Fill in the value of the variable, which will be automatically replaced in the prompt word every time a new session is started", + completionVarTip: + "Fill in the value of the variable, which will be automatically replaced in the prompt words every time a question is submitted.", + previewTitle: "Prompt preview", + queryTitle: "Query content", + queryPlaceholder: "Please enter the request text.", + run: "RUN", + }, + result: "Output Text", +}; + +export default translation; diff --git a/web/i18n/lang/app-debug.zh.ts b/web/i18n/lang/app-debug.zh.ts new file mode 100644 index 0000000000..98db2a2bea --- /dev/null +++ b/web/i18n/lang/app-debug.zh.ts @@ -0,0 +1,134 @@ +const translation = { + pageTitle: "提示词编排", + operation: { + applyConfig: "发布", + resetConfig: "重置", + addFeature: "添加功能", + stopResponding: "停止响应", + }, + notSetAPIKey: { + title: "LLM 提供者的密钥未设置", + trailFinished: "试用已结束", + description: "在调试之前需要设置 LLM 提供者的密钥。", + settingBtn: "去设置", + }, + trailUseGPT4Info: { + title: '当前不支持使用 gpt-4', + description: '使用 gpt-4,请设置 API Key', + }, + feature: { + groupChat: { + title: '聊天增强', + description: '为聊天型应用添加预对话设置,可以提升用户体验。' + }, + groupExperience: { + title: '体验增强', + }, + conversationOpener: { + title: "对话开场白", + description: "在对话型应用中,让 AI 主动说第一段话可以拉近与用户间的距离。" + }, + suggestedQuestionsAfterAnswer: { + title: '下一步问题建议', + description: '设置下一步问题建议可以让用户更好的对话。', + resDes: '回答结束后系统会给出 3 个建议', + tryToAsk: '试着问问', + }, + moreLikeThis: { + title: "更多类似的", + description: '一次生成多条文本,可在此基础上编辑并继续生成', + generateNumTip: "每次生成数", + tip: "使用此功能将会额外消耗 tokens" + }, + dataSet: { + title: "上下文", + noData: "您可以导入数据集作为上下文", + words: "词", + textBlocks: "文本块", + selectTitle: "选择引用数据集", + selected: "个数据集被选中", + noDataSet: "未找到数据集", + toCreate: "去创建", + notSupportSelectMulti: '目前只支持引用一个数据集' + } + }, + resetConfig: { + title: "确认重置?", + message: "重置将丢失当前页面所有修改,恢复至上次发布时的配置", + }, + errorMessage: { + nameOfKeyRequired: "变量 {{key}} 对应的名称必填", + valueOfVarRequired: "变量值必填", + queryRequired: "主要文本必填", + waitForResponse: "请等待上条信息响应完成", + }, + chatSubTitle: "对话前提示词", + completionSubTitle: "前缀提示词", + promptTip: + "提示词用于对 AI 的回复做出一系列指令和约束。可插入表单变量,例如 {{input}}。这段提示词不会被最终用户所看到。", + formattingChangedTitle: "编排已改变", + formattingChangedText: "修改编排将重置调试区域,确定吗?", + variableTitle: "变量", + notSetVar: "变量能使用户输入表单引入提示词或开场白,你可以试试在提示词中输入输入 {{input}}", + variableTip: + "变量将以表单形式让用户在对话前填写,用户填写的表单内容将自动替换提示词中的变量。", + autoAddVar: "提示词中引用了未定义的变量,是否自动添加到用户输入表单中?", + variableTable: { + key: "变量 Key", + name: "字段名称", + optional: "可选", + type: "类型", + action: "操作", + typeString: "文本", + typeSelect: "下拉选项", + }, + varKeyError: { + canNoBeEmpty: "变量不能为空", + tooLong: "变量: {{key}} 长度太长。不能超过 16 个字符", + notValid: "变量: {{key}} 非法。只能包含英文字符,数字和下划线", + notStartWithNumber: "变量: {{key}} 不能以数字开头", + }, + variableConig: { + modalTitle: "变量设置", + description: "设置变量 {{varName}}", + fieldType: '字段类型', + string: '文本', + select: '下拉选项', + notSet: '未设置,在 Prompt 中输入 {{input}} 试试', + stringTitle: "文本框设置", + maxLength: "最大长度", + options: "选项", + addOption: "添加选项", + }, + openingStatement: { + title: "对话开场白", + add: "添加开场白", + writeOpner: "编写开场白", + placeholder: "请在这里输入开场白", + noDataPlaceHolder: + "在对话型应用中,让 AI 主动说第一段话可以拉近与用户间的距离。", + varTip: '你可以使用变量, 试试输入 {{variable}}', + tooShort: "对话前提示词至少 20 字才能生成开场白", + notIncludeKey: "前缀提示词中不包含变量 {{key}}。请在前缀提示词中添加该变量", + }, + modelConfig: { + model: "语言模型", + setTone: "模型设置", + title: "模型及参数", + }, + inputs: { + title: "调试与预览", + noPrompt: "尝试在对话前提示框中编写一些提示词", + userInputField: "用户输入", + noVar: "填入变量的值,每次启动新会话时该变量将自动替换提示词中的变量。", + chatVarTip: "填入变量的值,该值将在每次开启一个新会话时自动替换到提示词中", + completionVarTip: "填入变量的值,该值将在每次提交问题时自动替换到提示词中", + previewTitle: "提示词预览", + queryTitle: "查询内容", + queryPlaceholder: "请输入文本内容", + run: "运行", + }, + result: "结果", +}; + +export default translation; diff --git a/web/i18n/lang/app-log.en.ts b/web/i18n/lang/app-log.en.ts new file mode 100644 index 0000000000..9df8cee527 --- /dev/null +++ b/web/i18n/lang/app-log.en.ts @@ -0,0 +1,67 @@ +const translation = { + title: 'Logs & Annotations', + description: 'The logs record the running status of the application, including user inputs and AI replies.', + dateTimeFormat: 'MM/DD/YYYY hh:mm A', + table: { + header: { + time: 'Time', + endUser: 'End User', + input: 'Input', + output: 'Output', + summary: 'Summary', + messageCount: 'Message Count', + userRate: 'User Rate', + adminRate: 'Op. Rate', + }, + pagination: { + previous: 'Prev', + next: 'Next', + }, + empty: { + noChat: 'No conversation yet', + noOutput: 'No output', + element: { + title: 'Is anyone there?', + content: 'Observe and annotate interactions between end-users and AI applications here to continuously improve AI accuracy. You can try sharing or testing the Web App yourself, then return to this page.', + }, + }, + }, + detail: { + time: 'Time', + conversationId: 'Conversation ID', + promptTemplate: 'Prompt Template', + promptTemplateBeforeChat: 'Prompt Template Before Chat · As System Message', + annotationTip: 'Improvements Marked by {{user}}', + timeConsuming: '', + second: 's', + tokenCost: 'Token spent', + loading: 'loading', + operation: { + like: 'like', + dislike: 'dislike', + addAnnotation: 'Add Improvement', + editAnnotation: 'Edit Improvement', + annotationPlaceholder: 'Enter the expected answer that you want AI to reply, which can be used for model fine-tuning and continuous improvement of text generation quality in the future.', + }, + }, + filter: { + period: { + today: 'Today', + last7days: 'Last 7 Days', + last4weeks: 'Last 4 weeks', + last3months: 'Last 3 months', + last12months: 'Last 12 months', + monthToDate: 'Month to date', + quarterToDate: 'Quarter to date', + yearToDate: 'Year to date', + allTime: 'All time', + }, + annotation: { + all: 'All', + annotated: 'Annotated Improvements ({{count}} items)', + not_annotated: 'Not Annotated', + }, + }, +} + +export default translation diff --git a/web/i18n/lang/app-log.zh.ts b/web/i18n/lang/app-log.zh.ts new file mode 100644 index 0000000000..4e8874d714 --- /dev/null +++ b/web/i18n/lang/app-log.zh.ts @@ -0,0 +1,67 @@ +const translation = { + title: '日志与标注', + description: '日志记录了应用的运行情况,包括用户的输入和 AI 的回复。', + dateTimeFormat: 'YYYY-MM-DD HH:mm', + table: { + header: { + time: '时间', + endUser: '用户', + input: '输入', + output: '输出', + summary: '摘要', + messageCount: '消息数', + userRate: '用户反馈', + adminRate: '管理员反馈', + }, + pagination: { + previous: '上一页', + next: '下一页', + }, + empty: { + noChat: '未开始的对话', + noOutput: '无输出', + element: { + title: '这里有人吗', + content: '在这里观测和标注最终用户和 AI 应用程序之间的交互,以不断提高 AI 的准确性。您可以试试 WebApp 或分享出去,然后返回此页面。', + }, + }, + }, + detail: { + time: '时间', + conversationId: '对话 ID', + promptTemplate: '前缀提示词', + promptTemplateBeforeChat: '对话前提示词 · 以系统消息提交', + annotationTip: '{{user}} 标记的改进回复', + timeConsuming: '耗时', + second: ' 秒', + tokenCost: '花费 Token', + loading: '加载中', + operation: { + like: '赞同', + dislike: '反对', + addAnnotation: '标记改进回复', + editAnnotation: '编辑改进回复', + annotationPlaceholder: '输入你希望 AI 回复的预期答案,这在今后可用于模型微调,持续改进文本生成质量。', + }, + }, + filter: { + period: { + today: '今天', + last7days: '过去 7 天', + last4weeks: '过去 4 周', + last3months: '过去 3 月', + last12months: '过去 12 月', + monthToDate: '本月至今', + quarterToDate: '本季度至今', + yearToDate: '本年至今', + allTime: '所有时间', + }, + annotation: { + all: '全部', + annotated: '已标注改进({{count}} 项)', + not_annotated: '未标注', + }, + }, +} + +export default translation diff --git a/web/i18n/lang/app-overview.en.ts b/web/i18n/lang/app-overview.en.ts new file mode 100644 index 0000000000..97136db98c --- /dev/null +++ b/web/i18n/lang/app-overview.en.ts @@ -0,0 +1,102 @@ +const translation = { + welcome: { + firstStepTip: 'To get started,', + enterKeyTip: 'enter your OpenAI API Key below', + getKeyTip: 'Get your API Key from OpenAI dashboard', + placeholder: 'Your OpenAI API Key(eg.sk-xxxx)', + }, + overview: { + title: 'Overview', + appInfo: { + explanation: 'Ready-to-use AI WebApp', + accessibleAddress: 'Public URL', + preview: 'Preview', + share: { + entry: 'Share', + explanation: 'Share the following URL to invite more people to access the application.', + shareUrl: 'Share URL', + copyLink: 'Copy Link', + regenerate: 'Regenerate', + }, + preUseReminder: 'Please enable WebApp before continuing.', + settings: { + entry: 'Settings', + title: 'WebApp Settings', + webName: 'WebApp Name', + webDesc: 'WebApp Description', + webDescTip: 'This text will be displayed on the client side, providing basic guidance on how to use the application', + webDescPlaceholder: 'Enter the description of the WebApp', + language: 'Language', + more: { + entry: 'Show more settings', + copyright: 'Copyright', + copyRightPlaceholder: 'Enter the name of the author or organization', + privacyPolicy: 'Privacy Policy', + privacyPolicyPlaceholder: 'Enter the privacy policy', + privacyPolicyTip: 'Helps visitors understand the data the application collects, see Dify\'s Privacy Policy.', + }, + }, + customize: { + way: 'way', + entry: 'Want to customize your WebApp?', + title: 'Customize AI WebApp', + explanation: 'You can customize the frontend of the Web App to fit your scenario and style needs.', + way1: { + name: 'Fork the client code, modify it and deploy to Vercel (recommended)', + step1: 'Fork the client code and modify it', + step1Tip: 'Click here to fork the source code into your GitHub account and modify the code', + step1Operation: 'Dify-WebClient', + step2: 'Configure the Web', + step2Tip: 'Copy the Web API and APP ID,then paste them into the client code config/index.ts', + step3: 'Deploy to Vercel', + step3Tip: 'Click here to import the repository into Vercel and deploy', + step3Operation: 'Import repository', + }, + way2: { + name: 'Write client-side code to call the API and deploy it to a server', + operation: 'Documentation', + }, + }, + }, + apiInfo: { + title: 'Backend service API', + explanation: 'Easily integrated into your application', + accessibleAddress: 'API Token', + doc: 'API Reference', + }, + status: { + running: 'In service', + disable: 'Disable', + }, + }, + analysis: { + title: 'Analysis', + totalMessages: { + title: 'Total Messages', + explanation: 'Daily AI interactions count; prompt engineering/debugging excluded.', + }, + activeUsers: { + title: 'Active Users', + explanation: 'Unique users engaging in Q&A with AI; prompt engineering/debugging excluded.', + }, + tokenUsage: { + title: 'Token Usage', + explanation: 'Reflects the daily token usage of the language model for the application, useful for cost control purposes.', + consumed: 'Consumed', + }, + avgSessionInteractions: { + title: 'Avg. Session Interactions', + explanation: 'Continuous user-AI communication count; for conversation-based apps.', + }, + userSatisfactionRate: { + title: 'User Satisfaction Rate', + explanation: 'The number of likes per 1,000 messages. This indicates the proportion of answers that users are highly satisfied with.', + }, + avgResponseTime: { + title: 'Avg. Response Time', + explanation: 'Time (ms) for AI to process/respond; for text-based apps.', + }, + }, +} + +export default translation diff --git a/web/i18n/lang/app-overview.zh.ts b/web/i18n/lang/app-overview.zh.ts new file mode 100644 index 0000000000..3a045602f9 --- /dev/null +++ b/web/i18n/lang/app-overview.zh.ts @@ -0,0 +1,102 @@ +const translation = { + welcome: { + firstStepTip: '开始之前,', + enterKeyTip: '请先在下方输入你的 OpenAI API Key', + getKeyTip: '从 OpenAI 获取你的 API Key', + placeholder: '你的 OpenAI API Key(例如 sk-xxxx)', + }, + overview: { + title: '概览', + appInfo: { + explanation: '开箱即用的 AI WebApp', + accessibleAddress: '公开访问 URL', + preview: '预览', + share: { + entry: '分享', + explanation: '将以下网址分享出去,让更多人访问该应用', + shareUrl: '分享 URL', + copyLink: '复制链接', + regenerate: '重新生成', + }, + preUseReminder: '使用前请先打开开关', + settings: { + entry: '设置', + title: 'WebApp 设置', + webName: 'WebApp 名称', + webDesc: 'WebApp 描述', + webDescTip: '以下文字将展示在客户端中,对应用进行说明和使用上的基本引导', + webDescPlaceholder: '请输入 WebApp 的描述', + language: '语言', + more: { + entry: '展示更多设置', + copyright: '版权', + copyRightPlaceholder: '请输入作者或组织名称', + privacyPolicy: '隐私政策', + privacyPolicyPlaceholder: '请输入隐私政策', + privacyPolicyTip: '帮助访问者了解该应用收集的数据,可参考 Dify 的隐私政策。', + }, + }, + customize: { + way: '方法', + entry: '想要进一步自定义 WebApp?', + title: '定制化 AI WebApp', + explanation: '你可以定制化 Web App 前端以符合你的情景与风格需求', + way1: { + name: 'Fork 客户端代码修改后部署到 Vercel(推荐)', + step1: 'Fork 客户端代码并修改', + step1Tip: '点击此处 Fork 源码到你的 GitHub 中,然后修改代码', + step1Operation: 'Dify-WebClient', + step2: '配置 Web APP', + step2Tip: '复制 Web API 秘钥 和 APP ID 拷贝到客户端代码 config/index.ts 中', + step3: '部署到 Vercel 中', + step3Tip: '点击此处将仓库导入到 Vercel 中部署', + step3Operation: '导入仓库', + }, + way2: { + name: '编写客户端调用 API 并部署到服务器中', + operation: '查看文档', + }, + }, + }, + apiInfo: { + title: '后端服务 API', + explanation: '可集成至你的应用的后端即服务', + accessibleAddress: 'API 访问凭据', + doc: '查阅 API 文档', + }, + status: { + running: '运行中', + disable: '已停用', + }, + }, + analysis: { + title: '分析', + totalMessages: { + title: '全部消息数', + explanation: '反映 AI 每天的互动总次数,每回答用户一个问题算一条 Message。提示词编排和调试的消息不计入。', + }, + activeUsers: { + title: '活跃用户数', + explanation: '与 AI 有效互动,即有一问一答以上的唯一用户数。提示词编排和调试的会话不计入。', + }, + tokenUsage: { + title: '费用消耗', + explanation: '反映每日该应用请求语言模型的 Tokens 花费,用于成本控制。', + consumed: '耗费', + }, + avgSessionInteractions: { + title: '平均会话互动数', + explanation: '反应每个会话用户的持续沟通次数,如果用户与 AI 问答了 10 轮,即为 10。该指标反映了用户粘性。仅在对话型应用提供。', + }, + userSatisfactionRate: { + title: '用户满意度', + explanation: '每 1000 条消息的点赞数。反应了用户对回答十分满意的比例。', + }, + avgResponseTime: { + title: '平均响应时间', + explanation: '衡量 AI 应用处理和回复用户请求所花费的平均时间,单位为毫秒,反映性能和用户体验。仅在文本型应用提供。', + }, + }, +} + +export default translation diff --git a/web/i18n/lang/app.en.ts b/web/i18n/lang/app.en.ts new file mode 100644 index 0000000000..73c2a5b1b6 --- /dev/null +++ b/web/i18n/lang/app.en.ts @@ -0,0 +1,40 @@ +const translation = { + createApp: 'Create new App', + modes: { + completion: 'Text Generator', + chat: 'Chat App', + }, + createFromConfigFile: 'Create app from config file', + deleteAppConfirmTitle: 'Delete this app?', + deleteAppConfirmContent: + 'Deleting the app is irreversible. Users will no longer be able to access your app, and all prompt configurations and logs will be permanently deleted.', + appDeleted: 'App deleted', + appDeleteFailed: 'Failed to delete app', + join: 'Join the community', + communityIntro: + 'Discuss with team members, contributors and developers on different channels.', + roadmap: 'See our roadmap', + newApp: { + startToCreate: 'Let\'s start with your new app', + captionName: 'Give your app a name', + captionAppType: 'What kind of app do you want?', + previewDemo: 'Preview demo', + chatApp: 'Chat App', + chatAppIntro: + 'I want to build a chat-based application. This app uses a question-and-answer format, allowing for multiple rounds of continuous conversation.', + completeApp: 'Text Generator', + completeAppIntro: + 'I want to create an application that generates high-quality text based on prompts, such as generating articles, summaries, translations, and more.', + showTemplates: 'I want to choose from a template', + hideTemplates: 'Go back to mode selection', + Create: 'Create', + Cancel: 'Cancel', + nameNotEmpty: 'Name cannot be empty', + appTemplateNotSelected: 'Please select a template', + appTypeRequired: 'Please select an app type', + appCreated: 'App created', + appCreateFailed: 'Failed to create app', + }, +} + +export default translation diff --git a/web/i18n/lang/app.zh.ts b/web/i18n/lang/app.zh.ts new file mode 100644 index 0000000000..03fde52c35 --- /dev/null +++ b/web/i18n/lang/app.zh.ts @@ -0,0 +1,39 @@ +const translation = { + createApp: '创建应用', + modes: { + completion: '文本生成型', + chat: '对话型', + }, + createFromConfigFile: '通过导入应用配置文件创建', + deleteAppConfirmTitle: '确认删除应用?', + deleteAppConfirmContent: + '删除应用将无法撤销。用户将不能访问你的应用,所有 Prompt 编排配置和日志均将一并被删除。', + appDeleted: '应用已删除', + appDeleteFailed: '应用删除失败', + join: '参与社区', + communityIntro: '与团队成员、贡献者和开发者在不同频道中交流', + roadmap: '产品路线图', + newApp: { + startToCreate: '开始创建一个新应用', + captionName: '给应用起个名字', + captionAppType: '想要哪种应用类型?', + previewDemo: '预览 Demo', + chatApp: '对话型应用', + chatAppIntro: + '我要构建一个聊天场景的应用。该应用采用一问一答模式与用户持续对话。', + completeApp: '文本生成应用', + completeAppIntro: + '我要构建一个根据提示生成高质量文本的应用,例如生成文章、摘要、翻译等', + showTemplates: '我想从范例模板中选择', + hideTemplates: '返回应用类型选择', + Create: '创建', + Cancel: '取消', + nameNotEmpty: '名称不能为空', + appTemplateNotSelected: '请选择应用模版', + appTypeRequired: '请选择应用类型', + appCreated: '应用已创建', + appCreateFailed: '应用创建失败', + }, +} + +export default translation diff --git a/web/i18n/lang/common.en.ts b/web/i18n/lang/common.en.ts new file mode 100644 index 0000000000..6771cbc42c --- /dev/null +++ b/web/i18n/lang/common.en.ts @@ -0,0 +1,205 @@ +const translation = { + api: { + success: 'Success', + saved: 'Saved', + create: 'Created', + remove: 'Removed', + }, + operation: { + confirm: 'Confirm', + cancel: 'Cancel', + clear: 'Clear', + save: 'Save', + edit: 'Edit', + add: 'Add', + refresh: 'Restart', + search: 'Search', + change: 'Change', + remove: 'Remove', + send: 'Send', + copy: 'Copy', + lineBreak: 'Line break', + sure: 'I\'m sure', + }, + placeholder: { + input: 'Please enter', + select: 'Please select', + }, + unit: { + char: 'chars', + }, + actionMsg: { + modifiedSuccessfully: 'Modified successfully', + modificationFailed: 'Modification failed', + copySuccessfully: 'Copied successfully', + }, + model: { + params: { + temperature: 'Temperature', + temperatureTip: + 'Controls randomness: Lowering results in less random completions. As the temperature approaches zero, the model will become deterministic and repetitive.', + topP: 'Top P', + topPTip: + 'Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options are considered.', + presencePenalty: 'Presence penalty', + presencePenaltyTip: + 'How much to penalize new tokens based on whether they appear in the text so far. Increases the model\'s likelihood to talk about new topics.', + frequencyPenalty: 'Frequency penalty', + frequencyPenaltyTip: + 'How much to penalize new tokens based on their existing frequency in the text so far. Decreases the model\'s likelihood to repeat the same line verbatim.', + maxToken: 'Max token', + maxTokenTip: + 'Max tokens generated is 2,048 or 4,000, depending on the model. Prompt and completion share this limit. One token is roughly 1 English character.', + }, + tone: { + Creative: 'Creative', + Balanced: 'Balanced', + Precise: 'Precise', + Custom: 'Custom', + }, + }, + menus: { + status: 'beta', + apps: 'Apps', + plugins: 'Plugins', + pluginsTips: 'Integrate third-party plugins or create ChatGPT-compatible AI-Plugins.', + datasets: 'Datasets', + datasetsTips: 'COMING SOON: Import your own text data or write data in real-time via Webhook for LLM context enhancement.', + newApp: 'New App', + newDataset: 'Create dataset', + }, + userProfile: { + settings: 'Settings', + workspace: 'Workspace', + createWorkspace: 'Create Workspace', + helpCenter: 'Help Document', + about: 'About', + logout: 'Log out', + }, + settings: { + accountGroup: 'ACCOUNT', + workplaceGroup: 'WORKPLACE', + account: "My account", + members: "Members", + integrations: "Integrations", + language: "Language", + provider: "Model Provider" + }, + account: { + avatar: 'Avatar', + name: 'Name', + email: 'Email', + langGeniusAccount: 'Dify account', + langGeniusAccountTip: 'Your Dify account and associated user data.', + editName: 'Edit Name', + showAppLength: 'Show {{length}} apps', + }, + members: { + team: 'Team', + invite: 'Invite', + name: 'NAME', + lastActive: 'LAST ACTIVE', + role: 'ROLES', + pending: 'Pending...', + owner: 'Owner', + admin: 'Admin', + adminTip: 'Can build apps & manage team settings', + normal: 'Normal', + normalTip: 'Only can use apps,can not build apps', + inviteTeamMember: 'Invite team member', + inviteTeamMemberTip: 'The other person will receive an email. If he\'s already a Dify user, he can access your team data directly after signing in.', + email: 'Email', + emailInvalid: 'Invalid Email Format', + emailPlaceholder: 'Input Email', + sendInvite: 'Send Invite', + invitationSent: 'Invitation sent', + invitationSentTip: 'The invitation is sent, and they can sign in to Dify to access your team data.', + ok: 'OK', + removeFromTeam: 'Remove from team', + removeFromTeamTip: 'Will remove team access', + setAdmin: 'Set as administrator', + setMember: 'Set to ordinary member', + disinvite: 'Cancel the invitation', + deleteMember: 'Delete Member', + you: '(You)', + }, + integrations: { + connected: 'Connected', + google: 'Google', + googleAccount: 'Login with Google account', + github: 'GitHub', + githubAccount: 'Login with GitHub account', + connect: 'Connect' + }, + language: { + displayLanguage: 'Display Language', + timezone: 'Time Zone', + }, + provider: { + apiKey: "API Key", + enterYourKey: "Enter your API key here", + invalidKey: "Invalid OpenAI API key", + validating: "Validating key...", + saveFailed: "Save api key failed", + apiKeyExceedBill: "This API KEY has no quota available, please read", + addKey: 'Add Key', + comingSoon: 'Coming Soon', + editKey: 'Edit', + invalidApiKey: 'Invalid API key', + azure: { + resourceName: 'Resource Name', + resourceNamePlaceholder: 'The name of your Azure OpenAI Resource.', + deploymentId: 'Deployment ID', + deploymentIdPlaceholder: 'The deployment name you chose when you deployed the model.', + apiVersion: 'API Version', + apiVersionPlaceholder: 'The API version to use for this operation.', + apiKey: 'API Key', + apiKeyPlaceholder: 'Enter your API key here', + helpTip: 'Learn Azure OpenAI Service', + }, + openaiHosted: { + openaiHosted: 'Hosted OpenAI', + onTrial: 'ON TRIAL', + exhausted: 'QUOTA EXHAUSTED', + desc: 'The OpenAI hosting service provided by Dify allows you to use models such as GPT-3.5. Before your trial quota is used up, you need to set up other model providers.', + callTimes: 'Call times', + usedUp: 'Trial quota used up. Add own Model Provider.', + useYourModel: 'Currently using own Model Provider.', + close: 'Close', + }, + encrypted: { + front: 'Your API KEY will be encrypted and stored using', + back: ' technology.', + } + }, + about: { + changeLog: 'Changlog', + updateNow: 'Update now', + nowAvailable: 'Dify {{version}} is now available.', + latestAvailable: 'Dify {{version}} is the latest version available.', + }, + appMenus: { + overview: 'Overview', + promptEng: 'Prompt Eng.', + apiAccess: 'API Access', + logAndAnn: 'Logs & Ann.', + }, + environment: { + testing: 'TESTING', + development: 'DEVELOPMENT', + }, + appModes: { + completionApp: 'Text Generator', + chatApp: 'Chat App', + }, + datasetMenus: { + documents: 'Documents', + hitTesting: 'Hit Testing', + settings: 'Settings', + emptyTip: 'The data set has not been associated, please go to the application or plug-in to complete the association.', + viewDoc: 'View documentation', + relatedApp: 'linked apps', + }, +} + +export default translation diff --git a/web/i18n/lang/common.zh.ts b/web/i18n/lang/common.zh.ts new file mode 100644 index 0000000000..a2f03a1bc1 --- /dev/null +++ b/web/i18n/lang/common.zh.ts @@ -0,0 +1,206 @@ +const translation = { + api: { + success: '成功', + saved: '已保存', + create: '已创建', + remove: '已移除', + }, + operation: { + confirm: '确认', + cancel: '取消', + clear: '清空', + save: '保存', + edit: '编辑', + add: '添加', + refresh: '重新开始', + search: '搜索', + change: '更改', + remove: '移除', + send: '发送', + copy: '复制', + lineBreak: '换行', + sure: '我确定', + }, + placeholder: { + input: '请输入', + select: '请选择', + }, + unit: { + char: '个字符', + }, + actionMsg: { + modifiedSuccessfully: '修改成功', + modificationFailed: '修改失败', + copySuccessfully: '复制成功', + }, + model: { + params: { + temperature: '多样性', + temperatureTip: + '较高的 Temperature 设置将导致更多样和创造性的输出,而较低的 Temperature 将产生更保守的输出并且类似于训练数据。', + topP: '采样范围', + topPTip: + 'Top P值越高,输出与训练文本越相似,Top P值越低,输出越有创意和变化。它可用于使输出更适合特定用例。', + presencePenalty: '词汇控制', + presencePenaltyTip: + 'Presence penalty 是根据新词是否出现在目前的文本中来对其进行惩罚。正值将降低模型谈论新话题的可能性。', + frequencyPenalty: '重复控制', + frequencyPenaltyTip: + 'Frequency penalty 是根据重复词在目前文本中的出现频率来对其进行惩罚。正值将不太可能重复常用单词和短语。', + maxToken: '最大 Token', + maxTokenTip: + '生成的最大令牌数为 2,048 或 4,000,取决于模型。提示和完成共享令牌数限制。一个令牌约等于 1 个英文或 4 个中文字符。', + }, + tone: { + Creative: '创意', + Balanced: '平衡', + Precise: '精确', + Custom: '自定义', + }, + }, + menus: { + status: 'beta', + apps: '应用', + plugins: '插件', + pluginsTips: '集成第三方插件或创建与 ChatGPT 兼容的 AI 插件。', + datasets: '数据集', + datasetsTips: '即将到来: 上传自己的长文本数据,或通过 Webhook 集成自己的数据源', + newApp: '创建应用', + newDataset: '创建数据集', + }, + userProfile: { + settings: '设置', + workspace: '工作空间', + createWorkspace: '创建工作空间', + helpCenter: '帮助文档', + about: '关于', + logout: '登出', + }, + settings: { + accountGroup: '账户', + workplaceGroup: '工作空间', + account: "我的账户", + members: "成员", + integrations: "集成", + language: "语言", + provider: "模型供应商" + }, + account: { + avatar: '头像', + name: '用户名', + email: '邮箱', + edit: '编辑', + langGeniusAccount: 'Dify 账号', + langGeniusAccountTip: '您的 Dify 账号和相关的用户数据。', + editName: '编辑名字', + showAppLength: '显示 {{length}} 个应用', + }, + members: { + team: '团队', + invite: '邀请', + name: '姓名', + lastActive: '上次活动时间', + role: '角色', + pending: '待定...', + owner: '所有者', + admin: '管理员', + adminTip: '能够建立应用程序和管理团队设置', + normal: '正常人', + normalTip: '只能使用应用程序,不能建立应用程序', + inviteTeamMember: '邀请团队成员', + inviteTeamMemberTip: '对方会收到一封邮件。如果他已经是 Dify 用户则可直接在登录后访问你的团队数据。', + email: '邮箱', + emailInvalid: '邮箱格式无效', + emailPlaceholder: '输入邮箱', + sendInvite: '发送邀请', + invitationSent: '邀请已发送', + invitationSentTip: '邀请已发送,对方登录 Dify 后即可访问你的团队数据。', + ok: '好的', + removeFromTeam: '移除团队', + removeFromTeamTip: '将取消团队访问', + setAdmin: '设为管理员', + setMember: '设为普通成员', + disinvite: '取消邀请', + deleteMember: '删除成员', + you: '(你)', + }, + integrations: { + connected: '登录方式', + google: 'Google', + googleAccount: 'Google 账号登录', + github: 'GitHub', + githubAccount: 'GitHub 账号登录', + connect: '绑定' + }, + language: { + displayLanguage: '界面语言', + timezone: '时区', + }, + provider: { + apiKey: "API 密钥", + enterYourKey: "输入你的 API 密钥", + invalidKey: '无效的 OpenAI API 密钥', + validating: "验证密钥中...", + saveFailed: "API 密钥保存失败", + apiKeyExceedBill: "此 API KEY 已没有可用配额,请阅读", + addKey: '添加 密钥', + comingSoon: '即将推出', + editKey: '编辑', + invalidApiKey: '无效的 API 密钥', + azure: { + resourceName: 'Resource Name', + resourceNamePlaceholder: 'The name of your Azure OpenAI Resource.', + deploymentId: 'Deployment ID', + deploymentIdPlaceholder: 'The deployment name you chose when you deployed the model.', + apiVersion: 'API Version', + apiVersionPlaceholder: 'The API version to use for this operation.', + apiKey: 'API Key', + apiKeyPlaceholder: 'Enter your API key here', + helpTip: '了解 Azure OpenAI Service', + }, + openaiHosted: { + openaiHosted: '托管 OpenAI', + onTrial: '体验', + exhausted: '超出限额', + desc: '托管 OpenAI 由 Dify 提供的托管 OpenAI 服务,你可以使用 GPT-3.5 等模型,在体验额度消耗完毕前你需要设置其它模型供应商。', + callTimes: '调用次数', + usedUp: '试用额度已用完,请在下方添加自己的模型供应商', + useYourModel: '当前正在使用你自己的模型供应商。', + close: '关闭', + }, + encrypted: { + front: '密钥将使用 ', + back: ' 技术进行加密和存储。', + } + }, + about: { + changeLog: '更新日志', + updateNow: '现在更新', + nowAvailable: 'Dify {{version}} 现已可用。', + latestAvailable: 'Dify {{version}} 已是最新版本。', + }, + appMenus: { + overview: '概览', + promptEng: '提示词编排', + apiAccess: '访问 API', + logAndAnn: '日志与标注', + }, + environment: { + testing: '测试环境', + development: '开发环境', + }, + appModes: { + completionApp: '文本生成型应用', + chatApp: '对话型应用', + }, + datasetMenus: { + documents: '文档', + hitTesting: '命中测试', + settings: '设置', + emptyTip: ' 数据集尚未关联,请前往应用程序或插件完成关联。', + viewDoc: '查看文档', + relatedApp: '个关联应用', + }, +} + +export default translation diff --git a/web/i18n/lang/dataset-creation.en.ts b/web/i18n/lang/dataset-creation.en.ts new file mode 100644 index 0000000000..2a0ecf574f --- /dev/null +++ b/web/i18n/lang/dataset-creation.en.ts @@ -0,0 +1,108 @@ +const translation = { + steps: { + header: { + creation: 'Create Dataset', + update: 'Add data', + }, + one: 'Choose data source', + two: 'Text Preprocessing and Cleaning', + three: 'Execute and finish', + }, + error: { + unavailable: 'This dataset is not avaliable', + }, + stepOne: { + filePreview: 'File Preview', + dataSourceType: { + file: 'Import from text file', + notion: 'Sync from Notion', + web: 'Sync from web site', + }, + uploader: { + title: 'Upload text file', + button: 'Drag and drop file, or', + browse: 'Browse', + tip: 'Supports txt, html, markdown, and pdf.', + validation: { + typeError: 'File type not supported', + size: 'File too large. Maximum is 15MB', + count: 'Multiple files not supported', + }, + cancel: 'Cancel', + change: 'Change', + failed: 'Upload failed', + }, + button: 'next', + emptyDatasetCreation: 'I want to create an empty dataset', + modal: { + title: 'Create an empty dataset', + tip: 'An empty dataset will contain no documents, and you can upload documents any time.', + input: 'Dataset name', + placeholder: 'Please input', + nameNotEmpty: 'Name cannot be empty', + nameLengthInvaild: 'Name must be between 1 to 40 characters', + cancelButton: 'Cancel', + confirmButton: 'Create', + failed: 'Creation failed', + }, + }, + stepTwo: { + segmentation: 'Segmentation settings', + auto: 'Automatic', + autoDescription: 'Automatically set segmentation and preprocessing rules. Unfamiliar users are recommended to select this.', + custom: 'Custom', + customDescription: 'Customize segmentation rules, segmentation length, and preprocessing rules, etc.', + separator: 'Segment identifier', + separatorPlaceholder: 'For example, newline (\\\\n) or special separator (such as "***")', + maxLength: 'Maximum segment length', + rules: 'Text preprocessing rules', + removeExtraSpaces: 'Replace consecutive spaces, newlines and tabs', + removeUrlEmails: 'Delete all URLs and email addresses', + removeStopwords: 'Remove stopwords such as "a", "an", "the"', + preview: 'Confirm & Preview', + reset: 'Reset', + indexMode: 'Index mode', + qualified: 'High Quality', + recommend: 'Recommend', + qualifiedTip: 'Call OpenAI\'s embedding interface for processing to provide higher accuracy when users query.', + warning: 'Please set up the model provider API key first.', + click: 'Go to settings', + economical: 'Economical', + economicalTip: 'Use offline vector engines, keyword indexes, etc. to reduce accuracy without spending tokens', + emstimateCost: 'Estimation', + emstimateSegment: 'Estimated segments', + segmentCount: 'segments', + calculating: 'Calculating...', + fileName: 'Preprocess document', + lastStep: 'Last step', + nextStep: 'Save & Process', + sideTipTitle: 'Why segment and preprocess?', + sideTipP1: 'When processing text data, segmentation and cleaning are two important preprocessing steps.', + sideTipP2: 'Segmentation splits long text into paragraphs so models can understand better. This improves the quality and relevance of model results.', + sideTipP3: 'Cleaning removes unnecessary characters and formats, making datasets cleaner and easier to parse.', + sideTipP4: 'Proper segmentation and cleaning improve model performance, providing more accurate and valuable results.', + previewTitle: 'Preview', + characters: 'characters', + indexSettedTip: 'To change the index method, please go to the ', + datasetSettingLink: 'dataset settings.', + }, + stepThree: { + creationTitle: '🎉 Dataset created', + creationContent: 'We automatically named the dataset, you can modify it at any time', + label: 'Dataset name', + additionTitle: '🎉 Document uploaded', + additionP1: 'The document has been uploaded to the dataset', + additionP2: ', you can find it in the document list of the dataset。', + stop: 'Stop processing', + resume: 'Resume processing', + navTo: 'Go to document', + sideTipTitle: 'What\'s next', + sideTipContent: 'After the document finishes indexing, the dataset can be integrated into the application as context, you can find the context setting in the prompt orchestration page. You can also create it as an independent ChatGPT indexing plugin for release.', + modelTitle: 'Are you sure to stop embedding?', + modelContent: 'If you need to resume processing later, you will continue from where you left off.', + modelButtonConfirm: "Confirm", + modelButtonCancel: 'Cancel' + }, +} + +export default translation diff --git a/web/i18n/lang/dataset-creation.zh.ts b/web/i18n/lang/dataset-creation.zh.ts new file mode 100644 index 0000000000..4fbaea9661 --- /dev/null +++ b/web/i18n/lang/dataset-creation.zh.ts @@ -0,0 +1,108 @@ +const translation = { + steps: { + header: { + creation: '创建数据集', + update: '上传文件', + }, + one: '选择数据源', + two: '文本分段与清洗', + three: '处理并完成', + }, + error: { + unavailable: '该数据集不可用', + }, + stepOne: { + filePreview: '文件预览', + dataSourceType: { + file: '导入已有文本', + notion: '同步自 Notion 内容', + web: '同步自 Web 站点', + }, + uploader: { + title: '上传文本文件', + button: '拖拽文件至此,或者', + browse: '选择文件', + tip: '已支持 TXT, HTML, Markdown, PDF', + validation: { + typeError: '文件类型不支持', + size: '文件太大了,不能超过 15MB', + count: '暂不支持多个文件', + }, + cancel: '取消', + change: '更改文件', + failed: '上传失败', + }, + button: '下一步', + emptyDatasetCreation: '创建一个空数据集', + modal: { + title: '创建空数据集', + tip: '空数据集中还没有文档,你可以在今后任何时候上传文档至该数据集。', + input: '数据集名称', + placeholder: '请输入数据集名称', + nameNotEmpty: '名称不能为空', + nameLengthInvaild: '名称长度不能超过 40 个字符', + cancelButton: '取消', + confirmButton: '创建', + failed: '创建失败', + }, + }, + stepTwo: { + segmentation: '分段设置', + auto: '自动分段与清洗', + autoDescription: '自动设置分段规则与预处理规则,如果不了解这些参数建议选择此项', + custom: '自定义', + customDescription: '自定义分段规则、分段长度以及预处理规则等参数', + separator: '分段标识符', + separatorPlaceholder: '例如换行符(\n)或特定的分隔符(如 "***")', + maxLength: '分段最大长度', + rules: '文本预处理规则', + removeExtraSpaces: '替换掉连续的空格、换行符和制表符', + removeUrlEmails: '删除所有 URL 和电子邮件地址', + removeStopwords: '去除停用词,例如 “a”,“an”,“the” 等', + preview: '确认并预览', + reset: '重置', + indexMode: '索引方式', + qualified: '高质量', + recommend: '推荐', + qualifiedTip: '调用 OpenAI 的嵌入接口进行处理,以在用户查询时提供更高的准确度', + warning: '请先完成模型供应商的 API KEY 设置。.', + click: '前往设置', + economical: '经济', + economicalTip: '使用离线的向量引擎、关键词索引等方式,降低了准确度但无需花费 Token', + emstimateCost: '执行嵌入预估消耗', + emstimateSegment: '预估分段数', + segmentCount: '段', + calculating: '计算中...', + fileName: '预处理文档', + lastStep: '上一步', + nextStep: '保存并处理', + sideTipTitle: '为什么要分段和预处理?', + sideTipP1: '在处理文本数据时,分段和清洗是两个重要的预处理步骤。', + sideTipP2: '分段的目的是将长文本拆分成较小的段落,以便模型更有效地处理和理解。这有助于提高模型生成的结果的质量和相关性。', + sideTipP3: '清洗则是对文本进行预处理,删除不必要的字符、符号或格式,使数据集更加干净、整洁,便于模型解析。', + sideTipP4: '通过对数据集进行适当的分段和清洗,可以提高模型在实际应用中的表现,从而为用户提供更准确、更有价值的结果。', + previewTitle: '分段预览', + characters: '字符', + indexSettedTip: '要更改索引方法,请转到', + datasetSettingLink: '数据集设置。', + }, + stepThree: { + creationTitle: '🎉 数据集已创建', + creationContent: '我们自动为该数据集起了个名称,您也可以随时修改', + label: '数据集名称', + additionTitle: '🎉 文档已上传', + additionP1: '文档已上传至数据集:', + additionP2: ',你可以在数据集的文档列表中找到它。', + stop: '停止处理', + resume: '恢复处理', + navTo: '前往文档', + sideTipTitle: '接下来做什么', + sideTipContent: '当文档完成索引处理后,数据集即可集成至应用内作为上下文使用,你可以在提示词编排页找到上下文设置。你也可以创建成可独立使用的 ChatGPT 索引插件发布。', + modelTitle: '确认停止索引过程吗?', + modelContent:'如果您需要稍后恢复处理,则从停止处继续。', + modelButtonConfirm: "确认停止", + modelButtonCancel: '取消' + }, +} + +export default translation diff --git a/web/i18n/lang/dataset-documents.en.ts b/web/i18n/lang/dataset-documents.en.ts new file mode 100644 index 0000000000..1892ecdebc --- /dev/null +++ b/web/i18n/lang/dataset-documents.en.ts @@ -0,0 +1,314 @@ +const translation = { + list: { + title: "Documents", + desc: "All files of the dataset are shown here, and the entire dataset can be linked to Dify citations or indexed via the Chat plugin.", + addFile: "add file", + table: { + header: { + fileName: "FILE NAME", + words: "WORDS", + hitCount: "HIT COUNT", + uploadTime: "UPLOAD TIME", + status: "STATUS", + action: "ACTION", + }, + }, + action: { + uploadFile: 'Upload new file', + settings: 'Segment settings', + archive: 'Archive', + delete: "Delete", + enableWarning: 'Archived file cannot be enabled', + }, + index: { + enable: 'Enable', + disable: 'Disable', + all: 'All', + enableTip: 'The file can be indexed', + disableTip: 'The file cannot be indexed', + }, + status: { + queuing: 'Queuing', + indexing: 'Indexing', + parsed: 'Parsed', + error: 'Error', + available: 'Available', + enabled: 'Enabled', + disabled: 'Disabled', + archived: 'Archived', + }, + empty: { + title: "There is no documentation yet", + upload: { + tip: "You can upload files, sync from the website, or from webb apps like Notion, GitHub, etc.", + }, + sync: { + tip: "Dify will periodically download files from your Notion and complete processing.", + }, + }, + delete: { + title: 'Are you sure Delete?', + content: 'If you need to resume processing later, you will continue from where you left off' + } + }, + metadata: { + title: "Metadata", + desc: "Labeling metadata for documents allows AI to access them in a timely manner and exposes the source of references for users.", + dateTimeFormat: 'MMMM D, YYYY hh:mm A', + docTypeSelectTitle: "Please select a document type", + docTypeChangeTitle: "Change document type", + docTypeSelectWarning: + "If the document type is changed, the now filled metadata will no longer be preserved", + firstMetaAction: "Let's go", + placeholder: { + add: 'Add ', + select: 'Select ', + }, + source: { + upload_file: 'Upload File', + notion: "Sync form Notion", + github: "Sync form Github", + }, + type: { + book: "Book", + webPage: "Web Page", + paper: "Paper", + socialMediaPost: "Social Media Post", + personalDocument: "Personal Document", + businessDocument: "Business Document", + IMChat: "IM Chat", + wikipediaEntry: "Wikipedia Entry", + notion: "Sync form Notion", + github: "Sync form Github", + technicalParameters: 'Technical Parameters', + }, + field: { + processRule: { + processDoc: 'Process Document', + segmentRule: 'Segment Rule', + segmentLength: 'Segment Length', + processClean: 'Text Process Clean', + }, + book: { + title: "Title", + language: "Language", + author: "Author", + publisher: "Publisher", + publicationDate: "Publication Date", + ISBN: "ISBN", + category: "Category", + }, + webPage: { + title: "Title", + url: "URL", + language: "Language", + authorPublisher: "Author/Publisher", + publishDate: "Publish Date", + topicsKeywords: "Topics/Keywords", + description: "Description", + }, + paper: { + title: "Title", + language: "Language", + author: "Author", + publishDate: "Publish Date", + journalConferenceName: "Journal/Conference Name", + volumeIssuePage: "Volume/Issue/Page", + DOI: "DOI", + topicsKeywords: "Topics/Keywords", + abstract: "Abstract", + }, + socialMediaPost: { + platform: "Platform", + authorUsername: "Author/Username", + publishDate: "Publish Date", + postURL: "Post URL", + topicsTags: "Topics/Tags", + }, + personalDocument: { + title: "Title", + author: "Author", + creationDate: "Creation Date", + lastModifiedDate: "Last Modified Date", + documentType: "Document Type", + tagsCategory: "Tags/Category", + }, + businessDocument: { + title: "Title", + author: "Author", + creationDate: "Creation Date", + lastModifiedDate: "Last Modified Date", + documentType: "Document Type", + departmentTeam: "Department/Team", + }, + IMChat: { + chatPlatform: "Chat Platform", + chatPartiesGroupName: "Chat Parties/Group Name", + participants: "Participants", + startDate: "Start Date", + endDate: "End Date", + topicsKeywords: "Topics/Keywords", + fileType: "File Type", + }, + wikipediaEntry: { + title: "Title", + language: "Language", + webpageURL: "Webpage URL", + editorContributor: "Editor/Contributor", + lastEditDate: "Last Edit Date", + summaryIntroduction: "Summary/Introduction", + }, + notion: { + title: "Title", + language: "Language", + author: "Author", + createdTime: "Created Time", + lastModifiedTime: "Last Modified Time", + url: "URL", + tag: "Tag", + description: "Description", + }, + github: { + repoName: "Repo Name", + repoDesc: "Repo Description", + repoOwner: "Repo Owner", + fileName: "File Name", + filePath: "File Path", + programmingLang: "Programming Language", + url: "URL", + license: "License", + lastCommitTime: "Last Commit Time", + lastCommitAuthor: "Last Commit Author", + }, + originInfo: { + originalFilename: "Original filename", + originalFileSize: "Original file size", + uploadDate: "Upload date", + lastUpdateDate: "Last update date", + source: "Source", + }, + technicalParameters: { + segmentSpecification: 'Segment specification', + segmentLength: 'Segment length', + avgParagraphLength: 'Avg. paragraph length', + paragraphs: 'Paragraphs', + hitCount: 'Hit count', + embeddingTime: 'Embedding time', + embeddedSpend: 'Embedded spend' + } + }, + languageMap: { + zh: "Chinese", + en: "English", + es: "Spanish", + fr: "French", + de: "German", + ja: "Japanese", + ko: "Korean", + ru: "Russian", + ar: "Arabic", + pt: "Portuguese", + it: "Italian", + nl: "Dutch", + pl: "Polish", + sv: "Swedish", + tr: "Turkish", + he: "Hebrew", + hi: "Hindi", + da: "Danish", + fi: "Finnish", + no: "Norwegian", + hu: "Hungarian", + el: "Greek", + cs: "Czech", + th: "Thai", + id: "Indonesian", + }, + categoryMap: { + book: { + fiction: "Fiction", + biography: "Biography", + history: "History", + science: "Science", + technology: "Technology", + education: "Education", + philosophy: "Philosophy", + religion: "Religion", + socialSciences: "SocialSciences", + art: "Art", + travel: "Travel", + health: "Health", + selfHelp: "SelfHelp", + businessEconomics: "BusinessEconomics", + cooking: "Cooking", + childrenYoungAdults: "ChildrenYoungAdults", + comicsGraphicNovels: "ComicsGraphicNovels", + poetry: "Poetry", + drama: "Drama", + other: "Other", + }, + personalDoc: { + notes: "Notes", + blogDraft: "Blog Draft", + diary: "Diary", + researchReport: "Research Report", + bookExcerpt: "Book Excerpt", + schedule: "Schedule", + list: "List", + projectOverview: "Project Overview", + photoCollection: "Photo Collection", + creativeWriting: "Creative Writing", + codeSnippet: "Code Snippet", + designDraft: "Design Draft", + personalResume: "Personal Resume", + other: "Other", + }, + businessDoc: { + meetingMinutes: "Meeting Minutes", + researchReport: "Research Report", + proposal: "Proposal", + employeeHandbook: "Employee Handbook", + trainingMaterials: "Training Materials", + requirementsDocument: "Requirements Document", + designDocument: "Design Document", + productSpecification: "Product Specification", + financialReport: "Financial Report", + marketAnalysis: "Market Analysis", + projectPlan: "Project Plan", + teamStructure: "Team Structure", + policiesProcedures: "Policies & Procedures", + contractsAgreements: "Contracts & Agreements", + emailCorrespondence: "Email Correspondence", + other: "Other", + }, + }, + }, + embedding: { + processing: 'Embedding processing...', + paused: 'Embedding paused', + completed: 'Embedding completed', + error: 'Embedding error', + docName: 'Preprocessing document', + mode: 'Segmentation rule', + segmentLength: 'Segmentation length', + textCleaning: 'Text pre-definition and cleaning', + segments: 'Paragraphs', + highQuality: 'High-quality mode', + economy: 'Economy mode', + estimate: 'Estimated consumption', + stop: 'Stop processing', + resume: 'Resume processing', + automatic: 'Automatic', + custom: 'Custom', + previewTip: 'Paragraph preview will be available after embedding is complete' + }, + segment: { + paragraphs: 'Paragraphs', + keywords: 'Key Words', + characters: 'characters', + hitCount: 'hit count', + vectorHash: 'Vector hash: ', + } +}; + +export default translation; diff --git a/web/i18n/lang/dataset-documents.zh.ts b/web/i18n/lang/dataset-documents.zh.ts new file mode 100644 index 0000000000..d295a896bb --- /dev/null +++ b/web/i18n/lang/dataset-documents.zh.ts @@ -0,0 +1,313 @@ +const translation = { + list: { + title: "文档", + desc: "数据集的所有文件都在这里显示,整个数据集都可以链接到 Dify 引用或通过 Chat 插件进行索引。", + addFile: "添加文件", + table: { + header: { + fileName: "文件名", + words: "字符数", + hitCount: "命中次数", + uploadTime: "上传时间", + status: "状态", + action: "操作", + }, + }, + action: { + uploadFile: '上传新文件', + settings: '分段设置', + archive: '归档', + delete: "删除", + enableWarning: '归档的文件无法启用', + }, + index: { + enable: '启用中', + disable: '禁用中', + all: '全部', + enableTip: '该文件可以被索引', + disableTip: '该文件无法被索引', + }, + status: { + queuing: '排队中', + indexing: '索引中', + parsed: '已解析', + error: '错误', + available: '可用', + enabled: '已启用', + disabled: '已禁用', + archived: '已归档', + }, + empty: { + title: "还没有文档", + upload: { + tip: "您可以上传文件,从网站同步,或者从网络应用程序(如概念、GitHub 等)同步。", + }, + sync: { + tip: "Dify 会定期从您的 Notion 中下载文件并完成处理。", + }, + }, + delete: { + title: '确定删除吗?', + content: '如果您需要稍后恢复处理,您将从您离开的地方继续' + } + }, + metadata: { + title: "元数据", + desc: "标记文档的元数据允许 AI 及时访问它们并为用户公开参考来源。", + dateTimeFormat: 'YYYY-MM-DD HH:mm', + docTypeSelectTitle: "请选择一种文档类型", + docTypeChangeTitle: "更换文档类型", + docTypeSelectWarning: "如果更改文档类型,将不再保留现在填充的元数据", + firstMetaAction: "开始", + placeholder: { + add: '输入', + select: '选择', + }, + source: { + upload_file: '文件上传', + notion: "从 Notion 同步的文档", + github: "从 Github 同步的代码", + }, + type: { + book: "书籍", + webPage: "网页", + paper: "论文", + socialMediaPost: "社交媒体帖子", + personalDocument: "个人文档", + businessDocument: "商务文档", + IMChat: "IM 聊天记录", + wikipediaEntry: "维基百科条目", + notion: "从 Notion 同步的文档", + github: "从 Github 同步的代码", + technicalParameters: '技术参数', + }, + field: { + processRule: { + processDoc: '预处理文档', + segmentRule: '分段规则', + segmentLength: '分段长度', + processClean: '文本预处理与清洗', + }, + book: { + title: "标题", + language: "语言", + author: "作者", + publisher: "出版商", + publicationDate: "出版日期", + ISBN: "ISBN", + category: "类别", + }, + webPage: { + title: "标题", + url: "网址", + language: "语言", + authorPublisher: "作者/出版商", + publishDate: "发布日期", + topicsKeywords: "主题/关键词", + description: "描述", + }, + paper: { + title: "标题", + language: "语言", + author: "作者", + publishDate: "发布日期", + journalConferenceName: "期刊/会议名称", + volumeIssuePage: "卷/期/页码", + DOI: "DOI", + topicsKeywords: "主题/关键词", + abstract: "摘要", + }, + socialMediaPost: { + platform: "平台", + authorUsername: "作者/用户名", + publishDate: "发布日期", + postURL: "帖子网址", + topicsTags: "主题/标签", + }, + personalDocument: { + title: "标题", + author: "作者", + creationDate: "创建日期", + lastModifiedDate: "最后修改日期", + documentType: "文档类型", + tagsCategory: "标签/类别", + }, + businessDocument: { + title: "标题", + author: "作者", + creationDate: "创建日期", + lastModifiedDate: "最后修改日期", + documentType: "文档类型", + departmentTeam: "部门/团队", + }, + IMChat: { + chatPlatform: "聊天平台", + chatPartiesGroupName: "聊天参与方/群组名称", + participants: "参与者", + startDate: "开始日期", + endDate: "结束日期", + topicsKeywords: "主题/关键词", + fileType: "文件类型", + }, + wikipediaEntry: { + title: "标题", + language: "语言", + webpageURL: "网页网址", + editorContributor: "编辑/贡献者", + lastEditDate: "最后编辑日期", + summaryIntroduction: "摘要/介绍", + }, + notion: { + title: "标题", + language: "语言", + author: "作者", + createdTime: "创建时间", + lastModifiedTime: "最后修改时间", + url: "网址", + tag: "标签", + description: "描述", + }, + github: { + repoName: "仓库名", + repoDesc: "仓库描述", + repoOwner: "仓库所有者", + fileName: "文件名", + filePath: "文件路径", + programmingLang: "编程语言", + url: "网址", + license: "许可证", + lastCommitTime: "最后提交时间", + lastCommitAuthor: "最后提交者", + }, + originInfo: { + originalFilename: "原始文件名称", + originalFileSize: "原始文件大小", + uploadDate: "上传日期", + lastUpdateDate: "最后更新日期", + source: "来源", + }, + technicalParameters: { + segmentSpecification: '分段规则', + segmentLength: '段落长度', + avgParagraphLength: '平均段落长度', + paragraphs: '段落数量', + hitCount: '命中次数', + embeddingTime: '嵌入时间', + embeddedSpend: '嵌入花费', + } + }, + languageMap: { + zh: "中文", + en: "英文", + es: "西班牙语", + fr: "法语", + de: "德语", + ja: "日语", + ko: "韩语", + ru: "俄语", + ar: "阿拉伯语", + pt: "葡萄牙语", + it: "意大利语", + nl: "荷兰语", + pl: "波兰语", + sv: "瑞典语", + tr: "土耳其语", + he: "希伯来语", + hi: "印地语", + da: "丹麦语", + fi: "芬兰语", + no: "挪威语", + hu: "匈牙利语", + el: "希腊语", + cs: "捷克语", + th: "泰语", + id: "印度尼西亚语", + }, + categoryMap: { + book: { + fiction: "小说", + biography: "传记", + history: "历史", + science: "科学", + technology: "技术", + education: "教育", + philosophy: "哲学", + religion: "宗教", + socialSciences: "社会科学", + art: "艺术", + travel: "旅行", + health: "健康", + selfHelp: "自助", + businessEconomics: "商业/经济", + cooking: "烹饪", + childrenYoungAdults: "儿童/青少年", + comicsGraphicNovels: "漫画/图形小说", + poetry: "诗歌", + drama: "戏剧", + other: "其他", + }, + personalDoc: { + notes: "笔记", + blogDraft: "博客草稿", + diary: "日记", + researchReport: "研究报告", + bookExcerpt: "书籍摘录", + schedule: "日程安排", + list: "列表", + projectOverview: "项目概述", + photoCollection: "照片集", + creativeWriting: "创意写作", + codeSnippet: "代码片段", + designDraft: "设计草稿", + personalResume: "个人简历", + other: "其他", + }, + businessDoc: { + meetingMinutes: "会议纪要", + researchReport: "研究报告", + proposal: "提案", + employeeHandbook: "员工手册", + trainingMaterials: "培训材料", + requirementsDocument: "需求文档", + designDocument: "设计文档", + productSpecification: "产品规格", + financialReport: "财务报告", + marketAnalysis: "市场分析", + projectPlan: "项目计划", + teamStructure: "团队结构", + policiesProcedures: "政策和流程", + contractsAgreements: "合同和协议", + emailCorrespondence: "邮件往来", + other: "其他", + }, + }, + }, + embedding: { + processing: '嵌入处理中...', + paused: '嵌入已停止', + completed: '嵌入已完成', + error: '嵌入发生错误', + docName: '预处理文档', + mode: '分段规则', + segmentLength: '分段长度', + textCleaning: '文本预定义与清洗', + segments: '段落', + highQuality: '高质量模式', + economy: '经济模式', + estimate: '预估消耗', + stop: '停止处理', + resume: '恢复处理', + automatic: '自动', + custom: '自定义', + previewTip: '段落预览将在嵌入完成后可用' + }, + segment: { + paragraphs: '段落', + keywords: '关键词', + characters: '字符', + hitCount: '命中次数', + vectorHash: '向量哈希:', + } +}; + +export default translation; diff --git a/web/i18n/lang/dataset-hit-testing.en.ts b/web/i18n/lang/dataset-hit-testing.en.ts new file mode 100644 index 0000000000..4946b59d83 --- /dev/null +++ b/web/i18n/lang/dataset-hit-testing.en.ts @@ -0,0 +1,28 @@ +const translation = { + title: "Hit Testing", + desc: "Test the hitting effect of the dataset based on the given query text.", + dateTimeFormat: 'MM/DD/YYYY hh:mm A', + recents: 'Recents', + table: { + header: { + source: "Source", + text: "Text", + time: "Time", + }, + }, + input: { + title: 'Source text', + placeholder: 'Please enter a text, a short declarative sentence is recommended.', + countWarning: 'Up to 200 characters.', + indexWarning: 'High quality dataset only.', + testing: 'Testing', + }, + hit: { + title: "HIT PARAGRAPHS", + emptyTip: 'Hit Testing results will show here', + }, + noRecentTip: 'No recent query results here', + viewChart: 'View VECTOR CHART', +}; + +export default translation; diff --git a/web/i18n/lang/dataset-hit-testing.zh.ts b/web/i18n/lang/dataset-hit-testing.zh.ts new file mode 100644 index 0000000000..0ef23cbd9b --- /dev/null +++ b/web/i18n/lang/dataset-hit-testing.zh.ts @@ -0,0 +1,28 @@ +const translation = { + title: '命中测试', + desc: '基于给定的查询文本测试数据集的命中效果。', + dateTimeFormat: 'YYYY-MM-DD HH:mm', + recents: '最近查询', + table: { + header: { + source: "数据源", + text: "文本", + time: "时间", + }, + }, + input: { + title: '源文本', + placeholder: '请输入文本,建议使用简短的陈述句。', + countWarning: '不超过 200 个字符', + indexWarning: '仅支持高质量模式数据集', + testing: '测试', + }, + hit: { + title: "命中段落", + emptyTip: '命中测试结果将展示在这里', + }, + noRecentTip: '最近无查询结果', + viewChart: '查看向量图表', +} + +export default translation diff --git a/web/i18n/lang/dataset-settings.en.ts b/web/i18n/lang/dataset-settings.en.ts new file mode 100644 index 0000000000..1337383ad4 --- /dev/null +++ b/web/i18n/lang/dataset-settings.en.ts @@ -0,0 +1,22 @@ +const translation = { + title: 'Dataset settings', + desc: 'Here you can modify the properties and working methods of the dataset.', + form: { + name: 'Dataset Name', + nameError: 'Name cannot be empty', + desc: 'Dataset description', + descPlaceholder: 'Describe what is in this data set. A detailed description allows AI to access the content of the data set in a timely manner. If empty, Dify will use the default hit strategy.', + descWrite: 'Learn how to write a good dataset description.', + permissions: 'Permissions', + permissionsOnlyMe: 'Only me', + permissionsAllMember: 'All team members', + indexMethod: 'Index Method', + indexMethodHighQuality: 'High Quality', + indexMethodHighQualityTip: 'Call OpenAI\'s embedding interface for processing to provide higher accuracy when users query.', + indexMethodEconomy: 'Economical', + indexMethodEconomyTip: 'Use offline vector engines, keyword indexes, etc. to reduce accuracy without spending tokens', + save: 'Save', + }, +} + +export default translation diff --git a/web/i18n/lang/dataset-settings.zh.ts b/web/i18n/lang/dataset-settings.zh.ts new file mode 100644 index 0000000000..0818836d79 --- /dev/null +++ b/web/i18n/lang/dataset-settings.zh.ts @@ -0,0 +1,22 @@ +const translation = { + title: '数据集设置', + desc: '在这里您可以修改数据集的工作方式以及其它设置。', + form: { + name: '数据集名称', + nameError: '名称不能为空', + desc: '数据集描述', + descPlaceholder: '描述这个数据集中的内容。详细的描述可以让 AI 及时访问数据集的内容。如果为空,Dify 将使用默认的命中策略。', + descWrite: '了解如何编写更好的数据集描述。', + permissions: '可见权限', + permissionsOnlyMe: '只有我', + permissionsAllMember: '所有团队成员', + indexMethod: '索引模式', + indexMethodHighQuality: '高质量', + indexMethodHighQualityTip: '调用 OpenAI 的嵌入接口进行处理,以在用户查询时提供更高的准确度', + indexMethodEconomy: '经济', + indexMethodEconomyTip: '使用离线的向量引擎、关键词索引等方式,降低了准确度但无需花费 Token', + save: '保存', + }, +} + +export default translation diff --git a/web/i18n/lang/dataset.en.ts b/web/i18n/lang/dataset.en.ts new file mode 100644 index 0000000000..94d98c2206 --- /dev/null +++ b/web/i18n/lang/dataset.en.ts @@ -0,0 +1,21 @@ +const translation = { + documentCount: ' docs', + wordCount: 'k words', + appCount: ' linked apps', + createDataset: 'Create Dataset', + createDatasetIntro: 'Import your own text data or write data in real-time via Webhook for LLM context enhancement.', + deleteDatasetConfirmTitle: 'Delete this app?', + deleteDatasetConfirmContent: + 'Deleting the dataset is irreversible. Users will no longer be able to access your dataset, and all prompt configurations and logs will be permanently deleted.', + datasetDeleted: 'Dataset deleted', + datasetDeleteFailed: 'Failed to delete dataset', + didYouKnow: 'Did you know?', + intro1: 'The dataset can be integrated into the Dify application ', + intro2: 'as a context', + intro3: ',', + intro4: 'or it ', + intro5: 'can be created', + intro6: ' as a standalone ChatGPT index plug-in to publish', +} + +export default translation diff --git a/web/i18n/lang/dataset.zh.ts b/web/i18n/lang/dataset.zh.ts new file mode 100644 index 0000000000..b74eabab7e --- /dev/null +++ b/web/i18n/lang/dataset.zh.ts @@ -0,0 +1,21 @@ +const translation = { + documentCount: ' 文档', + wordCount: '千字符', + appCount: ' 关联应用', + createDataset: '创建数据集', + createDatasetIntro: '导入您自己的文本数据或通过 Webhook 实时写入数据以增强 LLM 的上下文。', + deleteDatasetConfirmTitle: '要删除数据集吗?', + deleteDatasetConfirmContent: + '删除数据集是不可逆的。用户将无法再访问您的数据集,所有的提示配置和日志将被永久删除。', + datasetDeleted: '数据集已删除', + datasetDeleteFailed: '删除数据集失败', + didYouKnow: '你知道吗??', + intro1: '数据集可以被集成到 Dify 应用中', + intro2: '作为上下文', + intro3: ',', + intro4: '或可以', + intro5: '创建', + intro6: '为独立的 ChatGPT 插件发布使用', +} + +export default translation diff --git a/web/i18n/lang/layout.en.ts b/web/i18n/lang/layout.en.ts new file mode 100644 index 0000000000..928649474b --- /dev/null +++ b/web/i18n/lang/layout.en.ts @@ -0,0 +1,4 @@ +const translation = { +} + +export default translation diff --git a/web/i18n/lang/layout.zh.ts b/web/i18n/lang/layout.zh.ts new file mode 100644 index 0000000000..928649474b --- /dev/null +++ b/web/i18n/lang/layout.zh.ts @@ -0,0 +1,4 @@ +const translation = { +} + +export default translation diff --git a/web/i18n/lang/login.en.ts b/web/i18n/lang/login.en.ts new file mode 100644 index 0000000000..ae40e01ea1 --- /dev/null +++ b/web/i18n/lang/login.en.ts @@ -0,0 +1,41 @@ +const translation = { + "pageTitle": "Hey, let's get started!👋", + "welcome": "Welcome to Dify, please log in to continue.", + "email": "Email address", + "password": "Password", + "name": "Name", + "forget": "Forgot your password?", + "signBtn": "Sign in", + "installBtn": "Setting", + "setAdminAccount": "Setting up an admin account", + "setAdminAccountDesc": "Maximum privileges for admin account, which can be used to create applications and manage LLM providers, etc.", + "createAndSignIn": "Create and sign in", + "oneMoreStep": "One more step", + "createSample": "Based on this information, we’ll create sample application for you", + "invitationCode": "Invitation Code", + "interfaceLanguage": "Interface Dify", + "timezone": "Time zone", + "go": "Go to Dify", + "sendUsMail": "Email us your introduction, and we'll handle the invitation request.", + "acceptPP": "I have read and accept the privacy policy", + "reset": "Please run following command to reset your password", + "withGitHub": "Continue with GitHub", + "withGoogle": "Continue with Google", + "rightTitle": "Unlock the full potential of LLM", + "rightDesc": "Effortlessly build visually captivating, operable, and improvable AI applications.", + "tos": "Terms of Service", + "pp": "Privacy Policy", + "tosDesc": "By signing up, you agree to our", + "donthave": "Don't have?", + "invalidInvitationCode": "Invalid invitation code", + "accountAlreadyInited": "Account already inited", + "error": { + "emailEmpty": "Email address is required", + "emailInValid": "Please enter a valid email address", + "nameEmpty": "Name is required", + "passwordEmpty": "Password is required", + "passwordInvalid": "Password must contain letters and numbers, and the length must be greater than 8", + } +} + +export default translation diff --git a/web/i18n/lang/login.zh.ts b/web/i18n/lang/login.zh.ts new file mode 100644 index 0000000000..a7da82c13f --- /dev/null +++ b/web/i18n/lang/login.zh.ts @@ -0,0 +1,41 @@ +const translation = { + "pageTitle": "嗨,近来可好 👋", + "welcome": "欢迎来到 Dify, 登录以继续", + "email": "邮箱", + "password": "密码", + "name": "用户名", + "forget": "忘记密码?", + "signBtn": "登录", + "installBtn": "设置", + "setAdminAccount": "设置管理员账户", + "setAdminAccountDesc": "管理员拥有的最大权限,可用于创建应用和管理 LLM 供应商等。", + "createAndSignIn": "创建账户", + "oneMoreStep": "还差一步", + "createSample": "基于这些信息,我们将为您创建一个示例应用", + "invitationCode": "邀请码", + "interfaceLanguage": "界面语言", + "timezone": "时区", + "go": "跳转至 Dify", + "sendUsMail": "发封邮件介绍你自己,我们会尽快处理。", + "acceptPP": "我已阅读并接受隐私政策", + "reset": "请运行以下命令重置密码", + "withGitHub": "使用 GitHub 登录", + "withGoogle": "使用 Google 登录", + "rightTitle": "释放大型语言模型的全部潜能", + "rightDesc": "简单构建可视化、可运营、可改进的 AI 应用", + "tos": "使用协议", + "pp": "隐私政策", + "tosDesc": "使用即代表你并同意我们的", + "donthave": "还没有邀请码?", + "invalidInvitationCode": "无效的邀请码", + "accountAlreadyInited": "账户已经初始化", + "error": { + "emailEmpty": "邮箱不能为空", + "emailInValid": "请输入有效的邮箱地址", + "nameEmpty": "用户名不能为空", + "passwordEmpty": "密码不能为空", + "passwordInvalid": "密码必须包含字母和数字,且长度不小于8位", + } +} + +export default translation diff --git a/web/i18n/lang/register.en.ts b/web/i18n/lang/register.en.ts new file mode 100644 index 0000000000..928649474b --- /dev/null +++ b/web/i18n/lang/register.en.ts @@ -0,0 +1,4 @@ +const translation = { +} + +export default translation diff --git a/web/i18n/lang/register.zh.ts b/web/i18n/lang/register.zh.ts new file mode 100644 index 0000000000..928649474b --- /dev/null +++ b/web/i18n/lang/register.zh.ts @@ -0,0 +1,4 @@ +const translation = { +} + +export default translation diff --git a/web/i18n/lang/share-app.en.ts b/web/i18n/lang/share-app.en.ts new file mode 100644 index 0000000000..3e641222b5 --- /dev/null +++ b/web/i18n/lang/share-app.en.ts @@ -0,0 +1,45 @@ +const translation = { + common: { + welcome: "Welcome to use", + appUnavailable: "App is unavailable", + appUnkonwError: "App is unavailable" + }, + chat: { + newChat: "New chat", + newChatDefaultName: "New conversation", + powerBy: "Powered by", + prompt: "Prompt", + privatePromptConfigTitle: "Conversation settings", + publicPromptConfigTitle: "Initial Prompt", + configStatusDes: "Before start, you can modify conversation settings", + configDisabled: + "Previous session settings have been used for this session.", + startChat: "Start Chat", + privacyPolicyLeft: + "Please read the ", + privacyPolicyMiddle: + "privacy policy", + privacyPolicyRight: + " provided by the app developer.", + }, + generation: { + tabs: { + create: "Create", + saved: "Saved", + }, + savedNoData: { + title: "You haven't saved a result yet!", + description: 'Start generating content, and find your saved results here.', + startCreateContent: 'Start create content' + }, + title: "AI Completion", + queryTitle: "Query content", + queryPlaceholder: "Write your query content...", + run: "RUN", + copy: "Copy", + resultTitle: "AI Completion", + noData: "AI will give you what you want here.", + }, +}; + +export default translation; diff --git a/web/i18n/lang/share-app.zh.ts b/web/i18n/lang/share-app.zh.ts new file mode 100644 index 0000000000..031bedf03a --- /dev/null +++ b/web/i18n/lang/share-app.zh.ts @@ -0,0 +1,41 @@ +const translation = { + common: { + welcome: "欢迎使用", + appUnavailable: "应用不可用", + appUnkonwError: "应用不可用", + }, + chat: { + newChat: "新对话", + newChatDefaultName: "新的对话", + powerBy: "Powered by", + prompt: "提示词", + privatePromptConfigTitle: "对话设置", + publicPromptConfigTitle: "对话前提示词", + configStatusDes: "开始前,您可以修改对话设置", + configDisabled: "此次会话已使用上次会话表单", + startChat: "开始对话", + privacyPolicyLeft: "请阅读由该应用开发者提供的", + privacyPolicyMiddle: "隐私政策", + privacyPolicyRight: "。", + }, + generation: { + tabs: { + create: "创建", + saved: "已保存", + }, + savedNoData: { + title: "您还没有保存结果!", + description: '开始生成内容,您可以在这里找到保存的结果。', + startCreateContent: '开始生成内容' + }, + title: "AI 智能书写", + queryTitle: "查询内容", + queryPlaceholder: "请输入文本内容", + run: "运行", + copy: "拷贝", + resultTitle: "AI 书写", + noData: "AI 会在这里给你惊喜。", + }, +}; + +export default translation; diff --git a/web/i18n/server.ts b/web/i18n/server.ts new file mode 100644 index 0000000000..ac9630dd49 --- /dev/null +++ b/web/i18n/server.ts @@ -0,0 +1,43 @@ +import 'server-only' + +import { cookies, headers } from 'next/headers' +import Negotiator from 'negotiator' +import { match } from '@formatjs/intl-localematcher' +import type { Locale } from '.' +import { i18n } from '.' + +export const getLocaleOnServer = (): Locale => { + // @ts-expect-error locales are readonly + const locales: string[] = i18n.locales + + let languages: string[] | undefined + // get locale from cookie + const localeCookie = cookies().get('locale') + languages = localeCookie?.value ? [localeCookie.value] : [] + + if (!languages.length) { + // Negotiator expects plain object so we need to transform headers + const negotiatorHeaders: Record = {} + headers().forEach((value, key) => (negotiatorHeaders[key] = value)) + // Use negotiator and intl-localematcher to get best locale + languages = new Negotiator({ headers: negotiatorHeaders }).languages() + } + + // match locale + const matchedLocale = match(languages, locales, i18n.defaultLocale) as Locale + return matchedLocale +} + +// We enumerate all dictionaries here for better linting and typescript support +// We also get the default import for cleaner types +const dictionaries = { + 'en': () => import('@/dictionaries/en.json').then(module => module.default), + 'zh-Hans': () => import('@/dictionaries/zh-Hans.json').then(module => module.default), +} as { [locale: string]: () => Promise } + +export const getDictionary = async (locale: Locale = 'en') => { + try { + return await dictionaries[locale]() + } + catch (e) { console.error('locale not found', locale) } +} diff --git a/web/middleware.ts b/web/middleware.ts new file mode 100644 index 0000000000..5b17cbd673 --- /dev/null +++ b/web/middleware.ts @@ -0,0 +1,39 @@ +import { match } from '@formatjs/intl-localematcher' +import Negotiator from 'negotiator' +import { NextResponse } from 'next/server' +import type { NextRequest } from 'next/server' +import type { Locale } from './i18n' +import { i18n } from './i18n' + +export const getLocale = (request: NextRequest): Locale => { + // @ts-expect-error locales are readonly + const locales: Locale[] = i18n.locales + + let languages: string[] | undefined + // get locale from cookie + const localeCookie = request.cookies.get('locale') + languages = localeCookie?.value ? [localeCookie.value] : [] + + if (!languages.length) { + // Negotiator expects plain object so we need to transform headers + const negotiatorHeaders: Record = {} + request.headers.forEach((value, key) => (negotiatorHeaders[key] = value)) + // Use negotiator and intl-localematcher to get best locale + languages = new Negotiator({ headers: negotiatorHeaders }).languages() + } + + // match locale + const matchedLocale = match(languages, locales, i18n.defaultLocale) as Locale + return matchedLocale +} + +export const middleware = async (request: NextRequest) => { + const pathname = request.nextUrl.pathname + if (/\.(css|js(on)?|ico|svg|png)$/.test(pathname)) + return + + const locale = getLocale(request) + const response = NextResponse.next() + response.cookies.set('locale', locale) + return response +} diff --git a/web/models/app.ts b/web/models/app.ts new file mode 100644 index 0000000000..ddafdfbc72 --- /dev/null +++ b/web/models/app.ts @@ -0,0 +1,118 @@ +import type { App, AppTemplate, SiteConfig } from '@/types/app' + +export type AppMode = 'chat' | 'completion' + +/* export type App = { + id: string + name: string + decription: string + mode: AppMode + enable_site: boolean + enable_api: boolean + api_rpm: number + api_rph: number + is_demo: boolean + model_config: AppModelConfig + providers: Array<{ provider: string; token_is_set: boolean }> + site: SiteConfig + created_at: string +} + +export type AppModelConfig = { + provider: string + model_id: string + configs: { + prompt_template: string + prompt_variables: Array + completion_params: CompletionParam + } +} + +export type PromptVariable = { + key: string + name: string + description: string + type: string | number + default: string + options: string[] +} + +export type CompletionParam = { + max_tokens: number + temperature: number + top_p: number + echo: boolean + stop: string[] + presence_penalty: number + frequency_penalty: number +} + +export type SiteConfig = { + access_token: string + title: string + author: string + support_email: string + default_language: string + customize_domain: string + theme: string + customize_token_strategy: 'must' | 'allow' | 'not_allow' + prompt_public: boolean +} */ + +export type AppListResponse = { + data: App[] +} + +export type AppDetailResponse = App + +export type AppTemplatesResponse = { + data: AppTemplate[] +} + +export type CreateAppResponse = App + +export type UpdateAppNameResponse = App + +export type UpdateAppSiteCodeResponse = { app_id: string } & SiteConfig + +export type AppDailyConversationsResponse = { + data: Array<{ date: string; conversation_count: number }> +} + +export type AppDailyEndUsersResponse = { + data: Array<{ date: string; terminal_count: number }> +} + +export type AppTokenCostsResponse = { + data: Array<{ date: string; token_count: number; total_price: number; currency: number }> +} + +export type UpdateAppModelConfigResponse = { result: string } + +export type ApikeyItemResponse = { + id: string + token: string + last_used_at: string + created_at: string +} + +export type ApikeysListResponse = { + data: ApikeyItemResponse[] +} + +export type CreateApiKeyResponse = { + id: string + token: string + created_at: string +} + +export type ValidateOpenAIKeyResponse = { + result: string + error?: string +} + +export type UpdateOpenAIKeyResponse = ValidateOpenAIKeyResponse + +export type GenerationIntroductionResponse = { + introduction: string +} diff --git a/web/models/common.ts b/web/models/common.ts new file mode 100644 index 0000000000..21a74447e1 --- /dev/null +++ b/web/models/common.ts @@ -0,0 +1,92 @@ +export type CommonResponse = { + result: 'success' | 'fail' +} + +export type OauthResponse = { + redirect_url: string +} + +export type UserProfileResponse = { + id: string + name: string + email: string + interface_language?: string + interface_theme?: string + timezone?: string + last_login_at?: string + last_login_ip?: string + created_at?: string +} + +export type UserProfileOriginResponse = { + json: () => Promise + bodyUsed: boolean + headers: any +} + +export type LangGeniusVersionResponse = { + current_version: string + latest_version: string + version: string + release_date: string + release_notes: string + can_auto_update: boolean + current_env: string +} + +export type TenantInfoResponse = { + name: string + created_at: string + providers: Array<{ + provider: string + provider_name: string + token_is_set: boolean + is_valid: boolean + token_is_valid: boolean + }> + in_trail: boolean + trial_end_reason: null | 'trial_exceeded' | 'using_custom' +} + +export type Member = Pick & { + avatar: string + status: 'pending' | 'active' | 'banned' | 'closed' + role: 'owner' | 'admin' | 'normal' +} + +export type ProviderAzureToken = { + azure_api_base: string + azure_api_key: string + azure_api_type: string + azure_api_version: string +} +export type Provider = { + provider_name: string + provider_type: string + is_valid: boolean + is_enabled: boolean + last_used: string + token?: string | ProviderAzureToken +} + +export type ProviderHosted = Provider & { + quota_type: string + quota_limit: number + quota_used: number +} + +export type AccountIntegrate = { + provider: 'google' | 'github' + created_at: number + is_bound: boolean + link: string +} + +export interface IWorkspace { + id: string + name: string + plan: string + status: string + created_at: number + current: boolean +} diff --git a/web/models/datasets.ts b/web/models/datasets.ts new file mode 100644 index 0000000000..8295e77a55 --- /dev/null +++ b/web/models/datasets.ts @@ -0,0 +1,339 @@ +import { AppMode } from './app' + +export type DataSet = { + id: string + name: string + description: string + permission: 'only_me' | 'all_team_members' + data_source_type: 'upload_file' + indexing_technique: 'high_quality' | 'economy' + created_by: string + updated_by: string + updated_at: number + app_count: number + document_count: number + word_count: number +} + +export type File = { + id: string + name: string + size: number + extension: string + mime_type: string + created_by: string + created_at: number +} + +export type DataSetListResponse = { + data: DataSet[] +} + +export type IndexingEstimateResponse = { + tokens: number + total_price: number + currency: string + total_segments: number + preview: string[] +} + +export interface FileIndexingEstimateResponse extends IndexingEstimateResponse { + total_nodes: number +} + +export type IndexingStatusResponse = { + id: string + indexing_status: DocumentIndexingStatus + processing_started_at: number + parsing_completed_at: number + cleaning_completed_at: number + splitting_completed_at: number + completed_at: any + paused_at: any + error: any + stopped_at: any + completed_segments: number + total_segments: number +} + +export type ProcessMode = 'automatic' | 'custom' + +export type ProcessRuleResponse = { + mode: ProcessMode + rules: Rules +} + +export type Rules = { + pre_processing_rules: PreProcessingRule[] + segmentation: Segmentation +} + +export type PreProcessingRule = { + id: string + enabled: boolean +} + +export type Segmentation = { + separator: string + max_tokens: number +} + +export const DocumentIndexingStatusList = [ + 'waiting', + 'parsing', + 'cleaning', + 'splitting', + 'indexing', + 'paused', + 'error', + 'completed', +] as const + +export type DocumentIndexingStatus = typeof DocumentIndexingStatusList[number] + +export const DisplayStatusList = [ + "queuing", + "indexing", + "paused", + "error", + "available", + "enabled", + "disabled", + "archived", +] as const; + +export type DocumentDisplayStatus = typeof DisplayStatusList[number]; + +export type DataSourceInfo = { + upload_file: { + id: string + name: string + size: number + mime_type: string + created_at: number + created_by: string + extension: string + } +} + +export type InitialDocumentDetail = { + id: string + position: number + dataset_id: string + data_source_type: 'upload_file' + data_source_info: DataSourceInfo + dataset_process_rule_id: string + name: string + created_from: 'api' | 'web' + created_by: string + created_at: number + indexing_status: DocumentIndexingStatus + display_status: DocumentDisplayStatus +} + +export type SimpleDocumentDetail = InitialDocumentDetail & { + enabled: boolean + word_count: number + error?: string | null + archived: boolean + updated_at: number + hit_count: number + dataset_process_rule_id?: string +} + +export type DocumentListResponse = { + data: SimpleDocumentDetail[] + has_more: boolean + total: number + page: number + limit: number +} + +export type CreateDocumentReq = { + original_document_id?: string + indexing_technique?: string; + name: string + data_source: DataSource + process_rule: ProcessRule +} + +export type DataSource = { + type: string + info: string // upload_file_id + name: string +} + +export type ProcessRule = { + mode: string + rules: Rules +} + +export type createDocumentResponse = { + dataset?: DataSet + document: InitialDocumentDetail +} + +export type FullDocumentDetail = SimpleDocumentDetail & { + batch: string + created_api_request_id: string + processing_started_at: number + parsing_completed_at: number + cleaning_completed_at: number + splitting_completed_at: number + tokens: number + indexing_latency: number + completed_at: number + paused_by: string + paused_at: number + stopped_at: number + indexing_status: string + disabled_at: number + disabled_by: string + archived_reason: 'rule_modified' | 're_upload' + archived_by: string + archived_at: number + doc_type?: DocType | null + doc_metadata?: DocMetadata | null + segment_count: number + [key: string]: any +} + +export type DocMetadata = { + title: string + language: string + author: string + publisher: string + publicationDate: string + ISBN: string + category: string + [key: string]: string +} + +export const CUSTOMIZABLE_DOC_TYPES = [ + "book", + "web_page", + "paper", + "social_media_post", + "personal_document", + "business_document", + "im_chat_log", +] as const; + +export const FIXED_DOC_TYPES = ["synced_from_github", "synced_from_notion", "wikipedia_entry"] as const; + +export type CustomizableDocType = typeof CUSTOMIZABLE_DOC_TYPES[number]; +export type FixedDocType = typeof FIXED_DOC_TYPES[number]; +export type DocType = CustomizableDocType | FixedDocType; + +export type DocumentDetailResponse = FullDocumentDetail + +export const SEGMENT_STATUS_LIST = ['waiting', 'completed', 'error', 'indexing'] +export type SegmentStatus = typeof SEGMENT_STATUS_LIST[number] + +export type SegmentsQuery = { + last_id?: string + limit: number + // status?: SegmentStatus + hit_count_gte?: number + keyword?: string + enabled?: boolean +} + +export type SegmentDetailModel = { + id: string + position: number + document_id: string + content: string + word_count: number + tokens: number + keywords: string[] + index_node_id: string + index_node_hash: string + hit_count: number + enabled: boolean + disabled_at: number + disabled_by: string + status: SegmentStatus + created_by: string + created_at: number + indexing_at: number + completed_at: number + error: string | null + stopped_at: number +} + +export type SegmentsResponse = { + data: SegmentDetailModel[] + has_more: boolean + limit: number + total: number +} + +export type HitTestingRecord = { + id: string + content: string + source: 'app' | 'hit_testing' | 'plugin' + source_app_id: string + created_by_role: 'account' | 'end_user' + created_by: string + created_at: number +} + +export type HitTesting = { + segment: Segment + score: number + tsne_position: TsnePosition +} + +export type Segment = { + id: string + document: Document + content: string + position: number + word_count: number + tokens: number + keywords: string[] + hit_count: number + index_node_hash: string +} + +export type Document = { + id: string + data_source_type: string + name: string + doc_type: DocType +} + +export type HitTestingRecordsResponse = { + data: HitTestingRecord[] + has_more: boolean + limit: number + total: number + page: number +} + +export type TsnePosition = { + x: number + y: number +} + +export type HitTestingResponse = { + query: { + content: string + tsne_position: TsnePosition + } + records: Array +} + +export type RelatedApp = { + id: string + name: string + mode: AppMode + icon: string + icon_background: string +} + +export type RelatedAppResponse = { + data: Array + total: number +} diff --git a/web/models/debug.ts b/web/models/debug.ts new file mode 100644 index 0000000000..cde30cdd71 --- /dev/null +++ b/web/models/debug.ts @@ -0,0 +1,115 @@ +export type Inputs = Record + +export type PromptVariable = { + key: string, + name: string, + type: string, // "string" | "number" | "select", + default?: string | number, + required: boolean, + options?: string[] + max_length?: number +} + +export type CompletionParams = { + max_tokens: number, + temperature: number, + top_p: number, + presence_penalty: number, + frequency_penalty: number, +} + +export type ModelId = "gpt-3.5-turbo" | "text-davinci-003" + +export type PromptConfig = { + prompt_template: string, + prompt_variables: PromptVariable[], +} + +export type MoreLikeThisConfig = { + enabled: boolean +} + +export type SuggestedQuestionsAfterAnswerConfig = MoreLikeThisConfig + +// frontend use. Not the same as backend +export type ModelConfig = { + provider: string, // LLM Provider: for example "OPENAI" + model_id: string, + configs: PromptConfig +} + +export type DebugRequestBody = { + inputs: Inputs, + query: string, + completion_params: CompletionParams, + model_config: ModelConfig +} + +export type DebugResponse = { + id: string, + answer: string, + created_at: string, +} + + +export type DebugResponseStream = { + id: string, + data: string, + created_at: string, +} + + +export type FeedBackRequestBody = { + message_id: string, + rating: 'like' | 'dislike', + content?: string, + from_source: 'api' | 'log' +} + + +export type FeedBackResponse = { + message_id: string, + rating: 'like' | 'dislike' +} + +// Log session list +export type LogSessionListQuery = { + keyword?: string, + start?: string, // format datetime(YYYY-mm-dd HH:ii) + end?: string, // format datetime(YYYY-mm-dd HH:ii) + page: number, + limit: number, // default 20. 1-100 +} + +export type LogSessionListResponse = { + data: { + id: string, + conversation_id: string, + query: string, // user's query question + message: string, // prompt send to LLM + answer: string, + creat_at: string, + }[], + total: number, + page: number, +} + +// log session detail and debug +export type LogSessionDetailResponse = { + id: string, + cnversation_id: string, + model_provider: string, + query: string, + inputs: Record[], + message: string, + message_tokens: number, // number of tokens in message + answer: string, + answer_tokens: number, // number of tokens in answer + provider_response_latency: number, // used time in ms + from_source: 'api' | 'log', +} + +export type SavedMessage = { + id: string, + answer: string +} \ No newline at end of file diff --git a/web/models/history.ts b/web/models/history.ts new file mode 100644 index 0000000000..90d6245cb9 --- /dev/null +++ b/web/models/history.ts @@ -0,0 +1,11 @@ +export type History = { + id: string + source: string + target: string +} +export type HistoryResponse = { + histories: History[] +} + +export const fetchHistories = (url: string) => + fetch(url).then(r => r.json()) diff --git a/web/models/log.ts b/web/models/log.ts new file mode 100644 index 0000000000..9f915a3313 --- /dev/null +++ b/web/models/log.ts @@ -0,0 +1,192 @@ +// Log type contains key:string conversation_id:string created_at:string quesiton:string answer:string +export type Conversation = { + id: string + key: string + conversationId: string + question: string + answer: string + userRate: number + adminRate: number +} + +export type ConversationListResponse = { + logs: Conversation[] +} + +export const fetchLogs = (url: string) => + fetch(url).then(r => r.json()) + +export const CompletionParams = ['temperature', 'top_p', 'presence_penalty', 'max_token', 'stop', 'frequency_penalty'] as const + +export type CompletionParamType = typeof CompletionParams[number] + +export type CompletionParamsType = { + max_tokens: number + temperature: number + top_p: number + stop: string[] + presence_penalty: number + frequency_penalty: number +} + +export type ModelConfigDetail = { + introduction: string + prompt_template: string + prompt_variables: Array<{ + key: string + name: string + description: string + type: string | number + default: string + options: string[] + }> + completion_params: CompletionParamsType +} + +export type Annotation = { + content: string + account: { + id: string + name: string + email: string + } + created_at?: number +} + +export type MessageContent = { + id: string + conversation_id: string + query: string + inputs: Record + // message: Record + message: string + message_tokens: number + answer_tokens: number + answer: string + provider_response_latency: number + created_at: number + annotation: Annotation + feedbacks: Array<{ + rating: 'like' | 'dislike' | null + content: string | null + from_source?: 'admin' | 'user' + from_end_user_id?: string + }> +} + +export type CompletionConversationGeneralDetail = { + id: string + status: 'normal' | 'finished' + from_source: 'api' | 'console' + from_end_user_id: string + from_account_id: string + read_at: Date + created_at: number + annotation: Annotation + user_feedback_stats: { + like: number + dislike: number + } + admin_feedback_stats: { + like: number + dislike: number + } + model_config: { + provider: string + model_id: string + configs: Pick + } + message: Pick +} + +export type CompletionConversationFullDetailResponse = { + id: string + status: 'normal' | 'finished' + from_source: 'api' | 'console' + from_end_user_id: string + from_account_id: string + // read_at: Date + created_at: number + model_config: { + provider: string + model_id: string + configs: ModelConfigDetail + } + message: MessageContent +} + +export type CompletionConversationsResponse = { + data: Array + has_more: boolean + limit: number + total: number + page: number +} + +export type CompletionConversationsRequest = { + keyword: string + start: string + end: string + annotation_status: string + page: number + limit: number // The default value is 20 and the range is 1-100 +} + +export type ChatConversationGeneralDetail = Omit & { + summary: string + message_count: number + annotated: boolean +} + +export type ChatConversationsResponse = { + data: Array + has_more: boolean + limit: number + total: number + page: number +} + +export type ChatConversationsRequest = CompletionConversationsRequest & { message_count: number } + +export type ChatConversationFullDetailResponse = Omit & { + message_count: number + model_config: { + provider: string + model_id: string + configs: ModelConfigDetail + } +} + +export type ChatMessagesRequest = { + conversation_id: string + first_id?: string + limit: number +} +export type ChatMessage = MessageContent + +export type ChatMessagesResponse = { + data: Array + has_more: boolean + limit: number +} + +export const MessageRatings = ['like', 'dislike', null] as const +export type MessageRating = typeof MessageRatings[number] + +export type LogMessageFeedbacksRequest = { + message_id: string + rating: MessageRating + content?: string +} + +export type LogMessageFeedbacksResponse = { + result: 'success' | 'error' +} + +export type LogMessageAnnotationsRequest = Omit + +export type LogMessageAnnotationsResponse = LogMessageFeedbacksResponse + +export type AnnotationsCountResponse = { + count: number +} diff --git a/web/models/share.ts b/web/models/share.ts new file mode 100644 index 0000000000..03eace12fe --- /dev/null +++ b/web/models/share.ts @@ -0,0 +1,19 @@ +import { Locale } from '@/i18n' + +export type ResponseHolder = {} + +export type ConversationItem = { + id: string + name: string + inputs: Record | null + introduction: string, +} + +export type SiteInfo = { + title: string + description: string + default_language: Locale + prompt_public: boolean + copyright?: string + privacy_policy?: string +} \ No newline at end of file diff --git a/web/models/user.ts b/web/models/user.ts new file mode 100644 index 0000000000..5451980902 --- /dev/null +++ b/web/models/user.ts @@ -0,0 +1,17 @@ +export type User = { + id: string + firstName: string + lastName: string + name: string + phone: string + username: string + email: string + avatar: string +} + +export type UserResponse = { + users: User[] +} + +export const fetchUsers = (url: string) => + fetch(url).then(r => r.json()) diff --git a/web/next.config.js b/web/next.config.js new file mode 100644 index 0000000000..18ac3e1975 --- /dev/null +++ b/web/next.config.js @@ -0,0 +1,43 @@ +const withMDX = require('@next/mdx')({ + extension: /\.mdx?$/, + options: { + // If you use remark-gfm, you'll need to use next.config.mjs + // as the package is ESM only + // https://github.com/remarkjs/remark-gfm#install + remarkPlugins: [], + rehypePlugins: [], + // If you use `MDXProvider`, uncomment the following line. + // providerImportSource: "@mdx-js/react", + }, +}) + +/** @type {import('next').NextConfig} */ +const nextConfig = { + productionBrowserSourceMaps: false, // enable browser source map generation during the production build + // Configure pageExtensions to include md and mdx + pageExtensions: ['ts', 'tsx', 'js', 'jsx', 'md', 'mdx'], + experimental: { + appDir: true, + }, + // fix all before production. Now it slow the develop speed. + eslint: { + // Warning: This allows production builds to successfully complete even if + // your project has ESLint errors. + ignoreDuringBuilds: true, + }, + typescript: { + // https://nextjs.org/docs/api-reference/next.config.js/ignoring-typescript-errors + ignoreBuildErrors: true, + }, + async redirects() { + return [ + { + source: '/', + destination: '/apps', + permanent: false, + }, + ] + }, +} + +module.exports = withMDX(nextConfig) diff --git a/web/package.json b/web/package.json new file mode 100644 index 0000000000..c9138c20fe --- /dev/null +++ b/web/package.json @@ -0,0 +1,82 @@ +{ + "name": "dify-web", + "version": "0.2.0", + "private": true, + "scripts": { + "dev": "next dev", + "build": "next build", + "start": "next start", + "lint": "next lint", + "fix": "next lint --fix" + }, + "dependencies": { + "@formatjs/intl-localematcher": "^0.2.32", + "@headlessui/react": "^1.7.13", + "@heroicons/react": "^2.0.16", + "@mdx-js/loader": "^2.3.0", + "@mdx-js/react": "^2.3.0", + "@next/mdx": "^13.2.4", + "@tailwindcss/line-clamp": "^0.4.2", + "@types/crypto-js": "^4.1.1", + "@types/lodash-es": "^4.17.7", + "@types/node": "18.15.0", + "@types/react": "18.0.28", + "@types/react-dom": "18.0.11", + "@types/react-slider": "^1.3.1", + "@types/react-syntax-highlighter": "^15.5.6", + "@types/react-window": "^1.8.5", + "@types/react-window-infinite-loader": "^1.0.6", + "ahooks": "^3.7.5", + "classnames": "^2.3.2", + "copy-to-clipboard": "^3.3.3", + "crypto-js": "^4.1.1", + "dayjs": "^1.11.7", + "echarts": "^5.4.1", + "echarts-for-react": "^3.0.2", + "eslint": "8.36.0", + "eslint-config-next": "13.2.4", + "i18next": "^22.4.13", + "i18next-resources-to-backend": "^1.1.3", + "immer": "^9.0.19", + "js-cookie": "^3.0.1", + "lodash-es": "^4.17.21", + "negotiator": "^0.6.3", + "next": "13.2.4", + "qs": "^6.11.1", + "react": "18.2.0", + "react-dom": "18.2.0", + "react-error-boundary": "^4.0.2", + "react-headless-pagination": "^1.1.4", + "react-i18next": "^12.2.0", + "react-infinite-scroll-component": "^6.1.0", + "react-markdown": "^8.0.6", + "react-slider": "^2.0.4", + "react-syntax-highlighter": "^15.5.0", + "react-tooltip": "5.8.3", + "react-window": "^1.8.9", + "react-window-infinite-loader": "^1.0.9", + "rehype-katex": "^6.0.2", + "remark-breaks": "^3.0.2", + "remark-gfm": "^3.0.1", + "remark-math": "^5.1.1", + "sass": "^1.61.0", + "scheduler": "^0.23.0", + "server-only": "^0.0.1", + "swr": "^2.1.0", + "typescript": "4.9.5", + "use-context-selector": "^1.4.1" + }, + "devDependencies": { + "@antfu/eslint-config": "^0.36.0", + "@faker-js/faker": "^7.6.0", + "@tailwindcss/typography": "^0.5.9", + "@types/js-cookie": "^3.0.3", + "@types/negotiator": "^0.6.1", + "@types/qs": "^6.9.7", + "autoprefixer": "^10.4.14", + "eslint-plugin-react-hooks": "^4.6.0", + "miragejs": "^0.1.47", + "postcss": "^8.4.21", + "tailwindcss": "^3.2.7" + } +} diff --git a/web/postcss.config.js b/web/postcss.config.js new file mode 100644 index 0000000000..33ad091d26 --- /dev/null +++ b/web/postcss.config.js @@ -0,0 +1,6 @@ +module.exports = { + plugins: { + tailwindcss: {}, + autoprefixer: {}, + }, +} diff --git a/web/public/favicon.ico b/web/public/favicon.ico new file mode 100644 index 0000000000..00c1f4fc2b Binary files /dev/null and b/web/public/favicon.ico differ diff --git a/web/service/apps.ts b/web/service/apps.ts new file mode 100644 index 0000000000..7bb49d4a28 --- /dev/null +++ b/web/service/apps.ts @@ -0,0 +1,96 @@ +import type { Fetcher } from 'swr' +import { del, get, post } from './base' +import type { ApikeysListResponse, AppDailyConversationsResponse, AppDailyEndUsersResponse, AppDetailResponse, AppListResponse, AppTemplatesResponse, AppTokenCostsResponse, CreateApiKeyResponse, GenerationIntroductionResponse, UpdateAppModelConfigResponse, UpdateAppNameResponse, UpdateAppSiteCodeResponse, UpdateOpenAIKeyResponse, ValidateOpenAIKeyResponse } from '@/models/app' +import type { CommonResponse } from '@/models/common' +import type { AppMode, ModelConfig } from '@/types/app' + +export const fetchAppList: Fetcher }> = ({ params }) => { + return get('apps', params) as Promise +} + +export const fetchAppDetail: Fetcher = ({ url, id }) => { + return get(`${url}/${id}`) as Promise +} + +export const fetchAppTemplates: Fetcher = ({ url }) => { + return get(url) as Promise +} + +export const createApp: Fetcher = ({ name, mode, config }) => { + return post('apps', { body: { name, mode, model_config: config } }) as Promise +} + +export const deleteApp: Fetcher = (appID) => { + return del(`apps/${appID}`) as Promise +} + +// path: /apps/{appId}/name +export const updateAppName: Fetcher }> = ({ url, body }) => { + return post(url, { body }) as Promise +} + +export const updateAppSiteStatus: Fetcher }> = ({ url, body }) => { + return post(url, { body }) as Promise +} + +export const updateAppApiStatus: Fetcher }> = ({ url, body }) => { + return post(url, { body }) as Promise +} + +// path: /apps/{appId}/rate-limit +export const updateAppRateLimit: Fetcher }> = ({ url, body }) => { + return post(url, { body }) as Promise +} + +export const updateAppSiteAccessToken: Fetcher = ({ url }) => { + return post(url) as Promise +} + +export const updateAppSiteConfig: Fetcher }> = ({ url, body }) => { + return post(url, { body }) as Promise +} + +export const getAppDailyConversations: Fetcher }> = ({ url, params }) => { + return get(url, { params }) as Promise +} + +export const getAppDailyEndUsers: Fetcher }> = ({ url, params }) => { + return get(url, { params }) as Promise +} + +export const getAppTokenCosts: Fetcher }> = ({ url, params }) => { + return get(url, { params }) as Promise +} + +export const updateAppModelConfig: Fetcher }> = ({ url, body }) => { + return post(url, { body }) as Promise +} + +// For temp testing +export const fetchAppListNoMock: Fetcher }> = ({ url, params }) => { + return get(url, params) as Promise +} + +export const fetchApiKeysList: Fetcher }> = ({ url, params }) => { + return get(url, params) as Promise +} + +export const delApikey: Fetcher }> = ({ url, params }) => { + return del(url, params) as Promise +} + +export const createApikey: Fetcher }> = ({ url, body }) => { + return post(url, body) as Promise +} + +export const validateOpenAIKey: Fetcher = ({ url, body }) => { + return post(url, { body }) as Promise +} + +export const updateOpenAIKey: Fetcher = ({ url, body }) => { + return post(url, { body }) as Promise +} + +export const generationIntroduction: Fetcher = ({ url, body }) => { + return post(url, { body }) as Promise +} diff --git a/web/service/base.ts b/web/service/base.ts new file mode 100644 index 0000000000..737f8a1c6a --- /dev/null +++ b/web/service/base.ts @@ -0,0 +1,353 @@ +import { API_PREFIX, MOCK_API_PREFIX, PUBLIC_API_PREFIX, IS_CE_EDITION } from '@/config' +import Toast from '@/app/components/base/toast' + +const TIME_OUT = 100000 + +const ContentType = { + json: 'application/json', + stream: 'text/event-stream', + form: 'application/x-www-form-urlencoded; charset=UTF-8', + download: 'application/octet-stream', // for download + upload: 'multipart/form-data', // for upload +} + +const baseOptions = { + method: 'GET', + mode: 'cors', + credentials: 'include', // always send cookies、HTTP Basic authentication. + headers: new Headers({ + 'Content-Type': ContentType.json, + }), + redirect: 'follow', +} + +export type IOnDataMoreInfo = { + conversationId: string | undefined + messageId: string + errorMessage?: string +} + +export type IOnData = (message: string, isFirstMessage: boolean, moreInfo: IOnDataMoreInfo) => void +export type IOnCompleted = (hasError?: boolean) => void +export type IOnError = (msg: string) => void + +type IOtherOptions = { + isPublicAPI?: boolean + isMock?: boolean + needAllResponseContent?: boolean + onData?: IOnData // for stream + onError?: IOnError + onCompleted?: IOnCompleted // for stream + getAbortController?: (abortController: AbortController) => void +} + +function unicodeToChar(text: string) { + return text.replace(/\\u[0-9a-f]{4}/g, (_match, p1) => { + return String.fromCharCode(parseInt(p1, 16)) + }) +} + + +export function format(text: string) { + let res = text.trim() + if (res.startsWith('\n')) { + res = res.replace('\n', '') + } + return res.replaceAll('\n', '
').replaceAll('```', '') +} + +const handleStream = (response: any, onData: IOnData, onCompleted?: IOnCompleted) => { + if (!response.ok) + throw new Error('Network response was not ok') + + const reader = response.body.getReader() + const decoder = new TextDecoder('utf-8') + let buffer = '' + let bufferObj: any + let isFirstMessage = true + function read() { + let hasError = false + reader.read().then((result: any) => { + if (result.done) { + onCompleted && onCompleted() + return + } + buffer += decoder.decode(result.value, { stream: true }) + const lines = buffer.split('\n') + try { + lines.forEach((message) => { + if (message.startsWith('data: ')) { // check if it starts with data: + // console.log(message); + bufferObj = JSON.parse(message.substring(6)) // remove data: and parse as json + if (bufferObj.status === 400) { + onData('', false, { + conversationId: undefined, + messageId: '', + errorMessage: bufferObj.message + }) + hasError = true + onCompleted && onCompleted(true) + return + } + // can not use format here. Because message is splited. + onData(unicodeToChar(bufferObj.answer), isFirstMessage, { + conversationId: bufferObj.conversation_id, + messageId: bufferObj.id, + }) + isFirstMessage = false + } + }) + buffer = lines[lines.length - 1] + } catch (e) { + onData('', false, { + conversationId: undefined, + messageId: '', + errorMessage: e + '' + }) + hasError = true + onCompleted && onCompleted(true) + return + } + if (!hasError) { + read() + } + }) + } + read() +} + +const baseFetch = (url: string, fetchOptions: any, { isPublicAPI = false, isMock = false, needAllResponseContent }: IOtherOptions) => { + const options = Object.assign({}, baseOptions, fetchOptions) + if (isPublicAPI) { + const sharedToken = globalThis.location.pathname.split('/').slice(-1)[0] + options.headers.set('Authorization', `bearer ${sharedToken}`) + } + + let urlPrefix = isPublicAPI ? PUBLIC_API_PREFIX : API_PREFIX + if (isMock) + urlPrefix = MOCK_API_PREFIX + + let urlWithPrefix = `${urlPrefix}${url.startsWith('/') ? url : `/${url}`}` + + const { method, params, body } = options + // handle query + if (method === 'GET' && params) { + const paramsArray: string[] = [] + Object.keys(params).forEach(key => + paramsArray.push(`${key}=${encodeURIComponent(params[key])}`), + ) + if (urlWithPrefix.search(/\?/) === -1) + urlWithPrefix += `?${paramsArray.join('&')}` + + else + urlWithPrefix += `&${paramsArray.join('&')}` + + delete options.params + } + + if (body) + options.body = JSON.stringify(body) + + // Handle timeout + return Promise.race([ + new Promise((resolve, reject) => { + setTimeout(() => { + reject(new Error('request timeout')) + }, TIME_OUT) + }), + new Promise((resolve, reject) => { + globalThis.fetch(urlWithPrefix, options) + .then((res: any) => { + const resClone = res.clone() + // Error handler + if (!/^(2|3)\d{2}$/.test(res.status)) { + const bodyJson = res.json() + switch (res.status) { + case 401: { + if (isPublicAPI) { + Toast.notify({ type: 'error', message: 'Invalid token' }) + return + } + const loginUrl = `${globalThis.location.origin}/signin` + if (IS_CE_EDITION) { + bodyJson.then((data: any) => { + if (data.code === 'not_setup') { + globalThis.location.href = `${globalThis.location.origin}/install` + } else { + if (location.pathname === '/signin') { + bodyJson.then((data: any) => { + Toast.notify({ type: 'error', message: data.message }) + }) + } else { + globalThis.location.href = loginUrl + } + } + }) + return Promise.reject() + } + globalThis.location.href = loginUrl + break + } + case 403: + new Promise(() => { + bodyJson.then((data: any) => { + Toast.notify({ type: 'error', message: data.message }) + if (data.code === 'already_setup') { + globalThis.location.href = `${globalThis.location.origin}/signin` + } + }) + }) + break + // fall through + default: + // eslint-disable-next-line no-new + new Promise(() => { + bodyJson.then((data: any) => { + Toast.notify({ type: 'error', message: data.message }) + }) + }) + } + return Promise.reject(resClone) + } + + // handle delete api. Delete api not return content. + if (res.status === 204) { + resolve({ result: "success" }) + return + } + + // return data + const data = options.headers.get('Content-type') === ContentType.download ? res.blob() : res.json() + + resolve(needAllResponseContent ? resClone : data) + }) + .catch((err) => { + Toast.notify({ type: 'error', message: err }) + reject(err) + }) + }), + ]) +} + +export const upload = (options: any): Promise => { + const defaultOptions = { + method: 'POST', + url: `${API_PREFIX}/files/upload`, + headers: {}, + data: {}, + } + options = { + ...defaultOptions, + ...options, + headers: { ...defaultOptions.headers, ...options.headers }, + }; + return new Promise(function (resolve, reject) { + const xhr = options.xhr + xhr.open(options.method, options.url); + for (const key in options.headers) { + xhr.setRequestHeader(key, options.headers[key]); + } + xhr.withCredentials = true + xhr.responseType = 'json' + xhr.onreadystatechange = function () { + if (xhr.readyState === 4) { + if (xhr.status === 201) { + resolve(xhr.response) + } else { + reject(xhr) + } + } + } + xhr.upload.onprogress = options.onprogress + xhr.send(options.data) + }) +} + +export const ssePost = (url: string, fetchOptions: any, { isPublicAPI = false, onData, onCompleted, onError, getAbortController }: IOtherOptions) => { + const abortController = new AbortController() + + const options = Object.assign({}, baseOptions, { + method: 'POST', + signal: abortController.signal, + }, fetchOptions) + + getAbortController?.(abortController) + + const urlPrefix = isPublicAPI ? PUBLIC_API_PREFIX : API_PREFIX + const urlWithPrefix = `${urlPrefix}${url.startsWith('/') ? url : `/${url}`}` + + const { body } = options + if (body) + options.body = JSON.stringify(body) + + globalThis.fetch(urlWithPrefix, options) + .then((res: any) => { + // debugger + if (!/^(2|3)\d{2}$/.test(res.status)) { + // eslint-disable-next-line no-new + new Promise(() => { + res.json().then((data: any) => { + Toast.notify({ type: 'error', message: data.message || 'Server Error' }) + }) + }) + onError?.('Server Error') + return + } + return handleStream(res, (str: string, isFirstMessage: boolean, moreInfo: IOnDataMoreInfo) => { + if (moreInfo.errorMessage) { + Toast.notify({ type: 'error', message: moreInfo.errorMessage }) + return + } + onData?.(str, isFirstMessage, moreInfo) + }, onCompleted) + }).catch((e) => { + // debugger + Toast.notify({ type: 'error', message: e }) + onError?.(e) + }) +} + +export const request = (url: string, options = {}, otherOptions?: IOtherOptions) => { + return baseFetch(url, options, otherOptions || {}) +} + +export const get = (url: string, options = {}, otherOptions?: IOtherOptions) => { + return request(url, Object.assign({}, options, { method: 'GET' }), otherOptions) +} + +// For public API +export const getPublic = (url: string, options = {}, otherOptions?: IOtherOptions) => { + return get(url, options, { ...otherOptions, isPublicAPI: true }) +} + +export const post = (url: string, options = {}, otherOptions?: IOtherOptions) => { + return request(url, Object.assign({}, options, { method: 'POST' }), otherOptions) +} + +export const postPublic = (url: string, options = {}, otherOptions?: IOtherOptions) => { + return post(url, options, { ...otherOptions, isPublicAPI: true }) +} + +export const put = (url: string, options = {}, otherOptions?: IOtherOptions) => { + return request(url, Object.assign({}, options, { method: 'PUT' }), otherOptions) +} + +export const putPublic = (url: string, options = {}, otherOptions?: IOtherOptions) => { + return put(url, options, { ...otherOptions, isPublicAPI: true }) +} + +export const del = (url: string, options = {}, otherOptions?: IOtherOptions) => { + return request(url, Object.assign({}, options, { method: 'DELETE' }), otherOptions) +} + +export const delPublic = (url: string, options = {}, otherOptions?: IOtherOptions) => { + return del(url, options, { ...otherOptions, isPublicAPI: true }) +} + +export const patch = (url: string, options = {}, otherOptions?: IOtherOptions) => { + return request(url, Object.assign({}, options, { method: 'PATCH' }), otherOptions) +} + +export const patchPublic = (url: string, options = {}, otherOptions?: IOtherOptions) => { + return patch(url, options, { ...otherOptions, isPublicAPI: true }) +} diff --git a/web/service/common.ts b/web/service/common.ts new file mode 100644 index 0000000000..4f28cc84d3 --- /dev/null +++ b/web/service/common.ts @@ -0,0 +1,91 @@ +import type { Fetcher } from 'swr' +import { get, post, del, put } from './base' +import type { + CommonResponse, LangGeniusVersionResponse, OauthResponse, + TenantInfoResponse, UserProfileOriginResponse, Member, + AccountIntegrate, Provider, ProviderAzureToken, IWorkspace +} from '@/models/common' +import type { + ValidateOpenAIKeyResponse, + UpdateOpenAIKeyResponse +} from '@/models/app' + +export const login: Fetcher }> = ({ url, body }) => { + return post(url, { body }) as Promise +} + +export const setup: Fetcher }> = ({ body }) => { + return post('/setup', { body }) as Promise +} + +export const fetchUserProfile: Fetcher }> = ({ url, params }) => { + return get(url, params, { needAllResponseContent: true }) as Promise +} + +export const updateUserProfile: Fetcher }> = ({ url, body }) => { + return post(url, { body }) as Promise +} + +export const fetchTenantInfo: Fetcher = ({ url }) => { + return get(url) as Promise +} + +export const logout: Fetcher }> = ({ url, params }) => { + return get(url, params) as Promise +} + +export const fetchLanggeniusVersion: Fetcher }> = ({ url, params }) => { + return get(url, { params }) as Promise +} + +export const oauth: Fetcher }> = ({ url, params }) => { + return get(url, { params }) as Promise +} + +export const oneMoreStep: Fetcher }> = ({ url, body }) => { + return post(url, { body }) as Promise +} + +export const fetchMembers: Fetcher<{ accounts: Member[] | null }, { url: string; params: Record }> = ({ url, params }) => { + return get(url, { params }) as Promise<{ accounts: Member[] | null }> +} + +export const fetchProviders: Fetcher }> = ({ url, params }) => { + return get(url, { params }) as Promise +} + +export const validateProviderKey: Fetcher = ({ url, body }) => { + return post(url, { body }) as Promise +} +export const updateProviderAIKey: Fetcher = ({ url, body }) => { + return post(url, { body }) as Promise +} + +export const fetchAccountIntegrates: Fetcher<{ data: AccountIntegrate[] | null }, { url: string; params: Record }> = ({ url, params }) => { + return get(url, { params }) as Promise<{ data: AccountIntegrate[] | null }> +} + +export const inviteMember: Fetcher }> = ({ url, body }) => { + return post(url, { body }) as Promise +} + +export const updateMemberRole: Fetcher }> = ({ url, body }) => { + return put(url, { body }) as Promise +} + +export const deleteMemberOrCancelInvitation: Fetcher = ({ url }) => { + return del(url) as Promise +} + +export const fetchFilePreview: Fetcher<{ content: string }, { fileID: string }> = ({ fileID }) => { + return get(`/files/${fileID}/preview`) as Promise<{ content: string }> +} + +export const fetchWorkspaces: Fetcher<{ workspaces: IWorkspace[] }, { url: string; params: Record }> = ({ url, params }) => { + return get(url, { params }) as Promise<{ workspaces: IWorkspace[] }> +} + +export const switchWorkspace: Fetcher }> = ({ url, body }) => { + return post(url, { body }) as Promise +} + diff --git a/web/service/datasets.ts b/web/service/datasets.ts new file mode 100644 index 0000000000..b633dbf886 --- /dev/null +++ b/web/service/datasets.ts @@ -0,0 +1,127 @@ +import type { Fetcher } from 'swr' +import { del, get, post, put, patch } from './base' +import qs from 'qs' +import type { RelatedAppResponse, DataSet, HitTestingResponse, HitTestingRecordsResponse, DataSetListResponse, CreateDocumentReq, InitialDocumentDetail, DocumentDetailResponse, DocumentListResponse, IndexingEstimateResponse, FileIndexingEstimateResponse, IndexingStatusResponse, ProcessRuleResponse, SegmentsQuery, SegmentsResponse, createDocumentResponse } from '@/models/datasets' +import type { CommonResponse } from '@/models/common' + +// apis for documents in a dataset + +type CommonDocReq = { + datasetId: string + documentId: string +} + +export type SortType = 'created_at' | 'hit_count' | '-created_at' | '-hit_count' + +export type MetadataType = 'all' | 'only' | 'without' + +export const fetchDataDetail: Fetcher = (datasetId: string) => { + return get(`/datasets/${datasetId}`) as Promise +} + +export const updateDatasetSetting: Fetcher>}> = ({ datasetId, body }) => { + return patch(`/datasets/${datasetId}`, { body } ) as Promise +} + +export const fetchDatasetRelatedApps: Fetcher = (datasetId: string) => { + return get(`/datasets/${datasetId}/related-apps`) as Promise +} + +export const fetchDatasets: Fetcher = ({ url, params }) => { + const urlParams = qs.stringify(params, { indices: false }) + return get(`${url}?${urlParams}`,) as Promise +} + +export const createEmptyDataset: Fetcher = ({ name }) => { + return post('/datasets', { body: { name } }) as Promise +} + +export const deleteDataset: Fetcher = (datasetID) => { + return del(`/datasets/${datasetID}`) as Promise +} + +export const fetchDefaultProcessRule: Fetcher = ({ url }) => { + return get(url) as Promise +} +export const fetchProcessRule: Fetcher = ({ params: { documentId } }) => { + return get('/datasets/process-rule', { params: { document_id: documentId } }) as Promise +} + +export const fetchDocuments: Fetcher = ({ datasetId, params }) => { + return get(`/datasets/${datasetId}/documents`, { params }) as Promise +} + +export const createFirstDocument: Fetcher = ({ body }) => { + return post(`/datasets/init`, { body }) as Promise +} + +export const createDocument: Fetcher = ({ datasetId, body }) => { + return post(`/datasets/${datasetId}/documents`, { body }) as Promise +} + +export const fetchIndexingEstimate: Fetcher = ({ datasetId, documentId }) => { + return get(`/datasets/${datasetId}/documents/${documentId}/indexing-estimate`, {}) as Promise +} + +export const fetchIndexingStatus: Fetcher = ({ datasetId, documentId }) => { + return get(`/datasets/${datasetId}/documents/${documentId}/indexing-status`, {}) as Promise +} + +export const fetchDocumentDetail: Fetcher = ({ datasetId, documentId, params }) => { + return get(`/datasets/${datasetId}/documents/${documentId}`, { params }) as Promise +} + +export const pauseDocIndexing: Fetcher = ({ datasetId, documentId }) => { + return patch(`/datasets/${datasetId}/documents/${documentId}/processing/pause`) as Promise +} + +export const resumeDocIndexing: Fetcher = ({ datasetId, documentId }) => { + return patch(`/datasets/${datasetId}/documents/${documentId}/processing/resume`) as Promise +} + +export const deleteDocument: Fetcher = ({ datasetId, documentId }) => { + return del(`/datasets/${datasetId}/documents/${documentId}`) as Promise +} + +export const archiveDocument: Fetcher = ({ datasetId, documentId }) => { + return patch(`/datasets/${datasetId}/documents/${documentId}/status/archive`) as Promise +} + +export const enableDocument: Fetcher = ({ datasetId, documentId }) => { + return patch(`/datasets/${datasetId}/documents/${documentId}/status/enable`) as Promise +} + +export const disableDocument: Fetcher = ({ datasetId, documentId }) => { + return patch(`/datasets/${datasetId}/documents/${documentId}/status/disable`) as Promise +} + +export const modifyDocMetadata: Fetcher } }> = ({ datasetId, documentId, body }) => { + return put(`/datasets/${datasetId}/documents/${documentId}/metadata`, { body }) as Promise +} + +// apis for segments in a document + +export const fetchSegments: Fetcher = ({ datasetId, documentId, params }) => { + return get(`/datasets/${datasetId}/documents/${documentId}/segments`, { params }) as Promise +} + +export const enableSegment: Fetcher = ({ datasetId, segmentId }) => { + return patch(`/datasets/${datasetId}/segments/${segmentId}/enable`) as Promise +} + +export const disableSegment: Fetcher = ({ datasetId, segmentId }) => { + return patch(`/datasets/${datasetId}/segments/${segmentId}/disable`) as Promise +} + +// hit testing +export const hitTesting: Fetcher = ({ datasetId, queryText }) => { + return post(`/datasets/${datasetId}/hit-testing`, { body: { query: queryText } }) as Promise +} + +export const fetchTestingRecords: Fetcher = ({ datasetId, params }) => { + return get(`/datasets/${datasetId}/queries`, { params }) as Promise +} + +export const fetchFileIndexingEstimate: Fetcher = (body: any) => { + return post(`/datasets/file-indexing-estimate`, { body }) as Promise +} diff --git a/web/service/debug.ts b/web/service/debug.ts new file mode 100644 index 0000000000..b6b922f0e0 --- /dev/null +++ b/web/service/debug.ts @@ -0,0 +1,40 @@ +import { ssePost, get, IOnData, IOnCompleted, IOnError } from './base' + +export const sendChatMessage = async (appId: string, body: Record, { onData, onCompleted, onError, getAbortController }: { + onData: IOnData + onCompleted: IOnCompleted + onError: IOnError, + getAbortController?: (abortController: AbortController) => void +}) => { + return ssePost(`apps/${appId}/chat-messages`, { + body: { + ...body, + response_mode: 'streaming' + } + }, { onData, onCompleted, onError, getAbortController }) +} + +export const sendCompletionMessage = async (appId: string, body: Record, { onData, onCompleted, onError }: { + onData: IOnData + onCompleted: IOnCompleted + onError: IOnError +}) => { + return ssePost(`apps/${appId}/completion-messages`, { + body: { + ...body, + response_mode: 'streaming' + } + }, { onData, onCompleted, onError }) +} + +export const fetchSuggestedQuestions = (appId: string, messageId: string) => { + return get(`apps/${appId}/chat-messages/${messageId}/suggested-questions`) +} + +export const fetchConvesationMessages = (appId: string, conversation_id: string) => { + return get(`apps/${appId}/chat-messages`, { + params: { + conversation_id + } + }) +} diff --git a/web/service/demo/index.tsx b/web/service/demo/index.tsx new file mode 100644 index 0000000000..d95c351f19 --- /dev/null +++ b/web/service/demo/index.tsx @@ -0,0 +1,106 @@ +'use client' +import type { FC } from 'react' +import React from 'react' +import useSWR, { useSWRConfig } from 'swr' +import { createApp, fetchAppDetail, fetchAppList, getAppDailyConversations, getAppDailyEndUsers, updateAppApiStatus, updateAppModelConfig, updateAppName, updateAppRateLimit, updateAppSiteAccessToken, updateAppSiteConfig, updateAppSiteStatus } from '../apps' +import Loading from '@/app/components/base/loading' +const Service: FC = () => { + const { data: appList, error: appListError } = useSWR({ url: '/apps', params: { page: 1 } }, fetchAppList) + const { data: firstApp, error: appDetailError } = useSWR({ url: '/apps', id: '1' }, fetchAppDetail) + const { data: appName, error: appNameError } = useSWR({ url: '/apps', id: '1', body: { name: 'new name' } }, updateAppName) + const { data: updateAppSiteStatusRes, error: err1 } = useSWR({ url: '/apps', id: '1', body: { enable_site: false } }, updateAppSiteStatus) + const { data: updateAppApiStatusRes, error: err2 } = useSWR({ url: '/apps', id: '1', body: { enable_api: true } }, updateAppApiStatus) + const { data: updateAppRateLimitRes, error: err3 } = useSWR({ url: '/apps', id: '1', body: { api_rpm: 10, api_rph: 20 } }, updateAppRateLimit) + const { data: updateAppSiteCodeRes, error: err4 } = useSWR({ url: '/apps', id: '1', body: {} }, updateAppSiteAccessToken) + const { data: updateAppSiteConfigRes, error: err5 } = useSWR({ url: '/apps', id: '1', body: { title: 'title test', author: 'author test' } }, updateAppSiteConfig) + const { data: getAppDailyConversationsRes, error: err6 } = useSWR({ url: '/apps', id: '1', body: { start: '1', end: '2' } }, getAppDailyConversations) + const { data: getAppDailyEndUsersRes, error: err7 } = useSWR({ url: '/apps', id: '1', body: { start: '1', end: '2' } }, getAppDailyEndUsers) + const { data: updateAppModelConfigRes, error: err8 } = useSWR({ url: '/apps', id: '1', body: { model_id: 'gpt-100' } }, updateAppModelConfig) + + const { mutate } = useSWRConfig() + + const handleCreateApp = async () => { + await createApp({ + name: `new app${Math.round(Math.random() * 100)}`, + mode: 'chat', + }) + // reload app list + mutate({ url: '/apps', params: { page: 1 } }) + } + + if (appListError || appDetailError || appNameError || err1 || err2 || err3 || err4 || err5 || err6 || err7 || err8) + return
{JSON.stringify(appNameError)}
+ + if (!appList || !firstApp || !appName || !updateAppSiteStatusRes || !updateAppApiStatusRes || !updateAppRateLimitRes || !updateAppSiteCodeRes || !updateAppSiteConfigRes || !getAppDailyConversationsRes || !getAppDailyEndUsersRes || !updateAppModelConfigRes) + return + + return ( +
+
+
+
1.App list
+
+ {appList.data.map(item => ( +
{item.id} {item.name}
+ ))} +
+
+ +
+
2.First app detail
+
{JSON.stringify(firstApp)}
+
+ +
+ +
+ +
+
3.updateAppName
+
{JSON.stringify(appName)}
+
+ +
+
4.updateAppSiteStatusRes
+
{JSON.stringify(updateAppSiteStatusRes)}
+
+ +
+
5.updateAppApiStatusRes
+
{JSON.stringify(updateAppApiStatusRes)}
+
+ +
+
6.updateAppRateLimitRes
+
{JSON.stringify(updateAppRateLimitRes)}
+
+ +
+
7.updateAppSiteCodeRes
+
{JSON.stringify(updateAppSiteCodeRes)}
+
+ +
+
8.updateAppSiteConfigRes
+
{JSON.stringify(updateAppSiteConfigRes)}
+
+ +
+
9.getAppDailyConversationsRes
+
{JSON.stringify(getAppDailyConversationsRes)}
+
+ +
+
10.getAppDailyEndUsersRes
+
{JSON.stringify(getAppDailyEndUsersRes)}
+
+ +
+
11.updateAppModelConfigRes
+
{JSON.stringify(updateAppModelConfigRes)}
+
+
+
+ ) +} +export default React.memo(Service) diff --git a/web/service/log.ts b/web/service/log.ts new file mode 100644 index 0000000000..76ebd3adcd --- /dev/null +++ b/web/service/log.ts @@ -0,0 +1,59 @@ +import type { Fetcher } from 'swr' +import { get, post } from './base' +import type { + AnnotationsCountResponse, + ChatConversationFullDetailResponse, + ChatConversationsRequest, + ChatConversationsResponse, + ChatMessagesRequest, + ChatMessagesResponse, + CompletionConversationFullDetailResponse, + CompletionConversationsRequest, + CompletionConversationsResponse, + ConversationListResponse, + LogMessageAnnotationsRequest, + LogMessageAnnotationsResponse, + LogMessageFeedbacksRequest, + LogMessageFeedbacksResponse, +} from '@/models/log' + +export const fetchConversationList: Fetcher }> = ({ appId, params }) => { + return get(`/console/api/apps/${appId}/messages`, params) as Promise +} + +// (Text Generation Application) Session List +export const fetchCompletionConversations: Fetcher = ({ url, params }) => { + return get(url, { params }) as Promise +} + +// (Text Generation Application) Session Detail +export const fetchCompletionConversationDetail: Fetcher = ({ url }) => { + return get(url, {}) as Promise +} + +// (Chat Application) Session List +export const fetchChatConversations: Fetcher = ({ url, params }) => { + return get(url, { params }) as Promise +} + +// (Chat Application) Session Detail +export const fetchChatConversationDetail: Fetcher = ({ url }) => { + return get(url, {}) as Promise +} + +// (Chat Application) Message list in one session +export const fetchChatMessages: Fetcher = ({ url, params }) => { + return get(url, { params }) as Promise +} + +export const updateLogMessageFeedbacks: Fetcher = ({ url, body }) => { + return post(url, { body }) as Promise +} + +export const updateLogMessageAnnotations: Fetcher = ({ url, body }) => { + return post(url, { body }) as Promise +} + +export const fetchAnnotationsCount: Fetcher = ({ url }) => { + return get(url) as Promise +} diff --git a/web/service/share.ts b/web/service/share.ts new file mode 100644 index 0000000000..7623be5f80 --- /dev/null +++ b/web/service/share.ts @@ -0,0 +1,81 @@ +import type { IOnCompleted, IOnData, IOnError } from './base' +import { getPublic as get, postPublic as post, ssePost, delPublic as del } from './base' +import type { Feedbacktype } from '@/app/components/app/chat' + +export const sendChatMessage = async (body: Record, { onData, onCompleted, onError, getAbortController }: { + onData: IOnData + onCompleted: IOnCompleted + onError: IOnError, + getAbortController?: (abortController: AbortController) => void +}) => { + return ssePost('chat-messages', { + body: { + ...body, + response_mode: 'streaming', + }, + }, { onData, onCompleted, isPublicAPI: true, onError, getAbortController }) +} + +export const sendCompletionMessage = async (body: Record, { onData, onCompleted, onError }: { + onData: IOnData + onCompleted: IOnCompleted + onError: IOnError +}) => { + return ssePost('completion-messages', { + body: { + ...body, + response_mode: 'streaming', + }, + }, { onData, onCompleted, isPublicAPI: true, onError }) +} + +export const fetchAppInfo = async () => { + return get('/site') +} + +export const fetchConversations = async () => { + return get('conversations', { params: { limit: 20, first_id: '' } }) +} + +export const fetchChatList = async (conversationId: string) => { + return get('messages', { params: { conversation_id: conversationId, limit: 20, last_id: '' } }) +} + +// Abandoned API interface +// export const fetchAppVariables = async () => { +// return get(`variables`) +// } + +// init value. wait for server update +export const fetchAppParams = async () => { + return get('parameters') +} + +export const updateFeedback = async ({ url, body }: { url: string; body: Feedbacktype }) => { + return post(url, { body }) +} + +export const fetcMoreLikeThis = async (messageId: string) => { + return get(`/messages/${messageId}/more-like-this`, { + params: { + response_mode: 'blocking', + } + }) +} + +export const saveMessage = (messageId: string) => { + return post('/saved-messages', { body: { message_id: messageId } }) +} + +export const fetchSavedMessage = async () => { + return get(`/saved-messages`) +} + + +export const removeMessage = (messageId: string) => { + return del(`/saved-messages/${messageId}`) +} + +export const fetchSuggestedQuestions = (messageId: string) => { + return get(`/messages/${messageId}/suggested-questions`) +} diff --git a/web/tailwind.config.js b/web/tailwind.config.js new file mode 100644 index 0000000000..72bb550360 --- /dev/null +++ b/web/tailwind.config.js @@ -0,0 +1,70 @@ +/** @type {import('tailwindcss').Config} */ +module.exports = { + content: [ + './app/**/*.{js,ts,jsx,tsx}', + './components/**/*.{js,ts,jsx,tsx}', + ], + theme: { + typography: require('./typography'), + extend: { + colors: { + gray: { + 25: '#FCFCFD', + 50: '#F9FAFB', + 100: '#F3F4F6', + 200: '#E5E7EB', + 300: '#D1D5DB', + 400: '#9CA3AF', + 500: '#6B7280', + 700: '#374151', + 800: '#1F2A37', + 900: '#111928', + }, + primary: { + 25: '#F5F8FF', + 50: '#EBF5FF', + 100: '#E1EFFE', + 200: '#C3DDFD', + 300: '#A4CAFE', + 400: '#528BFF', + 600: '#1C64F2', + 700: '#1A56DB', + }, + blue: { + 500: '#E1EFFE', + }, + green: { + 50: '#F3FAF7', + 100: '#DEF7EC', + 800: '#03543F', + + }, + yellow: { + 100: '#FDF6B2', + 800: '#723B13', + }, + purple: { + 50: '#F6F5FF', + 200: '#DCD7FE', + }, + indigo: { + 25: '#F5F8FF', + 100: '#E0EAFF', + 600: '#444CE7' + } + }, + screens: { + 'mobile': '100px', + // => @media (min-width: 100px) { ... } + 'tablet': '640px', // 391 + // => @media (min-width: 600px) { ... } + 'pc': '769px', + // => @media (min-width: 769px) { ... } + }, + }, + }, + plugins: [ + require('@tailwindcss/typography'), + require('@tailwindcss/line-clamp'), + ], +} diff --git a/web/test/factories/index.ts b/web/test/factories/index.ts new file mode 100644 index 0000000000..810e05eb49 --- /dev/null +++ b/web/test/factories/index.ts @@ -0,0 +1,66 @@ +import { Factory } from 'miragejs' +import { faker } from '@faker-js/faker' + +import type { History } from '@/models/history' +import type { User } from '@/models/user' +import type { Log } from '@/models/log' + +export const seedHistory = () => { + return Factory.extend>({ + source() { + return faker.address.streetAddress() + }, + target() { + return faker.address.streetAddress() + }, + }) +} + +export const seedUser = () => { + return Factory.extend>({ + firstName() { + return faker.name.firstName() + }, + lastName() { + return faker.name.lastName() + }, + name() { + return faker.address.streetAddress() + }, + phone() { + return faker.phone.number() + }, + email() { + return faker.internet.email() + }, + username() { + return faker.internet.userName() + }, + avatar() { + return faker.internet.avatar() + }, + }) +} + +export const seedLog = () => { + return Factory.extend>({ + get key() { + return faker.datatype.uuid() + }, + get conversationId() { + return faker.datatype.uuid() + }, + get question() { + return faker.lorem.sentence() + }, + get answer() { + return faker.lorem.sentence() + }, + get userRate() { + return faker.datatype.number(5) + }, + get adminRate() { + return faker.datatype.number(5) + } + }) +} \ No newline at end of file diff --git a/web/test/test_util.ts b/web/test/test_util.ts new file mode 100644 index 0000000000..f8e293b0a7 --- /dev/null +++ b/web/test/test_util.ts @@ -0,0 +1,45 @@ +import { Model, createServer } from 'miragejs' +import type { User } from '@/models/user' +import type { History } from '@/models/history' +import type { Log } from '@/models/log' +import { seedUser, seedHistory, seedLog } from '@/test/factories' + + +export function mockAPI() { + if (process.env.NODE_ENV === 'development') { + console.log('in development mode, starting mock server ... ') + const server = createServer({ + environment: process.env.NODE_ENV, + factories: { + user: seedUser(), + history: seedHistory(), + log: seedLog(), + }, + models: { + user: Model.extend>({}), + history: Model.extend>({}), + log: Model.extend>({}), + }, + routes() { + this.namespace = '/api' + this.get('/users', () => { + return this.schema.all('user') + }) + this.get('/histories', () => { + return this.schema.all('history') + }) + this.get('/logs', () => { + return this.schema.all('log') + }) + }, + seeds(server) { + server.createList('user', 20) + server.createList('history', 50) + server.createList('log', 50) + }, + }) + return server + } + console.log('Not in development mode, not starting mock server ... ') + return null +} diff --git a/web/tsconfig.json b/web/tsconfig.json new file mode 100644 index 0000000000..c3e0bca665 --- /dev/null +++ b/web/tsconfig.json @@ -0,0 +1,42 @@ +{ + "compilerOptions": { + "target": "es2015", + "lib": [ + "dom", + "dom.iterable", + "esnext" + ], + "allowJs": true, + "skipLibCheck": true, + "strict": true, + "forceConsistentCasingInFileNames": true, + "noEmit": true, + "esModuleInterop": true, + "module": "esnext", + "moduleResolution": "node", + "resolveJsonModule": true, + "isolatedModules": true, + "jsx": "preserve", + "incremental": true, + "plugins": [ + { + "name": "next" + } + ], + "paths": { + "@/*": [ + "./*" + ] + } + }, + "include": [ + "next-env.d.ts", + "**/*.ts", + "**/*.tsx", + ".next/types/**/*.ts", + "app/components/develop/Prose.jsx" + ], + "exclude": [ + "node_modules" + ] +} diff --git a/web/types/app.ts b/web/types/app.ts new file mode 100644 index 0000000000..33ff1d5c67 --- /dev/null +++ b/web/types/app.ts @@ -0,0 +1,227 @@ +export enum AppType { + 'chat' = 'chat', + 'completion' = 'completion', +} + +export type VariableInput = { + key: string + name: string + value: string +} + +/** + * App modes + */ +export const AppModes = ['completion', 'chat'] as const +export type AppMode = typeof AppModes[number] + +/** + * Variable type + */ +export const VariableTypes = ['string', 'number', 'select'] as const +export type VariableType = typeof VariableTypes[number] + +/** + * Prompt variable parameter + */ +export type PromptVariable = { + /** Variable key */ + key: string + /** Variable name */ + name: string + /** Type */ + type: VariableType + required: boolean + /** Enumeration of single-selection drop-down values */ + options?: string[] + max_length?: number +} + +export type TextTypeFormItem = { + label: string, + variable: string, + required: boolean + max_length: number +} + +export type SelectTypeFormItem = { + label: string, + variable: string, + required: boolean, + options: string[] +} +/** + * User Input Form Item + */ +export type UserInputFormItem = { + 'text-input': TextTypeFormItem +} | { + 'select': SelectTypeFormItem +} + + +export type ToolItem = { + dataset: { + enabled: boolean + id: string + } +} | { + 'sensitive-word-avoidance': { + enabled: boolean + words: string[] + canned_response: string + } +} + +/** + * Model configuration. The backend type. + */ +export type ModelConfig = { + opening_statement: string + pre_prompt: string + user_input_form: UserInputFormItem[] + more_like_this: { + enabled: boolean + } + suggested_questions_after_answer: { + enabled: boolean + } + agent_mode: { + enabled: boolean + tools: ToolItem[] + } + model: { + /** LLM provider, e.g., OPENAI */ + provider: string + /** Model name, e.g, gpt-3.5.turbo */ + name: string + /** Default Completion call parameters */ + completion_params: { + /** Maximum number of tokens in the answer message returned by Completion */ + max_tokens: number + /** + * A number between 0 and 2. + * The larger the number, the more random the result; + * otherwise, the more deterministic. + * When in use, choose either `temperature` or `top_p`. + * Default is 1. + */ + temperature: number + /** + * Represents the proportion of probability mass samples to take, + * e.g., 0.1 means taking the top 10% probability mass samples. + * The determinism between the samples is basically consistent. + * Among these results, the `top_p` probability mass results are taken. + * When in use, choose either `temperature` or `top_p`. + * Default is 1. + */ + top_p: number + /** When enabled, the Completion Text will concatenate the Prompt content together and return it. */ + echo: boolean + /** + * Specify up to 4 to automatically stop generating before the text specified in `stop`. + * Suitable for use in chat mode. + * For example, specify "Q" and "A", + * and provide some Q&A examples as context, + * and the model will give out in Q&A format and stop generating before Q&A. + */ + stop: string[] + /** + * A number between -2.0 and 2.0. + * The larger the value, the less the model will repeat topics and the more it will provide new topics. + */ + presence_penalty: number + /** + * A number between -2.0 and 2.0. + * A lower setting will make the model appear less cultured, + * always repeating expressions. + * The difference between `frequency_penalty` and `presence_penalty` + * is that `frequency_penalty` penalizes a word based on its frequency in the training data, + * while `presence_penalty` penalizes a word based on its occurrence in the input text. + */ + frequency_penalty: number + } + } +} + +export const LanguagesSupported = ['zh-Hans', 'en-US'] as const +export type Language = typeof LanguagesSupported[number] + +/** + * Web Application Configuration + */ +export type SiteConfig = { + /** Application URL Identifier: `http://dify.app/{access_token}` */ + access_token: string + /** Public Title */ + title: string + /** Application Description will be shown in the Client */ + description: string + /** Author */ + author: string + /** User Support Email Address */ + support_email: string + /** + * Default Language, e.g. zh-Hans, en-US + * Use standard RFC 4646, see https://www.ruanyifeng.com/blog/2008/02/codes_for_language_names.html + */ + default_language: Language + /** Custom Domain */ + customize_domain: string + /** Theme */ + theme: string + /** Custom Token strategy Whether Terminal Users can choose their OpenAI Key */ + customize_token_strategy: 'must' | 'allow' | 'not_allow' + /** Is Prompt Public */ + prompt_public: boolean + /** Web API and APP Base Domain Name */ + app_base_url: string + /** Copyright */ + copyright: string + /** Privacy Policy */ + privacy_policy: string +} + +/** + * App + */ +export type App = { + /** App ID */ + id: string + /** Name */ + name: string + /** Mode */ + mode: AppMode + /** Enable web app */ + enable_site: boolean + /** Enable web API */ + enable_api: boolean + /** API requests per minute, default is 60 */ + api_rpm: number + /** API requests per hour, default is 3600 */ + api_rph: number + /** Whether it's a demo app */ + is_demo: boolean + /** Model configuration */ + model_config: ModelConfig + /** Timestamp of creation */ + created_at: number + /** Web Application Configuration */ + site: SiteConfig + /** api site url */ + api_base_url: string +} + +/** + * App Template + */ +export type AppTemplate = { + /** Name */ + name: string + /** Description */ + description: string + /** Mode */ + mode: AppMode + /** Model */ + model_config: ModelConfig +} diff --git a/web/typography.js b/web/typography.js new file mode 100644 index 0000000000..706e456ddd --- /dev/null +++ b/web/typography.js @@ -0,0 +1,357 @@ +module.exports = ({ theme }) => ({ + DEFAULT: { + css: { + '--tw-prose-body': theme('colors.zinc.700'), + '--tw-prose-headings': theme('colors.zinc.900'), + '--tw-prose-links': theme('colors.emerald.500'), + '--tw-prose-links-hover': theme('colors.emerald.600'), + '--tw-prose-links-underline': theme('colors.emerald.500 / 0.3'), + '--tw-prose-bold': theme('colors.zinc.900'), + '--tw-prose-counters': theme('colors.zinc.500'), + '--tw-prose-bullets': theme('colors.zinc.300'), + '--tw-prose-hr': theme('colors.zinc.900 / 0.05'), + '--tw-prose-quotes': theme('colors.zinc.900'), + '--tw-prose-quote-borders': theme('colors.zinc.200'), + '--tw-prose-captions': theme('colors.zinc.500'), + '--tw-prose-code': theme('colors.zinc.900'), + '--tw-prose-code-bg': theme('colors.zinc.100'), + '--tw-prose-code-ring': theme('colors.zinc.300'), + '--tw-prose-th-borders': theme('colors.zinc.300'), + '--tw-prose-td-borders': theme('colors.zinc.200'), + + '--tw-prose-invert-body': theme('colors.zinc.400'), + '--tw-prose-invert-headings': theme('colors.white'), + '--tw-prose-invert-links': theme('colors.emerald.400'), + '--tw-prose-invert-links-hover': theme('colors.emerald.500'), + '--tw-prose-invert-links-underline': theme('colors.emerald.500 / 0.3'), + '--tw-prose-invert-bold': theme('colors.white'), + '--tw-prose-invert-counters': theme('colors.zinc.400'), + '--tw-prose-invert-bullets': theme('colors.zinc.600'), + '--tw-prose-invert-hr': theme('colors.white / 0.05'), + '--tw-prose-invert-quotes': theme('colors.zinc.100'), + '--tw-prose-invert-quote-borders': theme('colors.zinc.700'), + '--tw-prose-invert-captions': theme('colors.zinc.400'), + '--tw-prose-invert-code': theme('colors.white'), + '--tw-prose-invert-code-bg': theme('colors.zinc.700 / 0.15'), + '--tw-prose-invert-code-ring': theme('colors.white / 0.1'), + '--tw-prose-invert-th-borders': theme('colors.zinc.600'), + '--tw-prose-invert-td-borders': theme('colors.zinc.700'), + + // Base + color: 'var(--tw-prose-body)', + fontSize: theme('fontSize.sm')[0], + lineHeight: theme('lineHeight.7'), + + // Layout + '> *': { + maxWidth: theme('maxWidth.2xl'), + marginLeft: 'auto', + marginRight: 'auto', + '@screen lg': { + maxWidth: theme('maxWidth.3xl'), + marginLeft: `calc(50% - min(50%, ${theme('maxWidth.lg')}))`, + marginRight: `calc(50% - min(50%, ${theme('maxWidth.lg')}))`, + }, + }, + + // Text + p: { + marginTop: theme('spacing.6'), + marginBottom: theme('spacing.6'), + }, + '[class~="lead"]': { + fontSize: theme('fontSize.base')[0], + ...theme('fontSize.base')[1], + }, + + // Lists + ol: { + listStyleType: 'decimal', + marginTop: theme('spacing.5'), + marginBottom: theme('spacing.5'), + paddingLeft: '1.625rem', + }, + 'ol[type="A"]': { + listStyleType: 'upper-alpha', + }, + 'ol[type="a"]': { + listStyleType: 'lower-alpha', + }, + 'ol[type="A" s]': { + listStyleType: 'upper-alpha', + }, + 'ol[type="a" s]': { + listStyleType: 'lower-alpha', + }, + 'ol[type="I"]': { + listStyleType: 'upper-roman', + }, + 'ol[type="i"]': { + listStyleType: 'lower-roman', + }, + 'ol[type="I" s]': { + listStyleType: 'upper-roman', + }, + 'ol[type="i" s]': { + listStyleType: 'lower-roman', + }, + 'ol[type="1"]': { + listStyleType: 'decimal', + }, + ul: { + listStyleType: 'disc', + marginTop: theme('spacing.5'), + marginBottom: theme('spacing.5'), + paddingLeft: '1.625rem', + }, + li: { + marginTop: theme('spacing.2'), + marginBottom: theme('spacing.2'), + }, + ':is(ol, ul) > li': { + paddingLeft: theme('spacing[1.5]'), + }, + 'ol > li::marker': { + fontWeight: '400', + color: 'var(--tw-prose-counters)', + }, + 'ul > li::marker': { + color: 'var(--tw-prose-bullets)', + }, + '> ul > li p': { + marginTop: theme('spacing.3'), + marginBottom: theme('spacing.3'), + }, + '> ul > li > *:first-child': { + marginTop: theme('spacing.5'), + }, + '> ul > li > *:last-child': { + marginBottom: theme('spacing.5'), + }, + '> ol > li > *:first-child': { + marginTop: theme('spacing.5'), + }, + '> ol > li > *:last-child': { + marginBottom: theme('spacing.5'), + }, + 'ul ul, ul ol, ol ul, ol ol': { + marginTop: theme('spacing.3'), + marginBottom: theme('spacing.3'), + }, + + // Horizontal rules + hr: { + borderColor: 'var(--tw-prose-hr)', + borderTopWidth: 1, + marginTop: theme('spacing.16'), + marginBottom: theme('spacing.16'), + maxWidth: 'none', + marginLeft: `calc(-1 * ${theme('spacing.4')})`, + marginRight: `calc(-1 * ${theme('spacing.4')})`, + '@screen sm': { + marginLeft: `calc(-1 * ${theme('spacing.6')})`, + marginRight: `calc(-1 * ${theme('spacing.6')})`, + }, + '@screen lg': { + marginLeft: `calc(-1 * ${theme('spacing.8')})`, + marginRight: `calc(-1 * ${theme('spacing.8')})`, + }, + }, + + // Quotes + blockquote: { + fontWeight: '500', + fontStyle: 'italic', + color: 'var(--tw-prose-quotes)', + borderLeftWidth: '0.25rem', + borderLeftColor: 'var(--tw-prose-quote-borders)', + quotes: '"\\201C""\\201D""\\2018""\\2019"', + marginTop: theme('spacing.8'), + marginBottom: theme('spacing.8'), + paddingLeft: theme('spacing.5'), + }, + 'blockquote p:first-of-type::before': { + content: 'open-quote', + }, + 'blockquote p:last-of-type::after': { + content: 'close-quote', + }, + + // Headings + h1: { + color: 'var(--tw-prose-headings)', + fontWeight: '700', + fontSize: theme('fontSize.2xl')[0], + ...theme('fontSize.2xl')[1], + marginBottom: theme('spacing.2'), + }, + h2: { + color: 'var(--tw-prose-headings)', + fontWeight: '600', + fontSize: theme('fontSize.lg')[0], + ...theme('fontSize.lg')[1], + marginTop: theme('spacing.16'), + marginBottom: theme('spacing.2'), + }, + h3: { + color: 'var(--tw-prose-headings)', + fontSize: theme('fontSize.base')[0], + ...theme('fontSize.base')[1], + fontWeight: '600', + marginTop: theme('spacing.10'), + marginBottom: theme('spacing.2'), + }, + + // Media + 'img, video, figure': { + marginTop: theme('spacing.8'), + marginBottom: theme('spacing.8'), + }, + 'figure > *': { + marginTop: '0', + marginBottom: '0', + }, + figcaption: { + color: 'var(--tw-prose-captions)', + fontSize: theme('fontSize.xs')[0], + ...theme('fontSize.xs')[1], + marginTop: theme('spacing.2'), + }, + + // Tables + table: { + width: '100%', + tableLayout: 'auto', + textAlign: 'left', + marginTop: theme('spacing.8'), + marginBottom: theme('spacing.8'), + lineHeight: theme('lineHeight.6'), + }, + thead: { + borderBottomWidth: '1px', + borderBottomColor: 'var(--tw-prose-th-borders)', + }, + 'thead th': { + color: 'var(--tw-prose-headings)', + fontWeight: '600', + verticalAlign: 'bottom', + paddingRight: theme('spacing.2'), + paddingBottom: theme('spacing.2'), + paddingLeft: theme('spacing.2'), + }, + 'thead th:first-child': { + paddingLeft: '0', + }, + 'thead th:last-child': { + paddingRight: '0', + }, + 'tbody tr': { + borderBottomWidth: '1px', + borderBottomColor: 'var(--tw-prose-td-borders)', + }, + 'tbody tr:last-child': { + borderBottomWidth: '0', + }, + 'tbody td': { + verticalAlign: 'baseline', + }, + tfoot: { + borderTopWidth: '1px', + borderTopColor: 'var(--tw-prose-th-borders)', + }, + 'tfoot td': { + verticalAlign: 'top', + }, + ':is(tbody, tfoot) td': { + paddingTop: theme('spacing.2'), + paddingRight: theme('spacing.2'), + paddingBottom: theme('spacing.2'), + paddingLeft: theme('spacing.2'), + }, + ':is(tbody, tfoot) td:first-child': { + paddingLeft: '0', + }, + ':is(tbody, tfoot) td:last-child': { + paddingRight: '0', + }, + + // Inline elements + a: { + color: 'var(--tw-prose-links)', + textDecoration: 'underline transparent', + fontWeight: '500', + transitionProperty: 'color, text-decoration-color', + transitionDuration: theme('transitionDuration.DEFAULT'), + transitionTimingFunction: theme('transitionTimingFunction.DEFAULT'), + '&:hover': { + color: 'var(--tw-prose-links-hover)', + textDecorationColor: 'var(--tw-prose-links-underline)', + }, + }, + ':is(h1, h2, h3) a': { + fontWeight: 'inherit', + }, + strong: { + color: 'var(--tw-prose-bold)', + fontWeight: '600', + }, + ':is(a, blockquote, thead th) strong': { + color: 'inherit', + }, + code: { + color: 'var(--tw-prose-code)', + borderRadius: theme('borderRadius.lg'), + paddingTop: theme('padding.1'), + paddingRight: theme('padding[1.5]'), + paddingBottom: theme('padding.1'), + paddingLeft: theme('padding[1.5]'), + boxShadow: 'inset 0 0 0 1px var(--tw-prose-code-ring)', + backgroundColor: 'var(--tw-prose-code-bg)', + fontSize: theme('fontSize.2xs'), + }, + ':is(a, h1, h2, h3, blockquote, thead th) code': { + color: 'inherit', + }, + 'h2 code': { + fontSize: theme('fontSize.base')[0], + fontWeight: 'inherit', + }, + 'h3 code': { + fontSize: theme('fontSize.sm')[0], + fontWeight: 'inherit', + }, + + // Overrides + ':is(h1, h2, h3) + *': { + marginTop: '0', + }, + '> :first-child': { + marginTop: '0 !important', + }, + '> :last-child': { + marginBottom: '0 !important', + }, + }, + }, + invert: { + css: { + '--tw-prose-body': 'var(--tw-prose-invert-body)', + '--tw-prose-headings': 'var(--tw-prose-invert-headings)', + '--tw-prose-links': 'var(--tw-prose-invert-links)', + '--tw-prose-links-hover': 'var(--tw-prose-invert-links-hover)', + '--tw-prose-links-underline': 'var(--tw-prose-invert-links-underline)', + '--tw-prose-bold': 'var(--tw-prose-invert-bold)', + '--tw-prose-counters': 'var(--tw-prose-invert-counters)', + '--tw-prose-bullets': 'var(--tw-prose-invert-bullets)', + '--tw-prose-hr': 'var(--tw-prose-invert-hr)', + '--tw-prose-quotes': 'var(--tw-prose-invert-quotes)', + '--tw-prose-quote-borders': 'var(--tw-prose-invert-quote-borders)', + '--tw-prose-captions': 'var(--tw-prose-invert-captions)', + '--tw-prose-code': 'var(--tw-prose-invert-code)', + '--tw-prose-code-bg': 'var(--tw-prose-invert-code-bg)', + '--tw-prose-code-ring': 'var(--tw-prose-invert-code-ring)', + '--tw-prose-th-borders': 'var(--tw-prose-invert-th-borders)', + '--tw-prose-td-borders': 'var(--tw-prose-invert-td-borders)', + }, + }, +}) diff --git a/web/utils/format.ts b/web/utils/format.ts new file mode 100644 index 0000000000..fc30c62369 --- /dev/null +++ b/web/utils/format.ts @@ -0,0 +1,33 @@ +/* +* Formats a number with comma separators. + formatNumber(1234567) will return '1,234,567' + formatNumber(1234567.89) will return '1,234,567.89' +*/ +export const formatNumber = (num: number | string) => { + if (!num) return num; + let parts = num.toString().split("."); + parts[0] = parts[0].replace(/\B(?=(\d{3})+(?!\d))/g, ","); + return parts.join("."); +} + +export const formatFileSize = (num: number) => { + if (!num) return num; + const units = ['', 'K', 'M', 'G', 'T', 'P']; + let index = 0; + while (num >= 1024 && index < units.length) { + num = num / 1024; + index++; + } + return num.toFixed(2) + `${units[index]}B`; +} + +export const formatTime = (num: number) => { + if (!num) return num; + const units = ['sec', 'min', 'h']; + let index = 0; + while (num >= 60 && index < units.length) { + num = num / 60; + index++; + } + return `${num.toFixed(2)} ${units[index]}`; +} diff --git a/web/utils/index.ts b/web/utils/index.ts new file mode 100644 index 0000000000..9965165dcf --- /dev/null +++ b/web/utils/index.ts @@ -0,0 +1,24 @@ +export const sleep = (ms: number) => { + return new Promise(resolve => setTimeout(resolve, ms)) +} + +export async function asyncRunSafe(fn: Promise): Promise<[Error] | [null, T]> { + try { + return [null, await fn] + } + catch (e) { + if (e instanceof Error) + return [e] + return [new Error('unknown error')] + } +} + +export const getTextWidthWithCanvas = (text: string, font?: string) => { + const canvas = document.createElement('canvas'); + const ctx = canvas.getContext('2d'); + if (ctx) { + ctx.font = font ?? '12px Inter, ui-sans-serif, system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, "Noto Sans", sans-serif, "Apple Color Emoji", "Segoe UI Emoji", "Segoe UI Symbol", "Noto Color Emoji"'; + return Number(ctx.measureText(text).width.toFixed(2)); + } + return 0; +} diff --git a/web/utils/language.ts b/web/utils/language.ts new file mode 100644 index 0000000000..c24c1267c4 --- /dev/null +++ b/web/utils/language.ts @@ -0,0 +1,19 @@ +type Item = { + value: number | string + name: string +} +export const languages: Item[] = [ + { + value: 'en-US', + name: 'English(United States)', + }, + { + value: 'zh-Hans', + name: '简体中文', + }, +] + +export const languageMaps = { + 'en': 'en-US', + 'zh-Hans': 'zh-Hans', +} diff --git a/web/utils/model-config.ts b/web/utils/model-config.ts new file mode 100644 index 0000000000..1c7594b9fe --- /dev/null +++ b/web/utils/model-config.ts @@ -0,0 +1,63 @@ +import { UserInputFormItem, } from '@/types/app' +import { PromptVariable } from '@/models/debug' + +export const userInputsFormToPromptVariables = (useInputs: UserInputFormItem[] | null) => { + if (!useInputs) return [] + const promptVariables: PromptVariable[] = [] + useInputs.forEach((item: any) => { + const type = item['text-input'] ? 'string' : 'select' + const content = type === 'string' ? item['text-input'] : item['select'] + if (type === 'string') { + promptVariables.push({ + key: content.variable, + name: content.label, + required: content.required, + type: 'string', + max_length: content.max_length, + options: [], + }) + } else { + promptVariables.push({ + key: content.variable, + name: content.label, + required: content.required, + type: 'select', + options: content.options, + }) + } + }) + return promptVariables +} + +export const promptVariablesToUserInputsForm = (promptVariables: PromptVariable[]) => { + const userInputs: UserInputFormItem[] = [] + promptVariables.filter(({ key, name }) => { + if (key && key.trim() && name && name.trim()) { + return true + } + return false + }).forEach((item: any) => { + if (item.type === 'string') { + userInputs.push({ + 'text-input': { + label: item.name, + variable: item.key, + required: item.required === false ? false : true, // default true + max_length: item.max_length, + default: '' + }, + } as any) + } else { + userInputs.push({ + 'select': { + label: item.name, + variable: item.key, + required: item.required === false ? false : true, // default true + options: item.options, + default: '' + }, + } as any) + } + }) + return userInputs +} diff --git a/web/utils/timezone.ts b/web/utils/timezone.ts new file mode 100644 index 0000000000..b408008db3 --- /dev/null +++ b/web/utils/timezone.ts @@ -0,0 +1,330 @@ +type Item = { + value: number | string + name: string +} +export const timezones: Item[] = [ + { + value: 'Pacific/Midway', + name: '(GMT-11:00) Midway Island, Samoa', + }, + { + value: 'Pacific/Honolulu', + name: '(GMT-10:00) Hawaii', + }, + { + value: 'America/Juneau', + name: '(GMT-8:00) Alaska', + }, + { + value: 'America/Dawson', + name: '(GMT-7:00) Dawson, Yukon', + }, + { + value: 'America/Chihuahua', + name: '(GMT-7:00) Chihuahua, La Paz, Mazatlan', + }, + { + value: 'America/Phoenix', + name: '(GMT-7:00) Arizona', + }, + { + value: 'America/Tijuana', + name: '(GMT-7:00) Tijuana', + }, + { + value: 'America/Los_Angeles', + name: '(GMT-7:00) Pacific Time', + }, + { + value: 'America/Boise', + name: '(GMT-6:00) Mountain Time', + }, + { + value: 'America/Regina', + name: '(GMT-6:00) Saskatchewan', + }, + { + value: 'America/Mexico_City', + name: '(GMT-6:00) Guadalajara, Mexico City, Monterrey', + }, + { + value: 'America/Belize', + name: '(GMT-6:00) Central America', + }, + { + value: 'America/Chicago', + name: '(GMT-5:00) Central Time', + }, + { + value: 'America/Bogota', + name: '(GMT-5:00) Bogota, Lima, Quito', + }, + { + value: 'America/Lima', + name: '(GMT-5:00) Pittsburgh', + }, + { + value: 'America/Detroit', + name: '(GMT-4:00) Eastern Time', + }, + { + value: 'America/Caracas', + name: '(GMT-4:00) Caracas, La Paz', + }, + { + value: 'America/Santiago', + name: '(GMT-3:00) Santiago', + }, + { + value: 'America/Sao_Paulo', + name: '(GMT-3:00) Brasilia', + }, + { + value: 'America/Montevideo', + name: '(GMT-3:00) Montevideo', + }, + { + value: 'America/Argentina/Buenos_Aires', + name: '(GMT-3:00) Buenos Aires, Georgetown', + }, + { + value: 'America/St_Johns', + name: '(GMT-2:30) Newfoundland and Labrador', + }, + { + value: 'America/Godthab', + name: '(GMT-2:00) Greenland', + }, + { + value: 'Atlantic/Cape_Verde', + name: '(GMT-1:00) Cape Verde Islands', + }, + { + value: 'Atlantic/Azores', + name: '(GMT+0:00) Azores', + }, + { + value: 'Etc/GMT', + name: '(GMT+0:00) UTC', + }, + { + value: 'Africa/Casablanca', + name: '(GMT+0:00) Casablanca, Monrovia', + }, + { + value: 'Europe/London', + name: '(GMT+1:00) Edinburgh, London', + }, + { + value: 'Europe/Dublin', + name: '(GMT+1:00) Dublin', + }, + { + value: 'Europe/Lisbon', + name: '(GMT+1:00) Lisbon', + }, + { + value: 'Atlantic/Canary', + name: '(GMT+1:00) Canary Islands', + }, + { + value: 'Africa/Algiers', + name: '(GMT+1:00) West Central Africa', + }, + { + value: 'Europe/Belgrade', + name: '(GMT+2:00) Belgrade, Bratislava, Budapest, Ljubljana, Prague', + }, + { + value: 'Europe/Sarajevo', + name: '(GMT+2:00) Sarajevo, Skopje, Warsaw, Zagreb', + }, + { + value: 'Europe/Brussels', + name: '(GMT+2:00) Brussels, Copenhagen, Madrid, Paris', + }, + { + value: 'Europe/Amsterdam', + name: '(GMT+2:00) Amsterdam, Berlin, Bern, Rome, Stockholm, Vienna', + }, + { + value: 'Africa/Cairo', + name: '(GMT+2:00) Cairo', + }, + { + value: 'Africa/Harare', + name: '(GMT+2:00) Harare, Pretoria', + }, + { + value: 'Europe/Berlin', + name: '(GMT+2:00) Frankfurt', + }, + { + value: 'Europe/Bucharest', + name: '(GMT+3:00) Bucharest', + }, + { + value: 'Europe/Helsinki', + name: '(GMT+3:00) Helsinki, Kyiv, Riga, Sofia, Tallinn, Vilnius', + }, + { + value: 'Europe/Athens', + name: '(GMT+3:00) Athens, Minsk', + }, + { + value: 'Asia/Jerusalem', + name: '(GMT+3:00) Jerusalem', + }, + { + value: 'Europe/Moscow', + name: '(GMT+3:00) Istanbul, Moscow, St. Petersburg, Volgograd', + }, + { + value: 'Asia/Kuwait', + name: '(GMT+3:00) Kuwait, Riyadh', + }, + { + value: 'Africa/Nairobi', + name: '(GMT+3:00) Nairobi', + }, + { + value: 'Asia/Baghdad', + name: '(GMT+3:00) Baghdad', + }, + { + value: 'Asia/Dubai', + name: '(GMT+4:00) Abu Dhabi, Muscat', + }, + { + value: 'Asia/Tehran', + name: '(GMT+4:30) Tehran', + }, + { + value: 'Asia/Kabul', + name: '(GMT+4:30) Kabul', + }, + { + value: 'Asia/Baku', + name: '(GMT+5:00) Baku, Tbilisi, Yerevan', + }, + { + value: 'Asia/Yekaterinburg', + name: '(GMT+5:00) Ekaterinburg', + }, + { + value: 'Asia/Karachi', + name: '(GMT+5:00) Islamabad, Karachi, Tashkent', + }, + { + value: 'Asia/Kolkata', + name: '(GMT+5:30) Chennai, Kolkata, Mumbai, New Delhi', + }, + { + value: 'Asia/Colombo', + name: '(GMT+5:30) Sri Jayawardenepura', + }, + { + value: 'Asia/Kathmandu', + name: '(GMT+5:45) Kathmandu', + }, + { + value: 'Asia/Dhaka', + name: '(GMT+6:00) Astana, Dhaka', + }, + { + value: 'Asia/Almaty', + name: '(GMT+6:00) Almaty, Novosibirsk', + }, + { + value: 'Asia/Rangoon', + name: '(GMT+6:30) Yangon Rangoon', + }, + { + value: 'Asia/Bangkok', + name: '(GMT+7:00) Bangkok, Hanoi, Jakarta', + }, + { + value: 'Asia/Krasnoyarsk', + name: '(GMT+7:00) Krasnoyarsk', + }, + { + value: 'Asia/Shanghai', + name: '(GMT+8:00) Beijing, Chongqing, Hong Kong SAR, Urumqi', + }, + { + value: 'Asia/Kuala_Lumpur', + name: '(GMT+8:00) Kuala Lumpur, Singapore', + }, + { + value: 'Asia/Taipei', + name: '(GMT+8:00) Taipei', + }, + { + value: 'Australia/Perth', + name: '(GMT+8:00) Perth', + }, + { + value: 'Asia/Irkutsk', + name: '(GMT+8:00) Irkutsk, Ulaanbaatar', + }, + { + value: 'Asia/Seoul', + name: '(GMT+9:00) Seoul', + }, + { + value: 'Asia/Tokyo', + name: '(GMT+9:00) Osaka, Sapporo, Tokyo', + }, + { + value: 'Australia/Darwin', + name: '(GMT+9:30) Darwin', + }, + { + value: 'Asia/Yakutsk', + name: '(GMT+10:00) Yakutsk', + }, + { + value: 'Australia/Brisbane', + name: '(GMT+10:00) Brisbane', + }, + { + value: 'Asia/Vladivostok', + name: '(GMT+10:00) Vladivostok', + }, + { + value: 'Pacific/Guam', + name: '(GMT+10:00) Guam, Port Moresby', + }, + { + value: 'Australia/Adelaide', + name: '(GMT+10:30) Adelaide', + }, + { + value: 'Australia/Sydney', + name: '(GMT+11:00) Canberra, Melbourne, Sydney', + }, + { + value: 'Australia/Hobart', + name: '(GMT+11:00) Hobart', + }, + { + value: 'Asia/Magadan', + name: '(GMT+11:00) Magadan, Solomon Islands, New Caledonia', + }, + { + value: 'Asia/Kamchatka', + name: '(GMT+12:00) Kamchatka, Marshall Islands', + }, + { + value: 'Pacific/Fiji', + name: '(GMT+12:00) Fiji Islands', + }, + { + value: 'Pacific/Auckland', + name: '(GMT+13:00) Auckland, Wellington', + }, + { + value: 'Pacific/Tongatapu', + name: '(GMT+13:00) Nuku\'alofa', + }, +] diff --git a/web/utils/var.ts b/web/utils/var.ts new file mode 100644 index 0000000000..44e2a7b93e --- /dev/null +++ b/web/utils/var.ts @@ -0,0 +1,47 @@ +import { VAR_ITEM_TEMPLATE, getMaxVarNameLength, zhRegex, emojiRegex, MAX_VAR_KEY_LENGHT } from "@/config" +const otherAllowedRegex = new RegExp(`^[a-zA-Z0-9_]+$`) + +export const getNewVar = (key: string) => { + return { + ...VAR_ITEM_TEMPLATE, + key, + name: key.slice(0, getMaxVarNameLength(key)), + } +} + +const checkKey = (key: string, canBeEmpty?: boolean) => { + if (key.length === 0 && !canBeEmpty) { + return 'canNoBeEmpty' + } + if (canBeEmpty && key === '') { + return true + } + if (key.length > MAX_VAR_KEY_LENGHT) { + return 'tooLong' + } + if (otherAllowedRegex.test(key)) { + if (/[0-9]/.test(key[0])) { + return 'notStartWithNumber' + } + return true + } + return 'notValid' +} + +export const checkKeys = (keys: string[], canBeEmpty?: boolean) => { + let isValid = true + let errorKey = '' + let errorMessageKey = '' + keys.forEach((key) => { + if (!isValid) { + return + } + const res = checkKey(key, canBeEmpty) + if (res !== true) { + isValid = false + errorKey = key + errorMessageKey = res + } + }) + return { isValid, errorKey, errorMessageKey } +} \ No newline at end of file