add segment function billing check for SAAS env (#3082)

Co-authored-by: jyong <jyong@dify.ai>
This commit is contained in:
Jyong 2024-04-02 17:55:49 +08:00 committed by GitHub
parent 9c7e99e829
commit e12a0c154c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 62 additions and 9 deletions

View File

@ -12,7 +12,11 @@ from controllers.console import api
from controllers.console.app.error import ProviderNotInitializeError from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import InvalidActionError, NoFileUploadedError, TooManyFilesError from controllers.console.datasets.error import InvalidActionError, NoFileUploadedError, TooManyFilesError
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_knowledge_limit_check,
cloud_edition_billing_resource_check,
)
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
@ -207,6 +211,7 @@ class DatasetDocumentSegmentAddApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check('vector_space') @cloud_edition_billing_resource_check('vector_space')
@cloud_edition_billing_knowledge_limit_check('add_segment')
def post(self, dataset_id, document_id): def post(self, dataset_id, document_id):
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
@ -357,6 +362,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check('vector_space') @cloud_edition_billing_resource_check('vector_space')
@cloud_edition_billing_knowledge_limit_check('add_segment')
def post(self, dataset_id, document_id): def post(self, dataset_id, document_id):
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)

View File

@ -51,14 +51,12 @@ def cloud_edition_billing_resource_check(resource: str,
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args, **kwargs):
features = FeatureService.get_features(current_user.current_tenant_id) features = FeatureService.get_features(current_user.current_tenant_id)
if features.billing.enabled: if features.billing.enabled:
members = features.members members = features.members
apps = features.apps apps = features.apps
vector_space = features.vector_space vector_space = features.vector_space
documents_upload_quota = features.documents_upload_quota documents_upload_quota = features.documents_upload_quota
annotation_quota_limit = features.annotation_quota_limit annotation_quota_limit = features.annotation_quota_limit
if resource == 'members' and 0 < members.limit <= members.size: if resource == 'members' and 0 < members.limit <= members.size:
abort(403, error_msg) abort(403, error_msg)
elif resource == 'apps' and 0 < apps.limit <= apps.size: elif resource == 'apps' and 0 < apps.limit <= apps.size:
@ -80,7 +78,29 @@ def cloud_edition_billing_resource_check(resource: str,
return view(*args, **kwargs) return view(*args, **kwargs)
return view(*args, **kwargs) return view(*args, **kwargs)
return decorated return decorated
return interceptor
def cloud_edition_billing_knowledge_limit_check(resource: str,
error_msg: str = "To unlock this feature and elevate your Dify experience, please upgrade to a paid plan."):
def interceptor(view):
@wraps(view)
def decorated(*args, **kwargs):
features = FeatureService.get_features(current_user.current_tenant_id)
if features.billing.enabled:
if resource == 'add_segment':
if features.billing.subscription.plan == 'sandbox':
abort(403, error_msg)
else:
return view(*args, **kwargs)
return view(*args, **kwargs)
return decorated
return interceptor return interceptor
@ -99,4 +119,5 @@ def cloud_utm_record(view):
except Exception as e: except Exception as e:
pass pass
return view(*args, **kwargs) return view(*args, **kwargs)
return decorated return decorated

View File

@ -4,7 +4,11 @@ from werkzeug.exceptions import NotFound
from controllers.service_api import api from controllers.service_api import api
from controllers.service_api.app.error import ProviderNotInitializeError from controllers.service_api.app.error import ProviderNotInitializeError
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check from controllers.service_api.wraps import (
DatasetApiResource,
cloud_edition_billing_knowledge_limit_check,
cloud_edition_billing_resource_check,
)
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
@ -18,6 +22,7 @@ class SegmentApi(DatasetApiResource):
"""Resource for segments.""" """Resource for segments."""
@cloud_edition_billing_resource_check('vector_space', 'dataset') @cloud_edition_billing_resource_check('vector_space', 'dataset')
@cloud_edition_billing_knowledge_limit_check('add_segment', 'dataset')
def post(self, tenant_id, dataset_id, document_id): def post(self, tenant_id, dataset_id, document_id):
"""Create single segment.""" """Create single segment."""
# check dataset # check dataset

View File

@ -8,7 +8,7 @@ from flask import current_app, request
from flask_login import user_logged_in from flask_login import user_logged_in
from flask_restful import Resource from flask_restful import Resource
from pydantic import BaseModel from pydantic import BaseModel
from werkzeug.exceptions import NotFound, Unauthorized from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
from extensions.ext_database import db from extensions.ext_database import db
from libs.login import _get_user from libs.login import _get_user
@ -92,13 +92,13 @@ def cloud_edition_billing_resource_check(resource: str,
documents_upload_quota = features.documents_upload_quota documents_upload_quota = features.documents_upload_quota
if resource == 'members' and 0 < members.limit <= members.size: if resource == 'members' and 0 < members.limit <= members.size:
raise Unauthorized(error_msg) raise Forbidden(error_msg)
elif resource == 'apps' and 0 < apps.limit <= apps.size: elif resource == 'apps' and 0 < apps.limit <= apps.size:
raise Unauthorized(error_msg) raise Forbidden(error_msg)
elif resource == 'vector_space' and 0 < vector_space.limit <= vector_space.size: elif resource == 'vector_space' and 0 < vector_space.limit <= vector_space.size:
raise Unauthorized(error_msg) raise Forbidden(error_msg)
elif resource == 'documents' and 0 < documents_upload_quota.limit <= documents_upload_quota.size: elif resource == 'documents' and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
raise Unauthorized(error_msg) raise Forbidden(error_msg)
else: else:
return view(*args, **kwargs) return view(*args, **kwargs)
@ -107,6 +107,27 @@ def cloud_edition_billing_resource_check(resource: str,
return interceptor return interceptor
def cloud_edition_billing_knowledge_limit_check(resource: str,
api_token_type: str,
error_msg: str = "To unlock this feature and elevate your Dify experience, please upgrade to a paid plan."):
def interceptor(view):
@wraps(view)
def decorated(*args, **kwargs):
api_token = validate_and_get_api_token(api_token_type)
features = FeatureService.get_features(api_token.tenant_id)
if features.billing.enabled:
if resource == 'add_segment':
if features.billing.subscription.plan == 'sandbox':
raise Forbidden(error_msg)
else:
return view(*args, **kwargs)
return view(*args, **kwargs)
return decorated
return interceptor
def validate_dataset_token(view=None): def validate_dataset_token(view=None):
def decorator(view): def decorator(view):
@wraps(view) @wraps(view)