from functools import wraps from typing import Any import logging from typing import Optional from flask import current_app, g, has_request_context, request from flask_login import user_logged_in # type: ignore from flask_login.config import EXEMPT_METHODS # type: ignore from werkzeug.exceptions import Unauthorized from werkzeug.local import LocalProxy from constants.languages import languages from events.tenant_event import tenant_was_created from services.errors.account import AccountNotFoundError from services.errors.workspace import WorkSpaceNotAllowedCreateError from services.feature_service import FeatureService from services.account_service import AccountService, RegisterService, TenantService from libs.oauth import OAuthUserInfo from configs import dify_config from extensions.ext_database import db from models.account import Account, Tenant, TenantAccountJoin #: A proxy for the current user. If no user is logged in, this will be an #: anonymous user current_user: Any = LocalProxy(lambda: _get_user()) def auth_wrapper(func, *args, **kwargs): auth_header = request.headers.get("Authorization") if dify_config.ADMIN_API_KEY_ENABLE: if auth_header: if " " not in auth_header: raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") auth_scheme, auth_token = auth_header.split(None, 1) auth_scheme = auth_scheme.lower() if auth_scheme != "bearer": raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") admin_api_key = dify_config.ADMIN_API_KEY if admin_api_key: if admin_api_key == auth_token: workspace_id = request.headers.get("X-WORKSPACE-ID") if workspace_id: tenant_account_join = ( db.session.query(Tenant, TenantAccountJoin) .filter(Tenant.id == workspace_id) .filter(TenantAccountJoin.tenant_id == Tenant.id) .filter(TenantAccountJoin.role == "owner") .one_or_none() ) if tenant_account_join: tenant, ta = tenant_account_join # account = Account.query.filter_by(id=ta.account_id).first() account = None # Login admin openid = request.headers.get("X-USER-OPENID") user = request.headers.get("X-USERNAME") role = request.headers.get("X-ROLE", "admin") email = request.headers.get("X-EMAIL") provider = request.headers.get("X-PROVIDER", "console") if openid: account = Account.get_by_openid(provider, openid) if account: if ta.account_id == account.id: role = "owner" elif user and email and provider: account = _generate_account(provider, OAuthUserInfo(openid, user, email), workspace_id) if account: if role != "owner": TenantService.create_tenant_member(tenant, account, role) account.current_tenant = tenant current_app.login_manager._update_request_context_with_user(account) user_logged_in.send(current_app._get_current_object(), user=_get_user()) else: # Login failed return current_app.login_manager.unauthorized() else: openid = request.headers.get("X-USER-OPENID") tenant = db.session.query(Tenant).filter(Tenant.id == workspace_id).one_or_none() account = Account.get_by_openid("console", openid) if tenant and account: TenantService.create_tenant_member(tenant, account, "owner") # Login failed return current_app.login_manager.unauthorized() if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED: pass elif not current_user.is_authenticated: return current_app.login_manager.unauthorized() # flask 1.x compatibility # current_app.ensure_sync is only available in Flask >= 2.0 if callable(getattr(current_app, "ensure_sync", None)): return current_app.ensure_sync(func)(*args, **kwargs) return func(*args, **kwargs) def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]: account = Account.get_by_openid(provider, user_info.id) if not account: account = Account.query.filter_by(email=user_info.email).first() return account def _generate_account(provider: str, user_info: OAuthUserInfo, workspace_id: str | None = None): # Get account by openid or email. account = _get_account_by_openid_or_email(provider, user_info) if account: logging.info('account %s', account.name) if workspace_id: TenantService.switch_tenant(account, workspace_id) tenant = TenantService.get_join_tenants(account) if not tenant: if not FeatureService.get_system_features().is_allow_create_workspace: raise WorkSpaceNotAllowedCreateError() else: tenant = TenantService.create_tenant(f"{account.name}'s Workspace") TenantService.create_tenant_member(tenant, account, role="owner") account.current_tenant = tenant tenant_was_created.send(tenant) if not account: if not FeatureService.get_system_features().is_allow_register: raise AccountNotFoundError() account_name = user_info.name or "Dify" account = RegisterService.register( email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider ) # Set interface language preferred_lang = request.accept_languages.best_match(languages) if preferred_lang and preferred_lang in languages: interface_language = preferred_lang else: interface_language = languages[0] account.interface_language = interface_language db.session.commit() # Link account AccountService.link_account_integrate(provider, user_info.id, account) return account def admin_api_key_required(func): @wraps(func) def decorated_view(*args, **kwargs): current_app.login_manager._update_request_context_with_user(None) return auth_wrapper(func, *args, **kwargs) return decorated_view def auth_wrapper(func, *args, **kwargs): auth_header = request.headers.get("Authorization") if dify_config.ADMIN_API_KEY_ENABLE: if auth_header: if " " not in auth_header: raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") auth_scheme, auth_token = auth_header.split(None, 1) auth_scheme = auth_scheme.lower() if auth_scheme != "bearer": raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") admin_api_key = dify_config.ADMIN_API_KEY if admin_api_key: if admin_api_key == auth_token: workspace_id = request.headers.get("X-WORKSPACE-ID") if workspace_id: tenant_account_join = ( db.session.query(Tenant, TenantAccountJoin) .filter(Tenant.id == workspace_id) .filter(TenantAccountJoin.tenant_id == Tenant.id) .filter(TenantAccountJoin.role == "owner") .one_or_none() ) if tenant_account_join: tenant, ta = tenant_account_join # account = Account.query.filter_by(id=ta.account_id).first() account = None # Login admin openid = request.headers.get("X-USER-OPENID") user = request.headers.get("X-USERNAME") role = request.headers.get("X-ROLE", "admin") email = request.headers.get("X-EMAIL") provider = request.headers.get("X-PROVIDER", "console") if openid: account = Account.get_by_openid(provider, openid) if account: if ta.account_id == account.id: role = "owner" elif user and email and provider: account = _generate_account(provider, OAuthUserInfo(openid, user, email), workspace_id) if account: if role != "owner": TenantService.create_tenant_member(tenant, account, role) account.current_tenant = tenant current_app.login_manager._update_request_context_with_user(account) user_logged_in.send(current_app._get_current_object(), user=_get_user()) else: # Login failed return current_app.login_manager.unauthorized() else: openid = request.headers.get("X-USER-OPENID") tenant = db.session.query(Tenant).filter(Tenant.id == workspace_id).one_or_none() account = Account.get_by_openid("console", openid) if tenant and account: TenantService.create_tenant_member(tenant, account, "owner") # Login failed return current_app.login_manager.unauthorized() if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED: pass elif not current_user.is_authenticated: return current_app.login_manager.unauthorized() # flask 1.x compatibility # current_app.ensure_sync is only available in Flask >= 2.0 if callable(getattr(current_app, "ensure_sync", None)): return current_app.ensure_sync(func)(*args, **kwargs) return func(*args, **kwargs) def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]: account = Account.get_by_openid(provider, user_info.id) if not account: account = Account.query.filter_by(email=user_info.email).first() return account def _generate_account(provider: str, user_info: OAuthUserInfo, workspace_id: str | None = None): # Get account by openid or email. account = _get_account_by_openid_or_email(provider, user_info) if account: logging.info('account %s', account.name) if workspace_id: TenantService.switch_tenant(account, workspace_id) tenant = TenantService.get_join_tenants(account) if not tenant: if not FeatureService.get_system_features().is_allow_create_workspace: raise WorkSpaceNotAllowedCreateError() else: tenant = TenantService.create_tenant(f"{account.name}'s Workspace") TenantService.create_tenant_member(tenant, account, role="owner") account.current_tenant = tenant tenant_was_created.send(tenant) if not account: if not FeatureService.get_system_features().is_allow_register: raise AccountNotFoundError() account_name = user_info.name or "Dify" account = RegisterService.register( email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider ) # Set interface language preferred_lang = request.accept_languages.best_match(languages) if preferred_lang and preferred_lang in languages: interface_language = preferred_lang else: interface_language = languages[0] account.interface_language = interface_language db.session.commit() # Link account AccountService.link_account_integrate(provider, user_info.id, account) return account def admin_api_key_required(func): @wraps(func) def decorated_view(*args, **kwargs): current_app.login_manager._update_request_context_with_user(None) return auth_wrapper(func, *args, **kwargs) return decorated_view def login_required(func): """ If you decorate a view with this, it will ensure that the current user is logged in and authenticated before calling the actual view. (If they are not, it calls the :attr:`LoginManager.unauthorized` callback.) For example:: @app.route('/post') @login_required def post(): pass If there are only certain times you need to require that your user is logged in, you can do so with:: if not current_user.is_authenticated: return current_app.login_manager.unauthorized() ...which is essentially the code that this function adds to your views. It can be convenient to globally turn off authentication when unit testing. To enable this, if the application configuration variable `LOGIN_DISABLED` is set to `True`, this decorator will be ignored. .. Note :: Per `W3 guidelines for CORS preflight requests `_, HTTP ``OPTIONS`` requests are exempt from login checks. :param func: The view function to decorate. :type func: function """ @wraps(func) def decorated_view(*args, **kwargs): return auth_wrapper(func, *args, **kwargs) return decorated_view def _get_user(): if has_request_context(): if "_login_user" not in g: current_app.login_manager._load_user() # type: ignore return g._login_user return None