import datetime import pytz from flask import request from flask_login import current_user from flask_restful import Resource, fields, marshal_with, reqparse from configs import dify_config from constants.languages import supported_language from controllers.console import api from controllers.console.setup import setup_required from controllers.console.workspace.error import ( AccountAlreadyInitedError, CurrentPasswordIncorrectError, InvalidInvitationCodeError, RepeatPasswordNotMatchError, ) from controllers.console.wraps import account_initialization_required from extensions.ext_database import db from fields.member_fields import account_fields from libs.helper import TimestampField, timezone from libs.login import login_required from models.account import AccountIntegrate, InvitationCode from services.account_service import AccountService from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError class AccountInitApi(Resource): @setup_required @login_required def post(self): account = current_user if account.status == "active": raise AccountAlreadyInitedError() parser = reqparse.RequestParser() if dify_config.EDITION == "CLOUD": parser.add_argument("invitation_code", type=str, location="json") parser.add_argument("interface_language", type=supported_language, required=True, location="json") parser.add_argument("timezone", type=timezone, required=True, location="json") args = parser.parse_args() if dify_config.EDITION == "CLOUD": if not args["invitation_code"]: raise ValueError("invitation_code is required") # check invitation code invitation_code = ( db.session.query(InvitationCode) .filter( InvitationCode.code == args["invitation_code"], InvitationCode.status == "unused", ) .first() ) if not invitation_code: raise InvalidInvitationCodeError() invitation_code.status = "used" invitation_code.used_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) invitation_code.used_by_tenant_id = account.current_tenant_id invitation_code.used_by_account_id = account.id account.interface_language = args["interface_language"] account.timezone = args["timezone"] account.interface_theme = "light" account.status = "active" account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() return {"result": "success"} class AccountProfileApi(Resource): @setup_required @login_required @account_initialization_required @marshal_with(account_fields) def get(self): return current_user class AccountNameApi(Resource): @setup_required @login_required @account_initialization_required @marshal_with(account_fields) def post(self): parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=True, location="json") args = parser.parse_args() # Validate account name length if len(args["name"]) < 3 or len(args["name"]) > 30: raise ValueError("Account name must be between 3 and 30 characters.") updated_account = AccountService.update_account(current_user, name=args["name"]) return updated_account class AccountAvatarApi(Resource): @setup_required @login_required @account_initialization_required @marshal_with(account_fields) def post(self): parser = reqparse.RequestParser() parser.add_argument("avatar", type=str, required=True, location="json") args = parser.parse_args() updated_account = AccountService.update_account(current_user, avatar=args["avatar"]) return updated_account class AccountInterfaceLanguageApi(Resource): @setup_required @login_required @account_initialization_required @marshal_with(account_fields) def post(self): parser = reqparse.RequestParser() parser.add_argument("interface_language", type=supported_language, required=True, location="json") args = parser.parse_args() updated_account = AccountService.update_account(current_user, interface_language=args["interface_language"]) return updated_account class AccountInterfaceThemeApi(Resource): @setup_required @login_required @account_initialization_required @marshal_with(account_fields) def post(self): parser = reqparse.RequestParser() parser.add_argument("interface_theme", type=str, choices=["light", "dark"], required=True, location="json") args = parser.parse_args() updated_account = AccountService.update_account(current_user, interface_theme=args["interface_theme"]) return updated_account class AccountTimezoneApi(Resource): @setup_required @login_required @account_initialization_required @marshal_with(account_fields) def post(self): parser = reqparse.RequestParser() parser.add_argument("timezone", type=str, required=True, location="json") args = parser.parse_args() # Validate timezone string, e.g. America/New_York, Asia/Shanghai if args["timezone"] not in pytz.all_timezones: raise ValueError("Invalid timezone string.") updated_account = AccountService.update_account(current_user, timezone=args["timezone"]) return updated_account class AccountPasswordApi(Resource): @setup_required @login_required @account_initialization_required @marshal_with(account_fields) def post(self): parser = reqparse.RequestParser() parser.add_argument("password", type=str, required=False, location="json") parser.add_argument("new_password", type=str, required=True, location="json") parser.add_argument("repeat_new_password", type=str, required=True, location="json") args = parser.parse_args() if args["new_password"] != args["repeat_new_password"]: raise RepeatPasswordNotMatchError() try: AccountService.update_account_password(current_user, args["password"], args["new_password"]) except ServiceCurrentPasswordIncorrectError: raise CurrentPasswordIncorrectError() return {"result": "success"} class AccountIntegrateApi(Resource): integrate_fields = { "provider": fields.String, "created_at": TimestampField, "is_bound": fields.Boolean, "link": fields.String, } integrate_list_fields = { "data": fields.List(fields.Nested(integrate_fields)), } @setup_required @login_required @account_initialization_required @marshal_with(integrate_list_fields) def get(self): account = current_user account_integrates = db.session.query(AccountIntegrate).filter(AccountIntegrate.account_id == account.id).all() base_url = request.url_root.rstrip("/") oauth_base_path = "/console/api/oauth/login" providers = ["github", "google"] integrate_data = [] for provider in providers: existing_integrate = next((ai for ai in account_integrates if ai.provider == provider), None) if existing_integrate: integrate_data.append( { "id": existing_integrate.id, "provider": provider, "created_at": existing_integrate.created_at, "is_bound": True, "link": None, } ) else: integrate_data.append( { "id": None, "provider": provider, "created_at": None, "is_bound": False, "link": f"{base_url}{oauth_base_path}/{provider}", } ) return {"data": integrate_data} # Register API resources api.add_resource(AccountInitApi, "/account/init") api.add_resource(AccountProfileApi, "/account/profile") api.add_resource(AccountNameApi, "/account/name") api.add_resource(AccountAvatarApi, "/account/avatar") api.add_resource(AccountInterfaceLanguageApi, "/account/interface-language") api.add_resource(AccountInterfaceThemeApi, "/account/interface-theme") api.add_resource(AccountTimezoneApi, "/account/timezone") api.add_resource(AccountPasswordApi, "/account/password") api.add_resource(AccountIntegrateApi, "/account/integrates") # api.add_resource(AccountEmailApi, '/account/email') # api.add_resource(AccountEmailVerifyApi, '/account/email-verify')