mirror of
https://github.com/langgenius/dify.git
synced 2024-11-15 19:22:36 +08:00
feat: reconstruct verification logic
This commit is contained in:
parent
21b8f26cd1
commit
e5c9f821ab
|
@ -599,9 +599,9 @@ class PositionConfig(BaseSettings):
|
|||
|
||||
|
||||
class VerificationConfig(BaseSettings):
|
||||
VERIFICATION_CODE_EXPIRY: PositiveInt = Field(
|
||||
description="Duration in seconds for which a verification code remains valid",
|
||||
default=300,
|
||||
VERIFICATION_CODE_EXPIRY_MINUTES: PositiveInt = Field(
|
||||
description="Duration in minutes for which a verification code remains valid",
|
||||
default=5,
|
||||
)
|
||||
|
||||
VERIFICATION_CODE_LENGTH: PositiveInt = Field(
|
||||
|
@ -609,9 +609,9 @@ class VerificationConfig(BaseSettings):
|
|||
default=6,
|
||||
)
|
||||
|
||||
VERIFICATION_CODE_COOLDOWN: PositiveInt = Field(
|
||||
description="Cooldown time in seconds between verification code generation",
|
||||
default=60,
|
||||
VERIFICATION_CODE_COOLDOWN_MINUTES: PositiveInt = Field(
|
||||
description="Cooldown time in minutes between verification code generation",
|
||||
default=1,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -248,17 +248,20 @@ class AccountDeleteVerifyApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
def post(self):
|
||||
account = current_user
|
||||
|
||||
try:
|
||||
code = VerificationService.generate_account_deletion_verification_code(account.email)
|
||||
code = VerificationService.generate_account_deletion_verification_code(account)
|
||||
AccountService.send_account_delete_verification_email(account, code)
|
||||
except RateLimitExceededError:
|
||||
return {"result": "fail", "error": "Rate limit exceeded."}, 429
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
class AccountDeleteApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
@ -267,13 +270,29 @@ class AccountDeleteVerifyApi(Resource):
|
|||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("code", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
if not VerificationService.verify_account_deletion_verification_code(account, args["code"]):
|
||||
return {"result": "fail", "error": "Verification code is invalid."}, 400
|
||||
|
||||
AccountService.delete_account(account)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("reason", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
if not VerificationService.verify_account_deletion_verification_code(account.email, args["code"]):
|
||||
return {"result": "fail", "error": "Verification code is invalid."}, 400
|
||||
|
||||
AccountService.delete_account(account, args["reason"], args["code"])
|
||||
try:
|
||||
AccountService.update_deletion_reason(account, args["reason"])
|
||||
except ValueError as e:
|
||||
return {"result": "fail", "error": str(e)}, 400
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
@ -288,6 +307,7 @@ api.add_resource(AccountInterfaceThemeApi, "/account/interface-theme")
|
|||
api.add_resource(AccountTimezoneApi, "/account/timezone")
|
||||
api.add_resource(AccountPasswordApi, "/account/password")
|
||||
api.add_resource(AccountIntegrateApi, "/account/integrates")
|
||||
api.add_resource(AccountDeleteVerifyApi, "/account/delete-verify")
|
||||
api.add_resource(AccountDeleteVerifyApi, "/account/delete/verify")
|
||||
api.add_resource(AccountDeleteApi, "/account/delete")
|
||||
# api.add_resource(AccountEmailApi, '/account/email')
|
||||
# api.add_resource(AccountEmailVerifyApi, '/account/email-verify')
|
||||
|
|
|
@ -8,6 +8,7 @@ import time
|
|||
import uuid
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timezone as tz
|
||||
from hashlib import sha256
|
||||
from typing import Any, Optional, Union
|
||||
from zoneinfo import available_timezones
|
||||
|
@ -19,6 +20,9 @@ from flask import Response, current_app, stream_with_context
|
|||
from flask_restful import fields
|
||||
from models.account import Account
|
||||
|
||||
from api.configs import dify_config
|
||||
from api.services.errors.account import RateLimitExceededError
|
||||
|
||||
|
||||
def run(script):
|
||||
return subprocess.getstatusoutput("source /root/.bashrc && " + script)
|
||||
|
@ -269,3 +273,66 @@ class RateLimiter:
|
|||
|
||||
redis_client.zadd(key, {current_time: current_time})
|
||||
redis_client.expire(key, self.time_window * 2)
|
||||
|
||||
|
||||
def get_current_datetime():
|
||||
return datetime.now(tz.utc).replace(tzinfo=None)
|
||||
|
||||
|
||||
class VerificationCodeManager:
|
||||
@classmethod
|
||||
def generate_verification_code(cls, account: Account, code_type: str) -> str:
|
||||
# Check if this key is still in cooldown period
|
||||
now = int(time.time())
|
||||
created_at = cls._get_verification_code_created_at(code_type, account.id)
|
||||
if created_at is not None and now - created_at < dify_config.VERIFICATION_CODE_COOLDOWN_MINUTES * 60:
|
||||
raise RateLimitExceededError()
|
||||
if created_at is not None:
|
||||
cls._revoke_verification_code(code_type, account.id)
|
||||
|
||||
verification_code = generate_string(dify_config.VERIFICATION_CODE_LENGTH)
|
||||
cls._set_verification_code(code_type, account.id, verification_code, dify_config.VERIFICATION_CODE_EXPIRY_MINUTES)
|
||||
|
||||
return verification_code
|
||||
|
||||
@classmethod
|
||||
def verify_verification_code(cls, account: Account, code_type: str, verification_code: str) -> bool:
|
||||
key, _ = cls._get_key(code_type, account_id=account.id)
|
||||
stored_verification_code = redis_client.get(key)
|
||||
|
||||
if stored_verification_code is None:
|
||||
return False
|
||||
return stored_verification_code == verification_code
|
||||
|
||||
### Helper methods ###
|
||||
@classmethod
|
||||
def _set_verification_code(cls, code_type: str, account_id: str, verification_code: str, expire_minutes: int) -> None:
|
||||
key, time_key = cls._get_key(code_type, account_id)
|
||||
now = int(time.time())
|
||||
|
||||
redis_client.setex(key, expire_minutes * 60, verification_code)
|
||||
redis_client.setex(time_key, expire_minutes * 60, now)
|
||||
|
||||
@classmethod
|
||||
def _get_verification_code(cls, code_type: str, account_id: str) -> Optional[str]:
|
||||
key, _ = cls._get_key(code_type, account_id)
|
||||
verification_code = redis_client.get(key)
|
||||
|
||||
return verification_code
|
||||
|
||||
@classmethod
|
||||
def _get_verification_code_created_at(cls, code_type: str, account_id: str) -> Optional[int]:
|
||||
_, time_key = cls._get_key(code_type, account_id)
|
||||
created_at = redis_client.get(time_key)
|
||||
|
||||
return int(created_at) if created_at is not None else None
|
||||
|
||||
@classmethod
|
||||
def _revoke_verification_code(cls, code_type: str, account_id: str) -> None:
|
||||
key, time_key = cls._get_key(code_type, account_id)
|
||||
redis_client.delete(key)
|
||||
redis_client.delete(time_key)
|
||||
|
||||
@classmethod
|
||||
def _get_key(cls, code_type: str, account_id: str) -> tuple[str, str]:
|
||||
return f"verification:{code_type}:{account_id}", f"verification:{code_type}:{account_id}:time"
|
||||
|
|
|
@ -278,7 +278,7 @@ class AccountDeletionLog(db.Model):
|
|||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
account_id = db.Column(StringUUID, nullable=False)
|
||||
status = db.Column(Enum(AccountDeletionLogStatus), nullable=False, default=AccountDeletionLogStatus.PENDING)
|
||||
reason = db.Column(db.Text)
|
||||
reason = db.Column(db.Text, nullable=True)
|
||||
email = db.Column(db.String(255), nullable=False)
|
||||
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
|
|
|
@ -7,9 +7,11 @@ from extensions.ext_database import db
|
|||
from sqlalchemy import or_
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from api.libs.helper import get_current_datetime
|
||||
from api.models.account import (Account, AccountDeletionLog,
|
||||
AccountDeletionLogStatus, AccountIntegrate,
|
||||
TenantAccountJoin, TenantAccountJoinRole)
|
||||
Tenant, TenantAccountJoin,
|
||||
TenantAccountJoinRole)
|
||||
from api.models.api_based_extension import APIBasedExtension
|
||||
from api.models.dataset import (AppDatasetJoin, Dataset, DatasetPermission,
|
||||
Document, DocumentSegment)
|
||||
|
@ -18,11 +20,15 @@ from api.models.model import (ApiToken, App, AppAnnotationSetting,
|
|||
DatasetRetrieverResource, EndUser, InstalledApp,
|
||||
Message, MessageAgentThought, MessageAnnotation,
|
||||
MessageChain, MessageFeedback, MessageFile,
|
||||
RecommendedApp)
|
||||
RecommendedApp, Site, Tag, TagBinding)
|
||||
from api.models.provider import (LoadBalancingModelConfig, Provider,
|
||||
ProviderModel, ProviderModelSetting)
|
||||
ProviderModel, ProviderModelSetting,
|
||||
TenantDefaultModel,
|
||||
TenantPreferredModelProvider)
|
||||
from api.models.source import (DataSourceApiKeyAuthBinding,
|
||||
DataSourceOauthBinding)
|
||||
from api.models.tools import (ApiToolProvider, BuiltinToolProvider,
|
||||
ToolConversationVariables)
|
||||
from api.models.web import PinnedConversation, SavedMessage
|
||||
from api.tasks.mail_account_deletion_done_task import \
|
||||
send_deletion_success_task
|
||||
|
@ -82,6 +88,9 @@ def _delete_app(app: App, account_id):
|
|||
# saved_messages
|
||||
db.session.query(SavedMessage).filter(SavedMessage.app_id == app.id).delete()
|
||||
|
||||
# sites
|
||||
db.session.query(Site).filter(Site.app_id == app.id).delete()
|
||||
|
||||
db.session.delete(app)
|
||||
|
||||
|
||||
|
@ -93,6 +102,8 @@ def _delete_tenant_as_owner(tenant_account_join: TenantAccountJoin):
|
|||
"""
|
||||
tenant_id, account_id = tenant_account_join.tenant_id, tenant_account_join.account_id
|
||||
|
||||
member_ids = db.session.query(TenantAccountJoin.account_id).filter(TenantAccountJoin.tenant_id == tenant_id).all()
|
||||
|
||||
# api_based_extensions
|
||||
db.session.query(APIBasedExtension).filter(APIBasedExtension.tenant_id == tenant_id).delete()
|
||||
|
||||
|
@ -105,17 +116,19 @@ def _delete_tenant_as_owner(tenant_account_join: TenantAccountJoin):
|
|||
db.session.query(DatasetPermission).filter(DatasetPermission.tenant_id == tenant_id).delete()
|
||||
|
||||
# datasets
|
||||
dataset_ids = db.session.query(Dataset.id).filter(Dataset.tenant_id == tenant_id).all()
|
||||
db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id).delete()
|
||||
|
||||
# documents
|
||||
document_ids = db.session.query(Document.id).filter(Document.tenant_id == tenant_id).all()
|
||||
db.session.query(Document).filter(Document.tenant_id == tenant_id).delete()
|
||||
|
||||
# data_source_api_key_auth_bindings
|
||||
db.session.query(DataSourceApiKeyAuthBinding).filter(DataSourceApiKeyAuthBinding.tenant_id == tenant_id).delete()
|
||||
|
||||
# data_source_oauth_bindings
|
||||
db.session.query(DataSourceOauthBinding).filter(DataSourceOauthBinding.tenant_id == tenant_id).delete()
|
||||
|
||||
# documents
|
||||
db.session.query(Document).filter(Document.tenant_id == tenant_id).delete()
|
||||
|
||||
# document_segments
|
||||
db.session.query(DocumentSegment).filter(DocumentSegment.tenant_id == tenant_id).delete()
|
||||
|
||||
|
@ -128,14 +141,38 @@ def _delete_tenant_as_owner(tenant_account_join: TenantAccountJoin):
|
|||
# provder_model_settings
|
||||
db.session.query(ProviderModelSetting).filter(ProviderModelSetting.tenant_id == tenant_id).delete()
|
||||
|
||||
# skip provider_orders
|
||||
# skip provider_orders (TODO: confirm this)
|
||||
|
||||
# providers
|
||||
db.session.query(Provider).filter(Provider.tenant_id == tenant_id).delete()
|
||||
|
||||
# tag_bindings
|
||||
db.session.query(TagBinding).filter(TagBinding.tenant_id == tenant_id).delete()
|
||||
|
||||
# tags
|
||||
db.session.query(Tag).filter(Tag.tenant_id == tenant_id).delete()
|
||||
|
||||
# tenant_default_models
|
||||
db.session.query(TenantDefaultModel).filter(TenantDefaultModel.tenant_id == tenant_id).delete()
|
||||
|
||||
# tenant_preferred_model_providers
|
||||
db.session.query(TenantPreferredModelProvider).filter(TenantPreferredModelProvider.tenant_id == tenant_id).delete()
|
||||
|
||||
# tool_api_providers
|
||||
db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).delete()
|
||||
|
||||
# tool_built_in_providers
|
||||
db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).delete()
|
||||
|
||||
# tool_conversation_variables
|
||||
db.session.query(ToolConversationVariables).filter(ToolConversationVariables.tenant_id == tenant_id).delete()
|
||||
|
||||
# Delete all tenant_account_joins of this tenant
|
||||
db.session.query(TenantAccountJoin).filter(TenantAccountJoin.tenant_id == tenant_id).delete()
|
||||
|
||||
# Delete tenant
|
||||
db.session.query(Tenant).filter(Tenant.id == tenant_id).delete()
|
||||
|
||||
|
||||
def _delete_tenant_as_non_owner(tenant_account_join: TenantAccountJoin):
|
||||
"""Actual deletion of tenant as non-owner. Related tables will also be deleted.
|
||||
|
@ -182,10 +219,15 @@ def _delete_user(log: AccountDeletionLog, account: Account) -> bool:
|
|||
# delete account
|
||||
db.session.delete(account)
|
||||
|
||||
# update log status
|
||||
log.status = AccountDeletionLogStatus.COMPLETED
|
||||
log.updated_at = get_current_datetime()
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
db.session.rollback()
|
||||
logger.exception(click.style(f"Failed to delete account {log.account_id}, error: {e}", fg="red"))
|
||||
log.status = AccountDeletionLogStatus.FAILED
|
||||
log.updated_at = get_current_datetime()
|
||||
success = False
|
||||
finally:
|
||||
db.session.commit()
|
||||
|
@ -231,6 +273,7 @@ def delete_account_task():
|
|||
if not account:
|
||||
logger.exception(click.style(f"Account {log.account_id} not found.", fg="red"))
|
||||
log.status = AccountDeletionLogStatus.FAILED
|
||||
log.updated_at = get_current_datetime()
|
||||
db.session.commit()
|
||||
continue
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ from configs import dify_config
|
|||
from constants.languages import language_timezone_mapping, languages
|
||||
from events.tenant_event import tenant_was_created
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.helper import RateLimiter, TokenManager
|
||||
from libs.helper import RateLimiter, TokenManager, get_current_datetime
|
||||
from libs.passport import PassportService
|
||||
from libs.password import compare_password, hash_password, valid_password
|
||||
from libs.rsa import generate_key_pair
|
||||
|
@ -155,7 +155,18 @@ class AccountService:
|
|||
return account
|
||||
|
||||
@staticmethod
|
||||
def delete_account(account: Account, reason: str) -> None:
|
||||
def update_deletion_reason(account: Account, reason: str) -> None:
|
||||
"""Update deletion log reason"""
|
||||
account_deletion_log = AccountDeletionLog.query.filter_by(account_id=account.id).first()
|
||||
if not account_deletion_log:
|
||||
raise Exception("Account deletion log not found.")
|
||||
|
||||
account_deletion_log.reason = reason
|
||||
account_deletion_log.updated_at = get_current_datetime()
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
def delete_account(account: Account) -> None:
|
||||
"""Delete account. Actual deletion is done by the background scheduler."""
|
||||
logging.info(f"Start deletion of account {account.id}.")
|
||||
|
||||
|
@ -163,7 +174,6 @@ class AccountService:
|
|||
account_deletion_log = AccountDeletionLog(
|
||||
account_id=account.id,
|
||||
status=AccountDeletionLogStatus.PENDING,
|
||||
reason=reason
|
||||
)
|
||||
db.session.add(account_deletion_log)
|
||||
db.session.commit()
|
||||
|
|
|
@ -1,69 +1,20 @@
|
|||
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from api.configs import dify_config
|
||||
from api.extensions.ext_redis import redis_client
|
||||
from api.libs.helper import generate_string, generate_text_hash
|
||||
from api.services.errors.account import RateLimitExceededError
|
||||
from api.libs.helper import VerificationCodeManager
|
||||
from api.models.account import Account
|
||||
|
||||
|
||||
class VerificationService:
|
||||
|
||||
@classmethod
|
||||
def generate_account_deletion_verification_code(cls, email: str) -> str:
|
||||
return cls._generate_verification_code(
|
||||
email=email,
|
||||
prefix="account_deletion",
|
||||
expire=dify_config.VERIFICATION_CODE_EXPIRY,
|
||||
code_length=dify_config.VERIFICATION_CODE_LENGTH
|
||||
def generate_account_deletion_verification_code(cls, account: Account) -> str:
|
||||
return VerificationCodeManager.generate_verification_code(
|
||||
account=account,
|
||||
code_type="account_deletion",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def verify_account_deletion_verification_code(cls, email: str, verification_code: str) -> bool:
|
||||
return cls._verify_verification_code(
|
||||
email=email,
|
||||
prefix="account_deletion",
|
||||
verification_code=verification_code
|
||||
def verify_account_deletion_verification_code(cls, account: Account, verification_code: str) -> bool:
|
||||
return VerificationCodeManager.verify_verification_code(
|
||||
account=account,
|
||||
code_type="account_deletion",
|
||||
verification_code=verification_code,
|
||||
)
|
||||
|
||||
### Helper methods ###
|
||||
|
||||
@classmethod
|
||||
def _generate_verification_code(cls, key_name: str, prefix: str, expire: int = 300, code_length: int = 6) -> str:
|
||||
hashed_key = generate_text_hash(key_name)
|
||||
key, time_key = cls._get_key(f"{prefix}:{hashed_key}")
|
||||
now = int(time.time())
|
||||
|
||||
# Check if there is already a verification code for this key within 1 minute
|
||||
created_at = redis_client.get(time_key)
|
||||
if created_at is not None and now - created_at < dify_config.VERIFICATION_CODE_COOLDOWN:
|
||||
raise RateLimitExceededError()
|
||||
|
||||
created_at = now
|
||||
verification_code = generate_string(code_length)
|
||||
|
||||
redis_client.setex(key, expire, verification_code)
|
||||
redis_client.setex(time_key, expire, created_at)
|
||||
return verification_code
|
||||
|
||||
@classmethod
|
||||
def _get_verification_code(cls, prefix: str, key_name: str) -> Optional[str]:
|
||||
hashed_key = generate_text_hash(key_name)
|
||||
key, _ = cls._get_key(f"{prefix}:{hashed_key}")
|
||||
verification_code = redis_client.get(key)
|
||||
|
||||
return verification_code
|
||||
|
||||
@classmethod
|
||||
def _verify_verification_code(cls, key_name: str, prefix: str, verification_code: str) -> bool:
|
||||
hashed_key = generate_text_hash(key_name)
|
||||
key, _ = cls._get_key(f"{prefix}:{hashed_key}")
|
||||
stored_verification_code = redis_client.get(key)
|
||||
|
||||
if stored_verification_code is None:
|
||||
return False
|
||||
return stored_verification_code == verification_code
|
||||
|
||||
@classmethod
|
||||
def _get_key(cls, key_name: str) -> str:
|
||||
return f"verification:{key_name}", f"verification:{key_name}:time"
|
||||
|
|
Loading…
Reference in New Issue
Block a user