Feat/new login (#8120)

Co-authored-by: douxc <douxc512@gmail.com>
Co-authored-by: Garfield Dai <dai.hai@foxmail.com>
This commit is contained in:
Joe 2024-10-21 10:03:40 +08:00 committed by GitHub
parent 2c0eaaec3d
commit 4fd2743efa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 1027 additions and 292 deletions

View File

@ -326,4 +326,4 @@ POSITION_TOOL_EXCLUDES=
POSITION_PROVIDER_PINS= POSITION_PROVIDER_PINS=
POSITION_PROVIDER_INCLUDES= POSITION_PROVIDER_INCLUDES=
POSITION_PROVIDER_EXCLUDES= POSITION_PROVIDER_EXCLUDES=

View File

@ -1,6 +1,15 @@
from typing import Annotated, Optional from typing import Annotated, Literal, Optional
from pydantic import AliasChoices, Field, HttpUrl, NegativeInt, NonNegativeInt, PositiveInt, computed_field from pydantic import (
AliasChoices,
Field,
HttpUrl,
NegativeInt,
NonNegativeInt,
PositiveFloat,
PositiveInt,
computed_field,
)
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings
from configs.feature.hosted_service import HostedServiceConfig from configs.feature.hosted_service import HostedServiceConfig
@ -473,6 +482,11 @@ class MailConfig(BaseSettings):
default=False, default=False,
) )
EMAIL_SEND_IP_LIMIT_PER_MINUTE: PositiveInt = Field(
description="Maximum number of emails allowed to be sent from the same IP address in a minute",
default=50,
)
class RagEtlConfig(BaseSettings): class RagEtlConfig(BaseSettings):
""" """
@ -614,6 +628,33 @@ class PositionConfig(BaseSettings):
return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""} return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""}
class LoginConfig(BaseSettings):
ENABLE_EMAIL_CODE_LOGIN: bool = Field(
description="whether to enable email code login",
default=False,
)
ENABLE_EMAIL_PASSWORD_LOGIN: bool = Field(
description="whether to enable email password login",
default=True,
)
ENABLE_SOCIAL_OAUTH_LOGIN: bool = Field(
description="whether to enable github/google oauth login",
default=False,
)
EMAIL_CODE_LOGIN_TOKEN_EXPIRY_HOURS: PositiveFloat = Field(
description="expiry time in hours for email code login token",
default=1 / 12,
)
ALLOW_REGISTER: bool = Field(
description="whether to enable register",
default=False,
)
ALLOW_CREATE_WORKSPACE: bool = Field(
description="whether to enable create workspace",
default=False,
)
class FeatureConfig( class FeatureConfig(
# place the configs in alphabet order # place the configs in alphabet order
AppExecutionConfig, AppExecutionConfig,
@ -639,6 +680,7 @@ class FeatureConfig(
UpdateConfig, UpdateConfig,
WorkflowConfig, WorkflowConfig,
WorkspaceConfig, WorkspaceConfig,
LoginConfig,
# hosted services config # hosted services config
HostedServiceConfig, HostedServiceConfig,
CeleryBeatConfig, CeleryBeatConfig,

View File

@ -1,17 +1,15 @@
import base64
import datetime import datetime
import secrets
from flask import request
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse
from constants.languages import supported_language from constants.languages import supported_language
from controllers.console import api from controllers.console import api
from controllers.console.error import AlreadyActivateError from controllers.console.error import AlreadyActivateError
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import StrLen, email, timezone from libs.helper import StrLen, email, extract_remote_ip, timezone
from libs.password import hash_password, valid_password from models.account import AccountStatus, Tenant
from models.account import AccountStatus from services.account_service import AccountService, RegisterService
from services.account_service import RegisterService
class ActivateCheckApi(Resource): class ActivateCheckApi(Resource):
@ -27,8 +25,18 @@ class ActivateCheckApi(Resource):
token = args["token"] token = args["token"]
invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token) invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token)
if invitation:
return {"is_valid": invitation is not None, "workspace_name": invitation["tenant"].name if invitation else None} data = invitation.get("data", {})
tenant: Tenant = invitation.get("tenant", None)
workspace_name = tenant.name if tenant else None
workspace_id = tenant.id if tenant else None
invitee_email = data.get("email") if data else None
return {
"is_valid": invitation is not None,
"data": {"workspace_name": workspace_name, "workspace_id": workspace_id, "email": invitee_email},
}
else:
return {"is_valid": False}
class ActivateApi(Resource): class ActivateApi(Resource):
@ -38,7 +46,6 @@ class ActivateApi(Resource):
parser.add_argument("email", type=email, required=False, nullable=True, location="json") parser.add_argument("email", type=email, required=False, nullable=True, location="json")
parser.add_argument("token", type=str, required=True, nullable=False, location="json") parser.add_argument("token", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json") parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
parser.add_argument("password", type=valid_password, required=True, nullable=False, location="json")
parser.add_argument( parser.add_argument(
"interface_language", type=supported_language, required=True, nullable=False, location="json" "interface_language", type=supported_language, required=True, nullable=False, location="json"
) )
@ -54,15 +61,6 @@ class ActivateApi(Resource):
account = invitation["account"] account = invitation["account"]
account.name = args["name"] account.name = args["name"]
# generate password salt
salt = secrets.token_bytes(16)
base64_salt = base64.b64encode(salt).decode()
# encrypt password with salt
password_hashed = hash_password(args["password"], salt)
base64_password_hashed = base64.b64encode(password_hashed).decode()
account.password = base64_password_hashed
account.password_salt = base64_salt
account.interface_language = args["interface_language"] account.interface_language = args["interface_language"]
account.timezone = args["timezone"] account.timezone = args["timezone"]
account.interface_theme = "light" account.interface_theme = "light"
@ -70,7 +68,9 @@ class ActivateApi(Resource):
account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.commit() db.session.commit()
return {"result": "success"} token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
return {"result": "success", "data": token_pair.model_dump()}
api.add_resource(ActivateCheckApi, "/activate/check") api.add_resource(ActivateCheckApi, "/activate/check")

View File

@ -27,5 +27,29 @@ class InvalidTokenError(BaseHTTPException):
class PasswordResetRateLimitExceededError(BaseHTTPException): class PasswordResetRateLimitExceededError(BaseHTTPException):
error_code = "password_reset_rate_limit_exceeded" error_code = "password_reset_rate_limit_exceeded"
description = "Password reset rate limit exceeded. Try again later." description = "Too many password reset emails have been sent. Please try again in 1 minutes."
code = 429
class EmailCodeError(BaseHTTPException):
error_code = "email_code_error"
description = "Email code is invalid or expired."
code = 400
class EmailOrPasswordMismatchError(BaseHTTPException):
error_code = "email_or_password_mismatch"
description = "The email or password is mismatched."
code = 400
class EmailPasswordLoginLimitError(BaseHTTPException):
error_code = "email_code_login_limit"
description = "Too many incorrect password attempts. Please try again later."
code = 429
class EmailCodeLoginRateLimitExceededError(BaseHTTPException):
error_code = "email_code_login_rate_limit_exceeded"
description = "Too many login emails have been sent. Please try again in 5 minutes."
code = 429 code = 429

View File

@ -1,65 +1,82 @@
import base64 import base64
import logging
import secrets import secrets
from flask import request
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse
from constants.languages import languages
from controllers.console import api from controllers.console import api
from controllers.console.auth.error import ( from controllers.console.auth.error import (
EmailCodeError,
InvalidEmailError, InvalidEmailError,
InvalidTokenError, InvalidTokenError,
PasswordMismatchError, PasswordMismatchError,
PasswordResetRateLimitExceededError,
) )
from controllers.console.error import EmailSendIpLimitError, NotAllowedRegister
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from events.tenant_event import tenant_was_created
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import email as email_validate from libs.helper import email, extract_remote_ip
from libs.password import hash_password, valid_password from libs.password import hash_password, valid_password
from models.account import Account from models.account import Account
from services.account_service import AccountService from services.account_service import AccountService, TenantService
from services.errors.account import RateLimitExceededError from services.errors.workspace import WorkSpaceNotAllowedCreateError
from services.feature_service import FeatureService
class ForgotPasswordSendEmailApi(Resource): class ForgotPasswordSendEmailApi(Resource):
@setup_required @setup_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("email", type=str, required=True, location="json") parser.add_argument("email", type=email, required=True, location="json")
parser.add_argument("language", type=str, required=False, location="json")
args = parser.parse_args() args = parser.parse_args()
email = args["email"] ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
if not email_validate(email): if args["language"] is not None and args["language"] == "zh-Hans":
raise InvalidEmailError() language = "zh-Hans"
account = Account.query.filter_by(email=email).first()
if account:
try:
AccountService.send_reset_password_email(account=account)
except RateLimitExceededError:
logging.warning(f"Rate limit exceeded for email: {account.email}")
raise PasswordResetRateLimitExceededError()
else: else:
# Return success to avoid revealing email registration status language = "en-US"
logging.warning(f"Attempt to reset password for unregistered email: {email}")
return {"result": "success"} account = Account.query.filter_by(email=args["email"]).first()
token = None
if account is None:
if FeatureService.get_system_features().is_allow_register:
token = AccountService.send_reset_password_email(email=args["email"], language=language)
return {"result": "fail", "data": token, "code": "account_not_found"}
else:
raise NotAllowedRegister()
else:
token = AccountService.send_reset_password_email(account=account, email=args["email"], language=language)
return {"result": "success", "data": token}
class ForgotPasswordCheckApi(Resource): class ForgotPasswordCheckApi(Resource):
@setup_required @setup_required
def post(self): def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("email", type=str, required=True, location="json")
parser.add_argument("code", type=str, required=True, location="json")
parser.add_argument("token", type=str, required=True, nullable=False, location="json") parser.add_argument("token", type=str, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
token = args["token"]
reset_data = AccountService.get_reset_password_data(token) user_email = args["email"]
if reset_data is None: token_data = AccountService.get_reset_password_data(args["token"])
return {"is_valid": False, "email": None} if token_data is None:
return {"is_valid": True, "email": reset_data.get("email")} raise InvalidTokenError()
if user_email != token_data.get("email"):
raise InvalidEmailError()
if args["code"] != token_data.get("code"):
raise EmailCodeError()
return {"is_valid": True, "email": token_data.get("email")}
class ForgotPasswordResetApi(Resource): class ForgotPasswordResetApi(Resource):
@ -92,9 +109,26 @@ class ForgotPasswordResetApi(Resource):
base64_password_hashed = base64.b64encode(password_hashed).decode() base64_password_hashed = base64.b64encode(password_hashed).decode()
account = Account.query.filter_by(email=reset_data.get("email")).first() account = Account.query.filter_by(email=reset_data.get("email")).first()
account.password = base64_password_hashed if account:
account.password_salt = base64_salt account.password = base64_password_hashed
db.session.commit() account.password_salt = base64_salt
db.session.commit()
tenant = TenantService.get_join_tenants(account)
if not tenant and not FeatureService.get_system_features().is_allow_create_workspace:
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(tenant, account, role="owner")
account.current_tenant = tenant
tenant_was_created.send(tenant)
else:
try:
account = AccountService.create_account_and_tenant(
email=reset_data.get("email"),
name=reset_data.get("email"),
password=password_confirm,
interface_language=languages[0],
)
except WorkSpaceNotAllowedCreateError:
pass
return {"result": "success"} return {"result": "success"}

View File

@ -1,16 +1,34 @@
from typing import cast from typing import cast
import flask_login import flask_login
from flask import request from flask import redirect, request
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse
import services import services
from configs import dify_config
from constants.languages import languages
from controllers.console import api from controllers.console import api
from controllers.console.auth.error import (
EmailCodeError,
EmailOrPasswordMismatchError,
EmailPasswordLoginLimitError,
InvalidEmailError,
InvalidTokenError,
)
from controllers.console.error import (
AccountBannedError,
EmailSendIpLimitError,
NotAllowedCreateWorkspace,
NotAllowedRegister,
)
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from events.tenant_event import tenant_was_created
from libs.helper import email, extract_remote_ip from libs.helper import email, extract_remote_ip
from libs.password import valid_password from libs.password import valid_password
from models.account import Account from models.account import Account
from services.account_service import AccountService, TenantService from services.account_service import AccountService, RegisterService, TenantService
from services.errors.workspace import WorkSpaceNotAllowedCreateError
from services.feature_service import FeatureService
class LoginApi(Resource): class LoginApi(Resource):
@ -23,15 +41,43 @@ class LoginApi(Resource):
parser.add_argument("email", type=email, required=True, location="json") parser.add_argument("email", type=email, required=True, location="json")
parser.add_argument("password", type=valid_password, required=True, location="json") parser.add_argument("password", type=valid_password, required=True, location="json")
parser.add_argument("remember_me", type=bool, required=False, default=False, location="json") parser.add_argument("remember_me", type=bool, required=False, default=False, location="json")
parser.add_argument("invite_token", type=str, required=False, default=None, location="json")
parser.add_argument("language", type=str, required=False, default="en-US", location="json")
args = parser.parse_args() args = parser.parse_args()
# todo: Verify the recaptcha is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args["email"])
if is_login_error_rate_limit:
raise EmailPasswordLoginLimitError()
invitation = args["invite_token"]
if invitation:
invitation = RegisterService.get_invitation_if_token_valid(None, args["email"], invitation)
if args["language"] is not None and args["language"] == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
try: try:
account = AccountService.authenticate(args["email"], args["password"]) if invitation:
except services.errors.account.AccountLoginError as e: data = invitation.get("data", {})
return {"code": "unauthorized", "message": str(e)}, 401 invitee_email = data.get("email") if data else None
if invitee_email != args["email"]:
raise InvalidEmailError()
account = AccountService.authenticate(args["email"], args["password"], args["invite_token"])
else:
account = AccountService.authenticate(args["email"], args["password"])
except services.errors.account.AccountLoginError:
raise AccountBannedError()
except services.errors.account.AccountPasswordError:
AccountService.add_login_error_rate_limit(args["email"])
raise EmailOrPasswordMismatchError()
except services.errors.account.AccountNotFoundError:
if FeatureService.get_system_features().is_allow_register:
token = AccountService.send_reset_password_email(email=args["email"], language=language)
return {"result": "fail", "data": token, "code": "account_not_found"}
else:
raise NotAllowedRegister()
# SELF_HOSTED only have one workspace # SELF_HOSTED only have one workspace
tenants = TenantService.get_join_tenants(account) tenants = TenantService.get_join_tenants(account)
if len(tenants) == 0: if len(tenants) == 0:
@ -41,7 +87,7 @@ class LoginApi(Resource):
} }
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(args["email"])
return {"result": "success", "data": token_pair.model_dump()} return {"result": "success", "data": token_pair.model_dump()}
@ -49,60 +95,114 @@ class LogoutApi(Resource):
@setup_required @setup_required
def get(self): def get(self):
account = cast(Account, flask_login.current_user) account = cast(Account, flask_login.current_user)
if isinstance(account, flask_login.AnonymousUserMixin):
return {"result": "success"}
AccountService.logout(account=account) AccountService.logout(account=account)
flask_login.logout_user() flask_login.logout_user()
return {"result": "success"} return {"result": "success"}
class ResetPasswordApi(Resource): class ResetPasswordSendEmailApi(Resource):
@setup_required @setup_required
def get(self): def post(self):
# parser = reqparse.RequestParser() parser = reqparse.RequestParser()
# parser.add_argument('email', type=email, required=True, location='json') parser.add_argument("email", type=email, required=True, location="json")
# args = parser.parse_args() parser.add_argument("language", type=str, required=False, location="json")
args = parser.parse_args()
# import mailchimp_transactional as MailchimpTransactional if args["language"] is not None and args["language"] == "zh-Hans":
# from mailchimp_transactional.api_client import ApiClientError language = "zh-Hans"
else:
language = "en-US"
# account = {'email': args['email']} account = AccountService.get_user_through_email(args["email"])
# account = AccountService.get_by_email(args['email']) if account is None:
# if account is None: if FeatureService.get_system_features().is_allow_register:
# raise ValueError('Email not found') token = AccountService.send_reset_password_email(email=args["email"], language=language)
# new_password = AccountService.generate_password() else:
# AccountService.update_password(account, new_password) raise NotAllowedRegister()
else:
token = AccountService.send_reset_password_email(account=account, language=language)
# todo: Send email return {"result": "success", "data": token}
# MAILCHIMP_API_KEY = dify_config.MAILCHIMP_TRANSACTIONAL_API_KEY
# mailchimp = MailchimpTransactional(MAILCHIMP_API_KEY)
# message = {
# 'from_email': 'noreply@example.com',
# 'to': [{'email': account['email']}],
# 'subject': 'Reset your Dify password',
# 'html': """
# <p>Dear User,</p>
# <p>The Dify team has generated a new password for you, details as follows:</p>
# <p><strong>{new_password}</strong></p>
# <p>Please change your password to log in as soon as possible.</p>
# <p>Regards,</p>
# <p>The Dify Team</p>
# """
# }
# response = mailchimp.messages.send({ class EmailCodeLoginSendEmailApi(Resource):
# 'message': message, @setup_required
# # required for transactional email def post(self):
# ' settings': { parser = reqparse.RequestParser()
# 'sandbox_mode': dify_config.MAILCHIMP_SANDBOX_MODE, parser.add_argument("email", type=email, required=True, location="json")
# }, parser.add_argument("language", type=str, required=False, location="json")
# }) args = parser.parse_args()
# Check if MSG was sent ip_address = extract_remote_ip(request)
# if response.status_code != 200: if AccountService.is_email_send_ip_limit(ip_address):
# # handle error raise EmailSendIpLimitError()
# pass
return {"result": "success"} if args["language"] is not None and args["language"] == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
account = AccountService.get_user_through_email(args["email"])
if account is None:
if FeatureService.get_system_features().is_allow_register:
token = AccountService.send_email_code_login_email(email=args["email"], language=language)
else:
raise NotAllowedRegister()
else:
token = AccountService.send_email_code_login_email(account=account, language=language)
return {"result": "success", "data": token}
class EmailCodeLoginApi(Resource):
@setup_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=str, required=True, location="json")
parser.add_argument("code", type=str, required=True, location="json")
parser.add_argument("token", type=str, required=True, location="json")
args = parser.parse_args()
user_email = args["email"]
token_data = AccountService.get_email_code_login_data(args["token"])
if token_data is None:
raise InvalidTokenError()
if token_data["email"] != args["email"]:
raise InvalidEmailError()
if token_data["code"] != args["code"]:
raise EmailCodeError()
AccountService.revoke_email_code_login_token(args["token"])
account = AccountService.get_user_through_email(user_email)
if account:
tenant = TenantService.get_join_tenants(account)
if not tenant:
if not FeatureService.get_system_features().is_allow_create_workspace:
raise NotAllowedCreateWorkspace()
else:
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(tenant, account, role="owner")
account.current_tenant = tenant
tenant_was_created.send(tenant)
if account is None:
try:
account = AccountService.create_account_and_tenant(
email=user_email, name=user_email, interface_language=languages[0]
)
except WorkSpaceNotAllowedCreateError:
return redirect(
f"{dify_config.CONSOLE_WEB_URL}/signin"
"?message=Workspace not found, please contact system admin to invite you to join in a workspace."
)
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(args["email"])
return {"result": "success", "data": token_pair.model_dump()}
class RefreshTokenApi(Resource): class RefreshTokenApi(Resource):
@ -120,4 +220,7 @@ class RefreshTokenApi(Resource):
api.add_resource(LoginApi, "/login") api.add_resource(LoginApi, "/login")
api.add_resource(LogoutApi, "/logout") api.add_resource(LogoutApi, "/logout")
api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login")
api.add_resource(EmailCodeLoginApi, "/email-code-login/validity")
api.add_resource(ResetPasswordSendEmailApi, "/reset-password")
api.add_resource(RefreshTokenApi, "/refresh-token") api.add_resource(RefreshTokenApi, "/refresh-token")

View File

@ -5,14 +5,19 @@ from typing import Optional
import requests import requests
from flask import current_app, redirect, request from flask import current_app, redirect, request
from flask_restful import Resource from flask_restful import Resource
from werkzeug.exceptions import Unauthorized
from configs import dify_config from configs import dify_config
from constants.languages import languages from constants.languages import languages
from events.tenant_event import tenant_was_created
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import extract_remote_ip from libs.helper import extract_remote_ip
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
from models.account import Account, AccountStatus from models.account import Account, AccountStatus
from services.account_service import AccountService, RegisterService, TenantService from services.account_service import AccountService, RegisterService, TenantService
from services.errors.account import AccountNotFoundError
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError
from services.feature_service import FeatureService
from .. import api from .. import api
@ -42,6 +47,7 @@ def get_oauth_providers():
class OAuthLogin(Resource): class OAuthLogin(Resource):
def get(self, provider: str): def get(self, provider: str):
invite_token = request.args.get("invite_token") or None
OAUTH_PROVIDERS = get_oauth_providers() OAUTH_PROVIDERS = get_oauth_providers()
with current_app.app_context(): with current_app.app_context():
oauth_provider = OAUTH_PROVIDERS.get(provider) oauth_provider = OAUTH_PROVIDERS.get(provider)
@ -49,7 +55,7 @@ class OAuthLogin(Resource):
if not oauth_provider: if not oauth_provider:
return {"error": "Invalid provider"}, 400 return {"error": "Invalid provider"}, 400
auth_url = oauth_provider.get_authorization_url() auth_url = oauth_provider.get_authorization_url(invite_token=invite_token)
return redirect(auth_url) return redirect(auth_url)
@ -62,6 +68,11 @@ class OAuthCallback(Resource):
return {"error": "Invalid provider"}, 400 return {"error": "Invalid provider"}, 400
code = request.args.get("code") code = request.args.get("code")
state = request.args.get("state")
invite_token = None
if state:
invite_token = state
try: try:
token = oauth_provider.get_access_token(code) token = oauth_provider.get_access_token(code)
user_info = oauth_provider.get_user_info(token) user_info = oauth_provider.get_user_info(token)
@ -69,7 +80,27 @@ class OAuthCallback(Resource):
logging.exception(f"An error occurred during the OAuth process with {provider}: {e.response.text}") logging.exception(f"An error occurred during the OAuth process with {provider}: {e.response.text}")
return {"error": "OAuth process failed"}, 400 return {"error": "OAuth process failed"}, 400
account = _generate_account(provider, user_info) if invite_token and RegisterService.is_valid_invite_token(invite_token):
invitation = RegisterService._get_invitation_by_token(token=invite_token)
if invitation:
invitation_email = invitation.get("email", None)
if invitation_email != user_info.email:
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Invalid invitation token.")
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}")
try:
account = _generate_account(provider, user_info)
except AccountNotFoundError:
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account not found.")
except WorkSpaceNotFoundError:
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Workspace not found.")
except WorkSpaceNotAllowedCreateError:
return redirect(
f"{dify_config.CONSOLE_WEB_URL}/signin"
"?message=Workspace not found, please contact system admin to invite you to join in a workspace."
)
# Check account status # Check account status
if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}: if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}:
return {"error": "Account is banned or closed."}, 403 return {"error": "Account is banned or closed."}, 403
@ -79,7 +110,15 @@ class OAuthCallback(Resource):
account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None) account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None)
db.session.commit() db.session.commit()
TenantService.create_owner_tenant_if_not_exist(account) try:
TenantService.create_owner_tenant_if_not_exist(account)
except Unauthorized:
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Workspace not found.")
except WorkSpaceNotAllowedCreateError:
return redirect(
f"{dify_config.CONSOLE_WEB_URL}/signin"
"?message=Workspace not found, please contact system admin to invite you to join in a workspace."
)
token_pair = AccountService.login( token_pair = AccountService.login(
account=account, account=account,
@ -104,8 +143,20 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
# Get account by openid or email. # Get account by openid or email.
account = _get_account_by_openid_or_email(provider, user_info) account = _get_account_by_openid_or_email(provider, user_info)
if account:
tenant = TenantService.get_join_tenants(account)
if not tenant:
if not FeatureService.get_system_features().is_allow_create_workspace:
raise WorkSpaceNotAllowedCreateError()
else:
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(tenant, account, role="owner")
account.current_tenant = tenant
tenant_was_created.send(tenant)
if not account: if not account:
# Create account if not FeatureService.get_system_features().is_allow_register:
raise AccountNotFoundError()
account_name = user_info.name or "Dify" account_name = user_info.name or "Dify"
account = RegisterService.register( account = RegisterService.register(
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider

View File

@ -38,3 +38,27 @@ class AlreadyActivateError(BaseHTTPException):
error_code = "already_activate" error_code = "already_activate"
description = "Auth Token is invalid or account already activated, please check again." description = "Auth Token is invalid or account already activated, please check again."
code = 403 code = 403
class NotAllowedCreateWorkspace(BaseHTTPException):
error_code = "unauthorized"
description = "Workspace not found, please contact system admin to invite you to join in a workspace."
code = 400
class AccountBannedError(BaseHTTPException):
error_code = "account_banned"
description = "Account is banned."
code = 400
class NotAllowedRegister(BaseHTTPException):
error_code = "unauthorized"
description = "Account not found."
code = 400
class EmailSendIpLimitError(BaseHTTPException):
error_code = "email_send_ip_limit"
description = "Too many emails have been sent from this IP address recently. Please try again later."
code = 429

View File

@ -189,23 +189,39 @@ def compact_generate_response(response: Union[dict, RateLimitGenerator]) -> Resp
class TokenManager: class TokenManager:
@classmethod @classmethod
def generate_token(cls, account: Account, token_type: str, additional_data: Optional[dict] = None) -> str: def generate_token(
old_token = cls._get_current_token_for_account(account.id, token_type) cls,
if old_token: token_type: str,
if isinstance(old_token, bytes): account: Optional[Account] = None,
old_token = old_token.decode("utf-8") email: Optional[str] = None,
cls.revoke_token(old_token, token_type) additional_data: Optional[dict] = None,
) -> str:
if account is None and email is None:
raise ValueError("Account or email must be provided")
account_id = account.id if account else None
account_email = account.email if account else email
if account_id:
old_token = cls._get_current_token_for_account(account_id, token_type)
if old_token:
if isinstance(old_token, bytes):
old_token = old_token.decode("utf-8")
cls.revoke_token(old_token, token_type)
token = str(uuid.uuid4()) token = str(uuid.uuid4())
token_data = {"account_id": account.id, "email": account.email, "token_type": token_type} token_data = {"account_id": account_id, "email": account_email, "token_type": token_type}
if additional_data: if additional_data:
token_data.update(additional_data) token_data.update(additional_data)
expiry_hours = current_app.config[f"{token_type.upper()}_TOKEN_EXPIRY_HOURS"] expiry_hours = current_app.config[f"{token_type.upper()}_TOKEN_EXPIRY_HOURS"]
token_key = cls._get_token_key(token, token_type) token_key = cls._get_token_key(token, token_type)
redis_client.setex(token_key, expiry_hours * 60 * 60, json.dumps(token_data)) expiry_time = int(expiry_hours * 60 * 60)
redis_client.setex(token_key, expiry_time, json.dumps(token_data))
if account_id:
cls._set_current_token_for_account(account.id, token, token_type, expiry_hours)
cls._set_current_token_for_account(account.id, token, token_type, expiry_hours)
return token return token
@classmethod @classmethod
@ -234,9 +250,12 @@ class TokenManager:
return current_token return current_token
@classmethod @classmethod
def _set_current_token_for_account(cls, account_id: str, token: str, token_type: str, expiry_hours: int): def _set_current_token_for_account(
cls, account_id: str, token: str, token_type: str, expiry_hours: Union[int, float]
):
key = cls._get_account_token_key(account_id, token_type) key = cls._get_account_token_key(account_id, token_type)
redis_client.setex(key, expiry_hours * 60 * 60, token) expiry_time = int(expiry_hours * 60 * 60)
redis_client.setex(key, expiry_time, token)
@classmethod @classmethod
def _get_account_token_key(cls, account_id: str, token_type: str) -> str: def _get_account_token_key(cls, account_id: str, token_type: str) -> str:

View File

@ -1,5 +1,6 @@
import urllib.parse import urllib.parse
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional
import requests import requests
@ -40,12 +41,14 @@ class GitHubOAuth(OAuth):
_USER_INFO_URL = "https://api.github.com/user" _USER_INFO_URL = "https://api.github.com/user"
_EMAIL_INFO_URL = "https://api.github.com/user/emails" _EMAIL_INFO_URL = "https://api.github.com/user/emails"
def get_authorization_url(self): def get_authorization_url(self, invite_token: Optional[str] = None):
params = { params = {
"client_id": self.client_id, "client_id": self.client_id,
"redirect_uri": self.redirect_uri, "redirect_uri": self.redirect_uri,
"scope": "user:email", # Request only basic user information "scope": "user:email", # Request only basic user information
} }
if invite_token:
params["state"] = invite_token
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
def get_access_token(self, code: str): def get_access_token(self, code: str):
@ -90,13 +93,15 @@ class GoogleOAuth(OAuth):
_TOKEN_URL = "https://oauth2.googleapis.com/token" _TOKEN_URL = "https://oauth2.googleapis.com/token"
_USER_INFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo" _USER_INFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo"
def get_authorization_url(self): def get_authorization_url(self, invite_token: Optional[str] = None):
params = { params = {
"client_id": self.client_id, "client_id": self.client_id,
"response_type": "code", "response_type": "code",
"redirect_uri": self.redirect_uri, "redirect_uri": self.redirect_uri,
"scope": "openid email", "scope": "openid email",
} }
if invite_token:
params["state"] = invite_token
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
def get_access_token(self, code: str): def get_access_token(self, code: str):

View File

@ -13,7 +13,7 @@ def valid_password(password):
if re.match(pattern, password) is not None: if re.match(pattern, password) is not None:
return password return password
raise ValueError("Not a valid password.") raise ValueError("Password must contain letters and numbers, and the length must be greater than 8.")
def hash_password(password_str, salt_byte): def hash_password(password_str, salt_byte):

View File

@ -1,6 +1,7 @@
import base64 import base64
import json import json
import logging import logging
import random
import secrets import secrets
import uuid import uuid
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
@ -34,7 +35,9 @@ from models.model import DifySetup
from services.errors.account import ( from services.errors.account import (
AccountAlreadyInTenantError, AccountAlreadyInTenantError,
AccountLoginError, AccountLoginError,
AccountNotFoundError,
AccountNotLinkTenantError, AccountNotLinkTenantError,
AccountPasswordError,
AccountRegisterError, AccountRegisterError,
CannotOperateSelfError, CannotOperateSelfError,
CurrentPasswordIncorrectError, CurrentPasswordIncorrectError,
@ -42,10 +45,12 @@ from services.errors.account import (
LinkAccountIntegrateError, LinkAccountIntegrateError,
MemberNotInTenantError, MemberNotInTenantError,
NoPermissionError, NoPermissionError,
RateLimitExceededError,
RoleAlreadyAssignedError, RoleAlreadyAssignedError,
TenantNotFoundError, TenantNotFoundError,
) )
from services.errors.workspace import WorkSpaceNotAllowedCreateError
from services.feature_service import FeatureService
from tasks.mail_email_code_login import send_email_code_login_mail_task
from tasks.mail_invite_member_task import send_invite_member_mail_task from tasks.mail_invite_member_task import send_invite_member_mail_task
from tasks.mail_reset_password_task import send_reset_password_mail_task from tasks.mail_reset_password_task import send_reset_password_mail_task
@ -61,7 +66,11 @@ REFRESH_TOKEN_EXPIRY = timedelta(days=30)
class AccountService: class AccountService:
reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=5, time_window=60 * 60) reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=1, time_window=60 * 1)
email_code_login_rate_limiter = RateLimiter(
prefix="email_code_login_rate_limit", max_attempts=1, time_window=60 * 1
)
LOGIN_MAX_ERROR_LIMITS = 5
@staticmethod @staticmethod
def _get_refresh_token_key(refresh_token: str) -> str: def _get_refresh_token_key(refresh_token: str) -> str:
@ -127,23 +136,34 @@ class AccountService:
return token return token
@staticmethod @staticmethod
def authenticate(email: str, password: str) -> Account: def authenticate(email: str, password: str, invite_token: Optional[str] = None) -> Account:
"""authenticate account with email and password""" """authenticate account with email and password"""
account = Account.query.filter_by(email=email).first() account = Account.query.filter_by(email=email).first()
if not account: if not account:
raise AccountLoginError("Invalid email or password.") raise AccountNotFoundError()
if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}: if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}:
raise AccountLoginError("Account is banned or closed.") raise AccountLoginError("Account is banned or closed.")
if password and invite_token and account.password is None:
# if invite_token is valid, set password and password_salt
salt = secrets.token_bytes(16)
base64_salt = base64.b64encode(salt).decode()
password_hashed = hash_password(password, salt)
base64_password_hashed = base64.b64encode(password_hashed).decode()
account.password = base64_password_hashed
account.password_salt = base64_salt
if account.password is None or not compare_password(password, account.password, account.password_salt):
raise AccountPasswordError("Invalid email or password.")
if account.status == AccountStatus.PENDING.value: if account.status == AccountStatus.PENDING.value:
account.status = AccountStatus.ACTIVE.value account.status = AccountStatus.ACTIVE.value
account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None) account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None)
db.session.commit()
if account.password is None or not compare_password(password, account.password, account.password_salt): db.session.commit()
raise AccountLoginError("Invalid email or password.")
return account return account
@staticmethod @staticmethod
@ -169,9 +189,18 @@ class AccountService:
@staticmethod @staticmethod
def create_account( def create_account(
email: str, name: str, interface_language: str, password: Optional[str] = None, interface_theme: str = "light" email: str,
name: str,
interface_language: str,
password: Optional[str] = None,
interface_theme: str = "light",
is_setup: Optional[bool] = False,
) -> Account: ) -> Account:
"""create account""" """create account"""
if not FeatureService.get_system_features().is_allow_register and not is_setup:
from controllers.console.error import NotAllowedRegister
raise NotAllowedRegister()
account = Account() account = Account()
account.email = email account.email = email
account.name = name account.name = name
@ -198,6 +227,19 @@ class AccountService:
db.session.commit() db.session.commit()
return account return account
@staticmethod
def create_account_and_tenant(
email: str, name: str, interface_language: str, password: Optional[str] = None
) -> Account:
"""create account"""
account = AccountService.create_account(
email=email, name=name, interface_language=interface_language, password=password
)
TenantService.create_owner_tenant_if_not_exist(account=account)
return account
@staticmethod @staticmethod
def link_account_integrate(provider: str, open_id: str, account: Account) -> None: def link_account_integrate(provider: str, open_id: str, account: Account) -> None:
"""Link account integrate""" """Link account integrate"""
@ -256,6 +298,10 @@ class AccountService:
if ip_address: if ip_address:
AccountService.update_login_info(account=account, ip_address=ip_address) AccountService.update_login_info(account=account, ip_address=ip_address)
if account.status == AccountStatus.PENDING.value:
account.status = AccountStatus.ACTIVE.value
db.session.commit()
access_token = AccountService.get_account_jwt_token(account=account) access_token = AccountService.get_account_jwt_token(account=account)
refresh_token = _generate_refresh_token() refresh_token = _generate_refresh_token()
@ -294,13 +340,29 @@ class AccountService:
return AccountService.load_user(account_id) return AccountService.load_user(account_id)
@classmethod @classmethod
def send_reset_password_email(cls, account): def send_reset_password_email(
if cls.reset_password_rate_limiter.is_rate_limited(account.email): cls,
raise RateLimitExceededError(f"Rate limit exceeded for email: {account.email}. Please try again later.") account: Optional[Account] = None,
email: Optional[str] = None,
language: Optional[str] = "en-US",
):
account_email = account.email if account else email
token = TokenManager.generate_token(account, "reset_password") if cls.reset_password_rate_limiter.is_rate_limited(account_email):
send_reset_password_mail_task.delay(language=account.interface_language, to=account.email, token=token) from controllers.console.auth.error import PasswordResetRateLimitExceededError
cls.reset_password_rate_limiter.increment_rate_limit(account.email)
raise PasswordResetRateLimitExceededError()
code = "".join([str(random.randint(0, 9)) for _ in range(6)])
token = TokenManager.generate_token(
account=account, email=email, token_type="reset_password", additional_data={"code": code}
)
send_reset_password_mail_task.delay(
language=language,
to=account_email,
code=code,
)
cls.reset_password_rate_limiter.increment_rate_limit(account_email)
return token return token
@classmethod @classmethod
@ -311,11 +373,125 @@ class AccountService:
def get_reset_password_data(cls, token: str) -> Optional[dict[str, Any]]: def get_reset_password_data(cls, token: str) -> Optional[dict[str, Any]]:
return TokenManager.get_token_data(token, "reset_password") return TokenManager.get_token_data(token, "reset_password")
@classmethod
def send_email_code_login_email(
cls, account: Optional[Account] = None, email: Optional[str] = None, language: Optional[str] = "en-US"
):
if cls.email_code_login_rate_limiter.is_rate_limited(email):
from controllers.console.auth.error import EmailCodeLoginRateLimitExceededError
raise EmailCodeLoginRateLimitExceededError()
code = "".join([str(random.randint(0, 9)) for _ in range(6)])
token = TokenManager.generate_token(
account=account, email=email, token_type="email_code_login", additional_data={"code": code}
)
send_email_code_login_mail_task.delay(
language=language,
to=account.email if account else email,
code=code,
)
cls.email_code_login_rate_limiter.increment_rate_limit(email)
return token
@classmethod
def get_email_code_login_data(cls, token: str) -> Optional[dict[str, Any]]:
return TokenManager.get_token_data(token, "email_code_login")
@classmethod
def revoke_email_code_login_token(cls, token: str):
TokenManager.revoke_token(token, "email_code_login")
@classmethod
def get_user_through_email(cls, email: str):
account = db.session.query(Account).filter(Account.email == email).first()
if not account:
return None
if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}:
raise Unauthorized("Account is banned or closed.")
return account
@staticmethod
def add_login_error_rate_limit(email: str) -> None:
key = f"login_error_rate_limit:{email}"
count = redis_client.get(key)
if count is None:
count = 0
count = int(count) + 1
redis_client.setex(key, 60 * 60 * 24, count)
@staticmethod
def is_login_error_rate_limit(email: str) -> bool:
key = f"login_error_rate_limit:{email}"
count = redis_client.get(key)
if count is None:
return False
count = int(count)
if count > AccountService.LOGIN_MAX_ERROR_LIMITS:
return True
return False
@staticmethod
def reset_login_error_rate_limit(email: str):
key = f"login_error_rate_limit:{email}"
redis_client.delete(key)
@staticmethod
def is_email_send_ip_limit(ip_address: str):
minute_key = f"email_send_ip_limit_minute:{ip_address}"
freeze_key = f"email_send_ip_limit_freeze:{ip_address}"
hour_limit_key = f"email_send_ip_limit_hour:{ip_address}"
# check ip is frozen
if redis_client.get(freeze_key):
return True
# check current minute count
current_minute_count = redis_client.get(minute_key)
if current_minute_count is None:
current_minute_count = 0
current_minute_count = int(current_minute_count)
# check current hour count
if current_minute_count > dify_config.EMAIL_SEND_IP_LIMIT_PER_MINUTE:
hour_limit_count = redis_client.get(hour_limit_key)
if hour_limit_count is None:
hour_limit_count = 0
hour_limit_count = int(hour_limit_count)
if hour_limit_count >= 1:
redis_client.setex(freeze_key, 60 * 60, 1)
return True
else:
redis_client.setex(hour_limit_key, 60 * 10, hour_limit_count + 1) # first time limit 10 minutes
# add hour limit count
redis_client.incr(hour_limit_key)
redis_client.expire(hour_limit_key, 60 * 60)
return True
redis_client.setex(minute_key, 60, current_minute_count + 1)
redis_client.expire(minute_key, 60)
return False
def _get_login_cache_key(*, account_id: str, token: str):
return f"account_login:{account_id}:{token}"
class TenantService: class TenantService:
@staticmethod @staticmethod
def create_tenant(name: str) -> Tenant: def create_tenant(name: str, is_setup: Optional[bool] = False) -> Tenant:
"""Create tenant""" """Create tenant"""
if not FeatureService.get_system_features().is_allow_create_workspace and not is_setup:
from controllers.console.error import NotAllowedCreateWorkspace
raise NotAllowedCreateWorkspace()
tenant = Tenant(name=name) tenant = Tenant(name=name)
db.session.add(tenant) db.session.add(tenant)
@ -326,8 +502,12 @@ class TenantService:
return tenant return tenant
@staticmethod @staticmethod
def create_owner_tenant_if_not_exist(account: Account, name: Optional[str] = None): def create_owner_tenant_if_not_exist(
account: Account, name: Optional[str] = None, is_setup: Optional[bool] = False
):
"""Create owner tenant if not exist""" """Create owner tenant if not exist"""
if not FeatureService.get_system_features().is_allow_create_workspace and not is_setup:
raise WorkSpaceNotAllowedCreateError()
available_ta = ( available_ta = (
TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first() TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first()
) )
@ -336,9 +516,9 @@ class TenantService:
return return
if name: if name:
tenant = TenantService.create_tenant(name) tenant = TenantService.create_tenant(name=name, is_setup=is_setup)
else: else:
tenant = TenantService.create_tenant(f"{account.name}'s Workspace") tenant = TenantService.create_tenant(name=f"{account.name}'s Workspace", is_setup=is_setup)
TenantService.create_tenant_member(tenant, account, role="owner") TenantService.create_tenant_member(tenant, account, role="owner")
account.current_tenant = tenant account.current_tenant = tenant
db.session.commit() db.session.commit()
@ -352,8 +532,13 @@ class TenantService:
logging.error(f"Tenant {tenant.id} has already an owner.") logging.error(f"Tenant {tenant.id} has already an owner.")
raise Exception("Tenant already has an owner.") raise Exception("Tenant already has an owner.")
ta = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, role=role) ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
db.session.add(ta) if ta:
ta.role = role
else:
ta = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, role=role)
db.session.add(ta)
db.session.commit() db.session.commit()
return ta return ta
@ -570,12 +755,13 @@ class RegisterService:
name=name, name=name,
interface_language=languages[0], interface_language=languages[0],
password=password, password=password,
is_setup=True,
) )
account.last_login_ip = ip_address account.last_login_ip = ip_address
account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None) account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None)
TenantService.create_owner_tenant_if_not_exist(account) TenantService.create_owner_tenant_if_not_exist(account=account, is_setup=True)
dify_setup = DifySetup(version=dify_config.CURRENT_VERSION) dify_setup = DifySetup(version=dify_config.CURRENT_VERSION)
db.session.add(dify_setup) db.session.add(dify_setup)
@ -600,27 +786,33 @@ class RegisterService:
provider: Optional[str] = None, provider: Optional[str] = None,
language: Optional[str] = None, language: Optional[str] = None,
status: Optional[AccountStatus] = None, status: Optional[AccountStatus] = None,
is_setup: Optional[bool] = False,
) -> Account: ) -> Account:
db.session.begin_nested() db.session.begin_nested()
"""Register account""" """Register account"""
try: try:
account = AccountService.create_account( account = AccountService.create_account(
email=email, name=name, interface_language=language or languages[0], password=password email=email,
name=name,
interface_language=language or languages[0],
password=password,
is_setup=is_setup,
) )
account.status = AccountStatus.ACTIVE.value if not status else status.value account.status = AccountStatus.ACTIVE.value if not status else status.value
account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None) account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None)
if open_id is not None or provider is not None: if open_id is not None or provider is not None:
AccountService.link_account_integrate(provider, open_id, account) AccountService.link_account_integrate(provider, open_id, account)
if dify_config.EDITION != "SELF_HOSTED":
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
if FeatureService.get_system_features().is_allow_create_workspace:
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(tenant, account, role="owner") TenantService.create_tenant_member(tenant, account, role="owner")
account.current_tenant = tenant account.current_tenant = tenant
tenant_was_created.send(tenant) tenant_was_created.send(tenant)
db.session.commit() db.session.commit()
except WorkSpaceNotAllowedCreateError:
db.session.rollback()
except Exception as e: except Exception as e:
db.session.rollback() db.session.rollback()
logging.error(f"Register failed: {e}") logging.error(f"Register failed: {e}")
@ -639,7 +831,9 @@ class RegisterService:
TenantService.check_member_permission(tenant, inviter, None, "add") TenantService.check_member_permission(tenant, inviter, None, "add")
name = email.split("@")[0] name = email.split("@")[0]
account = cls.register(email=email, name=name, language=language, status=AccountStatus.PENDING) account = cls.register(
email=email, name=name, language=language, status=AccountStatus.PENDING, is_setup=True
)
# Create new tenant member for invited tenant # Create new tenant member for invited tenant
TenantService.create_tenant_member(tenant, account, role) TenantService.create_tenant_member(tenant, account, role)
TenantService.switch_tenant(account, tenant.id) TenantService.switch_tenant(account, tenant.id)
@ -679,6 +873,11 @@ class RegisterService:
redis_client.setex(cls._get_invitation_token_key(token), expiry_hours * 60 * 60, json.dumps(invitation_data)) redis_client.setex(cls._get_invitation_token_key(token), expiry_hours * 60 * 60, json.dumps(invitation_data))
return token return token
@classmethod
def is_valid_invite_token(cls, token: str) -> bool:
data = redis_client.get(cls._get_invitation_token_key(token))
return data is not None
@classmethod @classmethod
def revoke_token(cls, workspace_id: str, email: str, token: str): def revoke_token(cls, workspace_id: str, email: str, token: str):
if workspace_id and email: if workspace_id and email:
@ -727,7 +926,9 @@ class RegisterService:
} }
@classmethod @classmethod
def _get_invitation_by_token(cls, token: str, workspace_id: str, email: str) -> Optional[dict[str, str]]: def _get_invitation_by_token(
cls, token: str, workspace_id: Optional[str] = None, email: Optional[str] = None
) -> Optional[dict[str, str]]:
if workspace_id is not None and email is not None: if workspace_id is not None and email is not None:
email_hash = sha256(email.encode()).hexdigest() email_hash = sha256(email.encode()).hexdigest()
cache_key = f"member_invite_token:{workspace_id}, {email_hash}:{token}" cache_key = f"member_invite_token:{workspace_id}, {email_hash}:{token}"

View File

@ -13,6 +13,10 @@ class AccountLoginError(BaseServiceError):
pass pass
class AccountPasswordError(BaseServiceError):
pass
class AccountNotLinkTenantError(BaseServiceError): class AccountNotLinkTenantError(BaseServiceError):
pass pass

View File

@ -0,0 +1,9 @@
from services.errors.base import BaseServiceError
class WorkSpaceNotAllowedCreateError(BaseServiceError):
pass
class WorkSpaceNotFoundError(BaseServiceError):
pass

View File

@ -42,6 +42,11 @@ class SystemFeatureModel(BaseModel):
sso_enforced_for_web: bool = False sso_enforced_for_web: bool = False
sso_enforced_for_web_protocol: str = "" sso_enforced_for_web_protocol: str = ""
enable_web_sso_switch_component: bool = False enable_web_sso_switch_component: bool = False
enable_email_code_login: bool = False
enable_email_password_login: bool = True
enable_social_oauth_login: bool = False
is_allow_register: bool = False
is_allow_create_workspace: bool = False
class FeatureService: class FeatureService:
@ -60,12 +65,23 @@ class FeatureService:
def get_system_features(cls) -> SystemFeatureModel: def get_system_features(cls) -> SystemFeatureModel:
system_features = SystemFeatureModel() system_features = SystemFeatureModel()
cls._fulfill_system_params_from_env(system_features)
if dify_config.ENTERPRISE_ENABLED: if dify_config.ENTERPRISE_ENABLED:
system_features.enable_web_sso_switch_component = True system_features.enable_web_sso_switch_component = True
cls._fulfill_params_from_enterprise(system_features) cls._fulfill_params_from_enterprise(system_features)
return system_features return system_features
@classmethod
def _fulfill_system_params_from_env(cls, system_features: SystemFeatureModel):
system_features.enable_email_code_login = dify_config.ENABLE_EMAIL_CODE_LOGIN
system_features.enable_email_password_login = dify_config.ENABLE_EMAIL_PASSWORD_LOGIN
system_features.enable_social_oauth_login = dify_config.ENABLE_SOCIAL_OAUTH_LOGIN
system_features.is_allow_register = dify_config.ALLOW_REGISTER
system_features.is_allow_create_workspace = dify_config.ALLOW_CREATE_WORKSPACE
@classmethod @classmethod
def _fulfill_params_from_env(cls, features: FeatureModel): def _fulfill_params_from_env(cls, features: FeatureModel):
features.can_replace_logo = dify_config.CAN_REPLACE_LOGO features.can_replace_logo = dify_config.CAN_REPLACE_LOGO
@ -113,7 +129,19 @@ class FeatureService:
def _fulfill_params_from_enterprise(cls, features): def _fulfill_params_from_enterprise(cls, features):
enterprise_info = EnterpriseService.get_info() enterprise_info = EnterpriseService.get_info()
features.sso_enforced_for_signin = enterprise_info["sso_enforced_for_signin"] if "sso_enforced_for_signin" in enterprise_info:
features.sso_enforced_for_signin_protocol = enterprise_info["sso_enforced_for_signin_protocol"] features.sso_enforced_for_signin = enterprise_info["sso_enforced_for_signin"]
features.sso_enforced_for_web = enterprise_info["sso_enforced_for_web"] if "sso_enforced_for_signin_protocol" in enterprise_info:
features.sso_enforced_for_web_protocol = enterprise_info["sso_enforced_for_web_protocol"] features.sso_enforced_for_signin_protocol = enterprise_info["sso_enforced_for_signin_protocol"]
if "sso_enforced_for_web" in enterprise_info:
features.sso_enforced_for_web = enterprise_info["sso_enforced_for_web"]
if "sso_enforced_for_web_protocol" in enterprise_info:
features.sso_enforced_for_web_protocol = enterprise_info["sso_enforced_for_web_protocol"]
if "enable_email_code_login" in enterprise_info:
features.enable_email_code_login = enterprise_info["enable_email_code_login"]
if "enable_email_password_login" in enterprise_info:
features.enable_email_password_login = enterprise_info["enable_email_password_login"]
if "is_allow_register" in enterprise_info:
features.is_allow_register = enterprise_info["is_allow_register"]
if "is_allow_create_workspace" in enterprise_info:
features.is_allow_create_workspace = enterprise_info["is_allow_create_workspace"]

View File

@ -0,0 +1,41 @@
import logging
import time
import click
from celery import shared_task
from flask import render_template
from extensions.ext_mail import mail
@shared_task(queue="mail")
def send_email_code_login_mail_task(language: str, to: str, code: str):
"""
Async Send email code login mail
:param language: Language in which the email should be sent (e.g., 'en', 'zh')
:param to: Recipient email address
:param code: Email code to be included in the email
"""
if not mail.is_inited():
return
logging.info(click.style("Start email code login mail to {}".format(to), fg="green"))
start_at = time.perf_counter()
# send email code login mail using different languages
try:
if language == "zh-Hans":
html_content = render_template("email_code_login_mail_template_zh-CN.html", to=to, code=code)
mail.send(to=to, subject="邮箱验证码", html=html_content)
else:
html_content = render_template("email_code_login_mail_template_en-US.html", to=to, code=code)
mail.send(to=to, subject="Email Code", html=html_content)
end_at = time.perf_counter()
logging.info(
click.style(
"Send email code login mail to {} succeeded: latency: {}".format(to, end_at - start_at), fg="green"
)
)
except Exception:
logging.exception("Send email code login mail to {} failed".format(to))

View File

@ -5,17 +5,16 @@ import click
from celery import shared_task from celery import shared_task
from flask import render_template from flask import render_template
from configs import dify_config
from extensions.ext_mail import mail from extensions.ext_mail import mail
@shared_task(queue="mail") @shared_task(queue="mail")
def send_reset_password_mail_task(language: str, to: str, token: str): def send_reset_password_mail_task(language: str, to: str, code: str):
""" """
Async Send reset password mail Async Send reset password mail
:param language: Language in which the email should be sent (e.g., 'en', 'zh') :param language: Language in which the email should be sent (e.g., 'en', 'zh')
:param to: Recipient email address :param to: Recipient email address
:param token: Reset password token to be included in the email :param code: Reset password code
""" """
if not mail.is_inited(): if not mail.is_inited():
return return
@ -25,13 +24,12 @@ def send_reset_password_mail_task(language: str, to: str, token: str):
# send reset password mail using different languages # send reset password mail using different languages
try: try:
url = f"{dify_config.CONSOLE_WEB_URL}/forgot-password?token={token}"
if language == "zh-Hans": if language == "zh-Hans":
html_content = render_template("reset_password_mail_template_zh-CN.html", to=to, url=url) html_content = render_template("reset_password_mail_template_zh-CN.html", to=to, code=code)
mail.send(to=to, subject="置您的 Dify 密码", html=html_content) mail.send(to=to, subject="置您的 Dify 密码", html=html_content)
else: else:
html_content = render_template("reset_password_mail_template_en-US.html", to=to, url=url) html_content = render_template("reset_password_mail_template_en-US.html", to=to, code=code)
mail.send(to=to, subject="Reset Your Dify Password", html=html_content) mail.send(to=to, subject="Set Your Dify Password", html=html_content)
end_at = time.perf_counter() end_at = time.perf_counter()
logging.info( logging.info(

View File

@ -0,0 +1,74 @@
<!DOCTYPE html>
<html>
<head>
<style>
body {
font-family: 'Arial', sans-serif;
line-height: 16pt;
color: #101828;
background-color: #e9ebf0;
margin: 0;
padding: 0;
}
.container {
width: 600px;
height: 360px;
margin: 40px auto;
padding: 36px 48px;
background-color: #fcfcfd;
border-radius: 16px;
border: 1px solid #ffffff;
box-shadow: 0 2px 4px -2px rgba(9, 9, 11, 0.08);
}
.header {
margin-bottom: 24px;
}
.header img {
max-width: 100px;
height: auto;
}
.title {
font-weight: 600;
font-size: 24px;
line-height: 28.8px;
}
.description {
font-size: 13px;
line-height: 16px;
color: #676f83;
margin-top: 12px;
}
.code-content {
padding: 16px 32px;
text-align: center;
border-radius: 16px;
background-color: #f2f4f7;
margin: 16px auto;
}
.code {
line-height: 36px;
font-weight: 700;
font-size: 30px;
}
.tips {
line-height: 16px;
color: #676f83;
font-size: 13px;
}
</style>
</head>
<body>
<div class="container">
<div class="header">
<!-- Optional: Add a logo or a header image here -->
<img src="https://cloud.dify.ai/logo/logo-site.png" alt="Dify Logo" />
</div>
<p class="title">Your login code for Dify</p>
<p class="description">Copy and paste this code, this code will only be valid for the next 5 minutes.</p>
<div class="code-content">
<span class="code">{{code}}</span>
</div>
<p class="tips">If you didn't request a login, don't worry. You can safely ignore this email.</p>
</div>
</body>
</html>

View File

@ -0,0 +1,74 @@
<!DOCTYPE html>
<html>
<head>
<style>
body {
font-family: 'Arial', sans-serif;
line-height: 16pt;
color: #101828;
background-color: #e9ebf0;
margin: 0;
padding: 0;
}
.container {
width: 600px;
height: 360px;
margin: 40px auto;
padding: 36px 48px;
background-color: #fcfcfd;
border-radius: 16px;
border: 1px solid #ffffff;
box-shadow: 0 2px 4px -2px rgba(9, 9, 11, 0.08);
}
.header {
margin-bottom: 24px;
}
.header img {
max-width: 100px;
height: auto;
}
.title {
font-weight: 600;
font-size: 24px;
line-height: 28.8px;
}
.description {
font-size: 13px;
line-height: 16px;
color: #676f83;
margin-top: 12px;
}
.code-content {
padding: 16px 32px;
text-align: center;
border-radius: 16px;
background-color: #f2f4f7;
margin: 16px auto;
}
.code {
line-height: 36px;
font-weight: 700;
font-size: 30px;
}
.tips {
line-height: 16px;
color: #676f83;
font-size: 13px;
}
</style>
</head>
<body>
<div class="container">
<div class="header">
<!-- Optional: Add a logo or a header image here -->
<img src="https://cloud.dify.ai/logo/logo-site.png" alt="Dify Logo" />
</div>
<p class="title">Dify 的登录验证码</p>
<p class="description">复制并粘贴此验证码,注意验证码仅在接下来的 5 分钟内有效。</p>
<div class="code-content">
<span class="code">{{code}}</span>
</div>
<p class="tips">如果您没有请求登录,请不要担心。您可以安全地忽略此电子邮件。</p>
</div>
</body>
</html>

View File

@ -59,7 +59,7 @@
<div class="content"> <div class="content">
<p>Dear {{ to }},</p> <p>Dear {{ to }},</p>
<p>{{ inviter_name }} is pleased to invite you to join our workspace on Dify, a platform specifically designed for LLM application development. On Dify, you can explore, create, and collaborate to build and operate AI applications.</p> <p>{{ inviter_name }} is pleased to invite you to join our workspace on Dify, a platform specifically designed for LLM application development. On Dify, you can explore, create, and collaborate to build and operate AI applications.</p>
<p>You can now log in to Dify using the GitHub or Google account associated with this email.</p> <p>Click the button below to log in to Dify and join the workspace.</p>
<p style="text-align: center;"><a style="color: #fff; text-decoration: none" class="button" href="{{ url }}">Login Here</a></p> <p style="text-align: center;"><a style="color: #fff; text-decoration: none" class="button" href="{{ url }}">Login Here</a></p>
</div> </div>
<div class="footer"> <div class="footer">

View File

@ -59,7 +59,7 @@
<div class="content"> <div class="content">
<p>尊敬的 {{ to }}</p> <p>尊敬的 {{ to }}</p>
<p>{{ inviter_name }} 现邀请您加入我们在 Dify 的工作区,这是一个专为 LLM 应用开发而设计的平台。在 Dify 上,您可以探索、创造和合作,构建和运营 AI 应用。</p> <p>{{ inviter_name }} 现邀请您加入我们在 Dify 的工作区,这是一个专为 LLM 应用开发而设计的平台。在 Dify 上,您可以探索、创造和合作,构建和运营 AI 应用。</p>
<p>您现在可以使用与此邮件相对应的 GitHub 或 Google 账号登录 Dify</p> <p>点击下方按钮即可登录 Dify 并且加入空间</p>
<p style="text-align: center;"><a style="color: #fff; text-decoration: none" class="button" href="{{ url }}">在此登录</a></p> <p style="text-align: center;"><a style="color: #fff; text-decoration: none" class="button" href="{{ url }}">在此登录</a></p>
</div> </div>
<div class="footer"> <div class="footer">

View File

@ -1,72 +1,74 @@
<!DOCTYPE html> <!DOCTYPE html>
<html> <html>
<head> <head>
<style> <style>
body { body {
font-family: 'Arial', sans-serif; font-family: 'Arial', sans-serif;
line-height: 16pt; line-height: 16pt;
color: #374151; color: #101828;
background-color: #E5E7EB; background-color: #e9ebf0;
margin: 0; margin: 0;
padding: 0; padding: 0;
} }
.container { .container {
width: 100%; width: 600px;
max-width: 560px; height: 360px;
margin: 40px auto; margin: 40px auto;
padding: 20px; padding: 36px 48px;
background-color: #F3F4F6; background-color: #fcfcfd;
border-radius: 8px; border-radius: 16px;
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); border: 1px solid #ffffff;
} box-shadow: 0 2px 4px -2px rgba(9, 9, 11, 0.08);
.header { }
text-align: center; .header {
margin-bottom: 20px; margin-bottom: 24px;
} }
.header img { .header img {
max-width: 100px; max-width: 100px;
height: auto; height: auto;
} }
.button { .title {
display: inline-block; font-weight: 600;
padding: 12px 24px; font-size: 24px;
background-color: #2970FF; line-height: 28.8px;
color: white; }
text-decoration: none; .description {
border-radius: 4px; font-size: 13px;
text-align: center; line-height: 16px;
transition: background-color 0.3s ease; color: #676f83;
} margin-top: 12px;
.button:hover { }
background-color: #265DD4; .code-content {
} padding: 16px 32px;
.footer { text-align: center;
font-size: 0.9em; border-radius: 16px;
color: #777777; background-color: #f2f4f7;
margin-top: 30px; margin: 16px auto;
} }
.content { .code {
margin-top: 20px; line-height: 36px;
} font-weight: 700;
font-size: 30px;
}
.tips {
line-height: 16px;
color: #676f83;
font-size: 13px;
}
</style> </style>
</head> </head>
<body>
<body>
<div class="container"> <div class="container">
<div class="header"> <div class="header">
<img src="https://cloud.dify.ai/logo/logo-site.png" alt="Dify Logo"> <!-- Optional: Add a logo or a header image here -->
</div> <img src="https://cloud.dify.ai/logo/logo-site.png" alt="Dify Logo" />
<div class="content"> </div>
<p>Dear {{ to }},</p> <p class="title">Set your Dify password</p>
<p>We have received a request to reset your password. If you initiated this request, please click the button below to reset your password:</p> <p class="description">Copy and paste this code, this code will only be valid for the next 5 minutes.</p>
<p style="text-align: center;"><a style="color: #fff; text-decoration: none" class="button" href="{{ url }}">Reset Password</a></p> <div class="code-content">
<p>If you did not request a password reset, please ignore this email and your account will remain secure.</p> <span class="code">{{code}}</span>
</div> </div>
<div class="footer"> <p class="tips">If you didn't request, don't worry. You can safely ignore this email.</p>
<p>Best regards,</p>
<p>Dify Team</p>
<p>Please do not reply directly to this email; it is automatically sent by the system.</p>
</div>
</div> </div>
</body> </body>
</html> </html>

View File

@ -1,72 +1,74 @@
<!DOCTYPE html> <!DOCTYPE html>
<html> <html>
<head> <head>
<style> <style>
body { body {
font-family: 'Arial', sans-serif; font-family: 'Arial', sans-serif;
line-height: 16pt; line-height: 16pt;
color: #374151; color: #101828;
background-color: #E5E7EB; background-color: #e9ebf0;
margin: 0; margin: 0;
padding: 0; padding: 0;
} }
.container { .container {
width: 100%; width: 600px;
max-width: 560px; height: 360px;
margin: 40px auto; margin: 40px auto;
padding: 20px; padding: 36px 48px;
background-color: #F3F4F6; background-color: #fcfcfd;
border-radius: 8px; border-radius: 16px;
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); border: 1px solid #ffffff;
} box-shadow: 0 2px 4px -2px rgba(9, 9, 11, 0.08);
.header { }
text-align: center; .header {
margin-bottom: 20px; margin-bottom: 24px;
} }
.header img { .header img {
max-width: 100px; max-width: 100px;
height: auto; height: auto;
} }
.button { .title {
display: inline-block; font-weight: 600;
padding: 12px 24px; font-size: 24px;
background-color: #2970FF; line-height: 28.8px;
color: white; }
text-decoration: none; .description {
border-radius: 4px; font-size: 13px;
text-align: center; line-height: 16px;
transition: background-color 0.3s ease; color: #676f83;
} margin-top: 12px;
.button:hover { }
background-color: #265DD4; .code-content {
} padding: 16px 32px;
.footer { text-align: center;
font-size: 0.9em; border-radius: 16px;
color: #777777; background-color: #f2f4f7;
margin-top: 30px; margin: 16px auto;
} }
.content { .code {
margin-top: 20px; line-height: 36px;
} font-weight: 700;
font-size: 30px;
}
.tips {
line-height: 16px;
color: #676f83;
font-size: 13px;
}
</style> </style>
</head> </head>
<body>
<body>
<div class="container"> <div class="container">
<div class="header"> <div class="header">
<img src="https://cloud.dify.ai/logo/logo-site.png" alt="Dify Logo"> <!-- Optional: Add a logo or a header image here -->
</div> <img src="https://cloud.dify.ai/logo/logo-site.png" alt="Dify Logo" />
<div class="content"> </div>
<p>尊敬的 {{ to }}</p> <p class="title">设置您的 Dify 账户密码</p>
<p>我们收到了您关于重置密码的请求。如果是您本人操作,请点击以下按钮重置您的密码:</p> <p class="description">复制并粘贴此验证码,注意验证码仅在接下来的 5 分钟内有效。</p>
<p style="text-align: center;"><a style="color: #fff; text-decoration: none" class="button" href="{{ url }}">重置密码</a></p> <div class="code-content">
<p>如果您没有请求重置密码,请忽略此邮件,您的账户信息将保持安全。</p> <span class="code">{{code}}</span>
</div> </div>
<div class="footer"> <p class="tips">如果您没有请求,请不要担心。您可以安全地忽略此电子邮件。</p>
<p>此致,</p>
<p>Dify 团队</p>
<p>请不要直接回复此电子邮件;由系统自动发送。</p>
</div>
</div> </div>
</body> </body>
</html> </html>

View File

@ -606,8 +606,7 @@ INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=1000
INVITE_EXPIRY_HOURS=72 INVITE_EXPIRY_HOURS=72
# Reset password token valid time (hours), # Reset password token valid time (hours),
# Default: 24. RESET_PASSWORD_TOKEN_EXPIRY_HOURS=1/12
RESET_PASSWORD_TOKEN_EXPIRY_HOURS=24
# The sandbox service endpoint. # The sandbox service endpoint.
CODE_EXECUTION_ENDPOINT=http://sandbox:8194 CODE_EXECUTION_ENDPOINT=http://sandbox:8194
@ -837,5 +836,6 @@ POSITION_TOOL_EXCLUDES=
POSITION_PROVIDER_PINS= POSITION_PROVIDER_PINS=
POSITION_PROVIDER_INCLUDES= POSITION_PROVIDER_INCLUDES=
POSITION_PROVIDER_EXCLUDES= POSITION_PROVIDER_EXCLUDES=
# CSP https://developer.mozilla.org/en-US/docs/Web/HTTP/CSP # CSP https://developer.mozilla.org/en-US/docs/Web/HTTP/CSP
CSP_WHITELIST= CSP_WHITELIST=