From c2ce2f88c7cb35b07de9678471a77933eb991fb3 Mon Sep 17 00:00:00 2001 From: Garfield Dai Date: Fri, 15 Nov 2024 17:59:36 +0800 Subject: [PATCH] feat: add license. (#10403) --- api/controllers/console/app/app.py | 3 ++ api/controllers/console/datasets/datasets.py | 3 +- api/controllers/console/error.py | 6 ++++ api/controllers/console/workspace/account.py | 3 +- .../console/workspace/tool_providers.py | 3 +- api/controllers/console/wraps.py | 16 ++++++++-- api/services/feature_service.py | 31 +++++++++++++++++++ 7 files changed, 60 insertions(+), 5 deletions(-) diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 36338cbd8a..5a4cd7684f 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -9,6 +9,7 @@ from controllers.console.app.wraps import get_app_model from controllers.console.wraps import ( account_initialization_required, cloud_edition_billing_resource_check, + enterprise_license_required, setup_required, ) from core.ops.ops_trace_manager import OpsTraceManager @@ -28,6 +29,7 @@ class AppListApi(Resource): @setup_required @login_required @account_initialization_required + @enterprise_license_required def get(self): """Get app list""" @@ -149,6 +151,7 @@ class AppApi(Resource): @setup_required @login_required @account_initialization_required + @enterprise_license_required @get_app_model @marshal_with(app_detail_fields_with_site) def get(self, app_model): diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 82163a32ee..95d4013e3a 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -10,7 +10,7 @@ from controllers.console import api from controllers.console.apikey import api_key_fields, api_key_list from controllers.console.app.error import ProviderNotInitializeError from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError -from controllers.console.wraps import account_initialization_required, setup_required +from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.indexing_runner import IndexingRunner from core.model_runtime.entities.model_entities import ModelType @@ -44,6 +44,7 @@ class DatasetListApi(Resource): @setup_required @login_required @account_initialization_required + @enterprise_license_required def get(self): page = request.args.get("page", default=1, type=int) limit = request.args.get("limit", default=20, type=int) diff --git a/api/controllers/console/error.py b/api/controllers/console/error.py index e0630ca66c..61561d56c8 100644 --- a/api/controllers/console/error.py +++ b/api/controllers/console/error.py @@ -86,3 +86,9 @@ class NoFileUploadedError(BaseHTTPException): error_code = "no_file_uploaded" description = "Please upload your file." code = 400 + + +class UnauthorizedAndForceLogout(BaseHTTPException): + error_code = "unauthorized_and_force_logout" + description = "Unauthorized and force logout." + code = 401 diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index aabc417759..750f65168f 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -14,7 +14,7 @@ from controllers.console.workspace.error import ( InvalidInvitationCodeError, RepeatPasswordNotMatchError, ) -from controllers.console.wraps import account_initialization_required, setup_required +from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required from extensions.ext_database import db from fields.member_fields import account_fields from libs.helper import TimestampField, timezone @@ -79,6 +79,7 @@ class AccountProfileApi(Resource): @login_required @account_initialization_required @marshal_with(account_fields) + @enterprise_license_required def get(self): return current_user diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index daadb85d84..9ecda2126d 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -7,7 +7,7 @@ from werkzeug.exceptions import Forbidden from configs import dify_config from controllers.console import api -from controllers.console.wraps import account_initialization_required, setup_required +from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required from core.model_runtime.utils.encoders import jsonable_encoder from libs.helper import alphanumeric, uuid_value from libs.login import login_required @@ -549,6 +549,7 @@ class ToolLabelsApi(Resource): @setup_required @login_required @account_initialization_required + @enterprise_license_required def get(self): return jsonable_encoder(ToolLabelsService.list_tool_labels()) diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index 9f294cb93c..d0df296c24 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -8,10 +8,10 @@ from flask_login import current_user from configs import dify_config from controllers.console.workspace.error import AccountNotInitializedError from models.model import DifySetup -from services.feature_service import FeatureService +from services.feature_service import FeatureService, LicenseStatus from services.operation_service import OperationService -from .error import NotInitValidateError, NotSetupError +from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout def account_initialization_required(view): @@ -142,3 +142,15 @@ def setup_required(view): return view(*args, **kwargs) return decorated + + +def enterprise_license_required(view): + @wraps(view) + def decorated(*args, **kwargs): + settings = FeatureService.get_system_features() + if settings.license.status in [LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST]: + raise UnauthorizedAndForceLogout("Your license is invalid. Please contact your administrator.") + + return view(*args, **kwargs) + + return decorated diff --git a/api/services/feature_service.py b/api/services/feature_service.py index c321393bc5..d0b04628cf 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -1,3 +1,5 @@ +from enum import Enum + from pydantic import BaseModel, ConfigDict from configs import dify_config @@ -20,6 +22,20 @@ class LimitationModel(BaseModel): limit: int = 0 +class LicenseStatus(str, Enum): + NONE = "none" + INACTIVE = "inactive" + ACTIVE = "active" + EXPIRING = "expiring" + EXPIRED = "expired" + LOST = "lost" + + +class LicenseModel(BaseModel): + status: LicenseStatus = LicenseStatus.NONE + expired_at: str = "" + + class FeatureModel(BaseModel): billing: BillingModel = BillingModel() members: LimitationModel = LimitationModel(size=0, limit=1) @@ -47,6 +63,7 @@ class SystemFeatureModel(BaseModel): enable_social_oauth_login: bool = False is_allow_register: bool = False is_allow_create_workspace: bool = False + license: LicenseModel = LicenseModel() class FeatureService: @@ -131,17 +148,31 @@ class FeatureService: if "sso_enforced_for_signin" in enterprise_info: features.sso_enforced_for_signin = enterprise_info["sso_enforced_for_signin"] + if "sso_enforced_for_signin_protocol" in enterprise_info: features.sso_enforced_for_signin_protocol = enterprise_info["sso_enforced_for_signin_protocol"] + if "sso_enforced_for_web" in enterprise_info: features.sso_enforced_for_web = enterprise_info["sso_enforced_for_web"] + if "sso_enforced_for_web_protocol" in enterprise_info: features.sso_enforced_for_web_protocol = enterprise_info["sso_enforced_for_web_protocol"] + if "enable_email_code_login" in enterprise_info: features.enable_email_code_login = enterprise_info["enable_email_code_login"] + if "enable_email_password_login" in enterprise_info: features.enable_email_password_login = enterprise_info["enable_email_password_login"] + if "is_allow_register" in enterprise_info: features.is_allow_register = enterprise_info["is_allow_register"] + if "is_allow_create_workspace" in enterprise_info: features.is_allow_create_workspace = enterprise_info["is_allow_create_workspace"] + + if "license" in enterprise_info: + if "status" in enterprise_info["license"]: + features.license.status = enterprise_info["license"]["status"] + + if "expired_at" in enterprise_info["license"]: + features.license.expired_at = enterprise_info["license"]["expired_at"]