mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
274 lines
9.1 KiB
Python
274 lines
9.1 KiB
Python
# -*- coding:utf-8 -*-
|
|
from datetime import datetime
|
|
|
|
import pytz
|
|
from flask import current_app, request
|
|
from flask_login import current_user
|
|
from flask_restful import Resource, fields, marshal_with, reqparse
|
|
|
|
from constants.languages import supported_language
|
|
from controllers.console import api
|
|
from controllers.console.setup import setup_required
|
|
from controllers.console.workspace.error import (
|
|
AccountAlreadyInitedError,
|
|
CurrentPasswordIncorrectError,
|
|
InvalidInvitationCodeError,
|
|
RepeatPasswordNotMatchError,
|
|
)
|
|
from controllers.console.wraps import account_initialization_required
|
|
from extensions.ext_database import db
|
|
from libs.helper import TimestampField, timezone
|
|
from libs.login import login_required
|
|
from models.account import AccountIntegrate, InvitationCode
|
|
from services.account_service import AccountService
|
|
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
|
|
|
|
account_fields = {
|
|
'id': fields.String,
|
|
'name': fields.String,
|
|
'avatar': fields.String,
|
|
'email': fields.String,
|
|
'is_password_set': fields.Boolean,
|
|
'interface_language': fields.String,
|
|
'interface_theme': fields.String,
|
|
'timezone': fields.String,
|
|
'last_login_at': TimestampField,
|
|
'last_login_ip': fields.String,
|
|
'created_at': TimestampField
|
|
}
|
|
|
|
|
|
class AccountInitApi(Resource):
|
|
|
|
@setup_required
|
|
@login_required
|
|
def post(self):
|
|
account = current_user
|
|
|
|
if account.status == 'active':
|
|
raise AccountAlreadyInitedError()
|
|
|
|
parser = reqparse.RequestParser()
|
|
|
|
if current_app.config['EDITION'] == 'CLOUD':
|
|
parser.add_argument('invitation_code', type=str, location='json')
|
|
|
|
parser.add_argument(
|
|
'interface_language', type=supported_language, required=True, location='json')
|
|
parser.add_argument('timezone', type=timezone,
|
|
required=True, location='json')
|
|
args = parser.parse_args()
|
|
|
|
if current_app.config['EDITION'] == 'CLOUD':
|
|
if not args['invitation_code']:
|
|
raise ValueError('invitation_code is required')
|
|
|
|
# check invitation code
|
|
invitation_code = db.session.query(InvitationCode).filter(
|
|
InvitationCode.code == args['invitation_code'],
|
|
InvitationCode.status == 'unused',
|
|
).first()
|
|
|
|
if not invitation_code:
|
|
raise InvalidInvitationCodeError()
|
|
|
|
invitation_code.status = 'used'
|
|
invitation_code.used_at = datetime.utcnow()
|
|
invitation_code.used_by_tenant_id = account.current_tenant_id
|
|
invitation_code.used_by_account_id = account.id
|
|
|
|
account.interface_language = args['interface_language']
|
|
account.timezone = args['timezone']
|
|
account.interface_theme = 'light'
|
|
account.status = 'active'
|
|
account.initialized_at = datetime.utcnow()
|
|
db.session.commit()
|
|
|
|
return {'result': 'success'}
|
|
|
|
|
|
class AccountProfileApi(Resource):
|
|
@setup_required
|
|
@login_required
|
|
@account_initialization_required
|
|
@marshal_with(account_fields)
|
|
def get(self):
|
|
return current_user
|
|
|
|
|
|
class AccountNameApi(Resource):
|
|
@setup_required
|
|
@login_required
|
|
@account_initialization_required
|
|
@marshal_with(account_fields)
|
|
def post(self):
|
|
parser = reqparse.RequestParser()
|
|
parser.add_argument('name', type=str, required=True, location='json')
|
|
args = parser.parse_args()
|
|
|
|
# Validate account name length
|
|
if len(args['name']) < 3 or len(args['name']) > 30:
|
|
raise ValueError(
|
|
"Account name must be between 3 and 30 characters.")
|
|
|
|
updated_account = AccountService.update_account(current_user, name=args['name'])
|
|
|
|
return updated_account
|
|
|
|
|
|
class AccountAvatarApi(Resource):
|
|
@setup_required
|
|
@login_required
|
|
@account_initialization_required
|
|
@marshal_with(account_fields)
|
|
def post(self):
|
|
parser = reqparse.RequestParser()
|
|
parser.add_argument('avatar', type=str, required=True, location='json')
|
|
args = parser.parse_args()
|
|
|
|
updated_account = AccountService.update_account(current_user, avatar=args['avatar'])
|
|
|
|
return updated_account
|
|
|
|
|
|
class AccountInterfaceLanguageApi(Resource):
|
|
@setup_required
|
|
@login_required
|
|
@account_initialization_required
|
|
@marshal_with(account_fields)
|
|
def post(self):
|
|
parser = reqparse.RequestParser()
|
|
parser.add_argument(
|
|
'interface_language', type=supported_language, required=True, location='json')
|
|
args = parser.parse_args()
|
|
|
|
updated_account = AccountService.update_account(current_user, interface_language=args['interface_language'])
|
|
|
|
return updated_account
|
|
|
|
|
|
class AccountInterfaceThemeApi(Resource):
|
|
@setup_required
|
|
@login_required
|
|
@account_initialization_required
|
|
@marshal_with(account_fields)
|
|
def post(self):
|
|
parser = reqparse.RequestParser()
|
|
parser.add_argument('interface_theme', type=str, choices=[
|
|
'light', 'dark'], required=True, location='json')
|
|
args = parser.parse_args()
|
|
|
|
updated_account = AccountService.update_account(current_user, interface_theme=args['interface_theme'])
|
|
|
|
return updated_account
|
|
|
|
|
|
class AccountTimezoneApi(Resource):
|
|
@setup_required
|
|
@login_required
|
|
@account_initialization_required
|
|
@marshal_with(account_fields)
|
|
def post(self):
|
|
parser = reqparse.RequestParser()
|
|
parser.add_argument('timezone', type=str,
|
|
required=True, location='json')
|
|
args = parser.parse_args()
|
|
|
|
# Validate timezone string, e.g. America/New_York, Asia/Shanghai
|
|
if args['timezone'] not in pytz.all_timezones:
|
|
raise ValueError("Invalid timezone string.")
|
|
|
|
updated_account = AccountService.update_account(current_user, timezone=args['timezone'])
|
|
|
|
return updated_account
|
|
|
|
|
|
class AccountPasswordApi(Resource):
|
|
@setup_required
|
|
@login_required
|
|
@account_initialization_required
|
|
@marshal_with(account_fields)
|
|
def post(self):
|
|
parser = reqparse.RequestParser()
|
|
parser.add_argument('password', type=str,
|
|
required=False, location='json')
|
|
parser.add_argument('new_password', type=str,
|
|
required=True, location='json')
|
|
parser.add_argument('repeat_new_password', type=str,
|
|
required=True, location='json')
|
|
args = parser.parse_args()
|
|
|
|
if args['new_password'] != args['repeat_new_password']:
|
|
raise RepeatPasswordNotMatchError()
|
|
|
|
try:
|
|
AccountService.update_account_password(
|
|
current_user, args['password'], args['new_password'])
|
|
except ServiceCurrentPasswordIncorrectError:
|
|
raise CurrentPasswordIncorrectError()
|
|
|
|
return {"result": "success"}
|
|
|
|
|
|
class AccountIntegrateApi(Resource):
|
|
integrate_fields = {
|
|
'provider': fields.String,
|
|
'created_at': TimestampField,
|
|
'is_bound': fields.Boolean,
|
|
'link': fields.String
|
|
}
|
|
|
|
integrate_list_fields = {
|
|
'data': fields.List(fields.Nested(integrate_fields)),
|
|
}
|
|
|
|
@setup_required
|
|
@login_required
|
|
@account_initialization_required
|
|
@marshal_with(integrate_list_fields)
|
|
def get(self):
|
|
account = current_user
|
|
|
|
account_integrates = db.session.query(AccountIntegrate).filter(
|
|
AccountIntegrate.account_id == account.id).all()
|
|
|
|
base_url = request.url_root.rstrip('/')
|
|
oauth_base_path = "/console/api/oauth/login"
|
|
providers = ["github", "google"]
|
|
|
|
integrate_data = []
|
|
for provider in providers:
|
|
existing_integrate = next((ai for ai in account_integrates if ai.provider == provider), None)
|
|
if existing_integrate:
|
|
integrate_data.append({
|
|
'id': existing_integrate.id,
|
|
'provider': provider,
|
|
'created_at': existing_integrate.created_at,
|
|
'is_bound': True,
|
|
'link': None
|
|
})
|
|
else:
|
|
integrate_data.append({
|
|
'id': None,
|
|
'provider': provider,
|
|
'created_at': None,
|
|
'is_bound': False,
|
|
'link': f'{base_url}{oauth_base_path}/{provider}'
|
|
})
|
|
|
|
return {'data': integrate_data}
|
|
|
|
|
|
# Register API resources
|
|
api.add_resource(AccountInitApi, '/account/init')
|
|
api.add_resource(AccountProfileApi, '/account/profile')
|
|
api.add_resource(AccountNameApi, '/account/name')
|
|
api.add_resource(AccountAvatarApi, '/account/avatar')
|
|
api.add_resource(AccountInterfaceLanguageApi, '/account/interface-language')
|
|
api.add_resource(AccountInterfaceThemeApi, '/account/interface-theme')
|
|
api.add_resource(AccountTimezoneApi, '/account/timezone')
|
|
api.add_resource(AccountPasswordApi, '/account/password')
|
|
api.add_resource(AccountIntegrateApi, '/account/integrates')
|
|
# api.add_resource(AccountEmailApi, '/account/email')
|
|
# api.add_resource(AccountEmailVerifyApi, '/account/email-verify')
|