login.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. from functools import wraps
  2. import logging
  3. from typing import Optional
  4. from flask import current_app, g, has_request_context, request
  5. from flask_login import user_logged_in
  6. from flask_login.config import EXEMPT_METHODS
  7. from werkzeug.exceptions import Unauthorized
  8. from werkzeug.local import LocalProxy
  9. from constants.languages import languages
  10. from events.tenant_event import tenant_was_created
  11. from services.errors.account import AccountNotFoundError
  12. from services.errors.workspace import WorkSpaceNotAllowedCreateError
  13. from services.feature_service import FeatureService
  14. from services.account_service import AccountService, RegisterService, TenantService
  15. from libs.oauth import OAuthUserInfo
  16. from configs import dify_config
  17. from extensions.ext_database import db
  18. from models.account import Account, Tenant, TenantAccountJoin
  19. #: A proxy for the current user. If no user is logged in, this will be an
  20. #: anonymous user
  21. current_user = LocalProxy(lambda: _get_user())
  22. def auth_wrapper(func, *args, **kwargs):
  23. auth_header = request.headers.get("Authorization")
  24. if dify_config.ADMIN_API_KEY_ENABLE:
  25. if auth_header:
  26. if " " not in auth_header:
  27. raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
  28. auth_scheme, auth_token = auth_header.split(None, 1)
  29. auth_scheme = auth_scheme.lower()
  30. if auth_scheme != "bearer":
  31. raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
  32. admin_api_key = dify_config.ADMIN_API_KEY
  33. if admin_api_key:
  34. if admin_api_key == auth_token:
  35. workspace_id = request.headers.get("X-WORKSPACE-ID")
  36. if workspace_id:
  37. tenant_account_join = (
  38. db.session.query(Tenant, TenantAccountJoin)
  39. .filter(Tenant.id == workspace_id)
  40. .filter(TenantAccountJoin.tenant_id == Tenant.id)
  41. .filter(TenantAccountJoin.role == "owner")
  42. .one_or_none()
  43. )
  44. if tenant_account_join:
  45. tenant, ta = tenant_account_join
  46. # account = Account.query.filter_by(id=ta.account_id).first()
  47. account = None
  48. # Login admin
  49. openid = request.headers.get("X-USER-OPENID")
  50. user = request.headers.get("X-USERNAME")
  51. role = request.headers.get("X-ROLE", "admin")
  52. email = request.headers.get("X-EMAIL")
  53. provider = request.headers.get("X-PROVIDER", "console")
  54. if openid:
  55. account = Account.get_by_openid(provider, openid)
  56. if account:
  57. if ta.account_id == account.id:
  58. role = "owner"
  59. elif user and email and provider:
  60. account = _generate_account(provider, OAuthUserInfo(openid, user, email), workspace_id)
  61. if account:
  62. if role != "owner":
  63. TenantService.create_tenant_member(tenant, account, role)
  64. account.current_tenant = tenant
  65. current_app.login_manager._update_request_context_with_user(account)
  66. user_logged_in.send(current_app._get_current_object(), user=_get_user())
  67. else:
  68. # Login failed
  69. return current_app.login_manager.unauthorized()
  70. else:
  71. openid = request.headers.get("X-USER-OPENID")
  72. tenant = db.session.query(Tenant).filter(Tenant.id == workspace_id).one_or_none()
  73. account = Account.get_by_openid("console", openid)
  74. if tenant and account:
  75. TenantService.create_tenant_member(tenant, account, "owner")
  76. # Login failed
  77. return current_app.login_manager.unauthorized()
  78. if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED:
  79. pass
  80. elif not current_user.is_authenticated:
  81. return current_app.login_manager.unauthorized()
  82. # flask 1.x compatibility
  83. # current_app.ensure_sync is only available in Flask >= 2.0
  84. if callable(getattr(current_app, "ensure_sync", None)):
  85. return current_app.ensure_sync(func)(*args, **kwargs)
  86. return func(*args, **kwargs)
  87. def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]:
  88. account = Account.get_by_openid(provider, user_info.id)
  89. if not account:
  90. account = Account.query.filter_by(email=user_info.email).first()
  91. return account
  92. def _generate_account(provider: str, user_info: OAuthUserInfo, workspace_id: str | None = None):
  93. # Get account by openid or email.
  94. account = _get_account_by_openid_or_email(provider, user_info)
  95. if account:
  96. logging.info('account %s', account.name)
  97. if workspace_id:
  98. TenantService.switch_tenant(account, workspace_id)
  99. tenant = TenantService.get_join_tenants(account)
  100. if not tenant:
  101. if not FeatureService.get_system_features().is_allow_create_workspace:
  102. raise WorkSpaceNotAllowedCreateError()
  103. else:
  104. tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
  105. TenantService.create_tenant_member(tenant, account, role="owner")
  106. account.current_tenant = tenant
  107. tenant_was_created.send(tenant)
  108. if not account:
  109. if not FeatureService.get_system_features().is_allow_register:
  110. raise AccountNotFoundError()
  111. account_name = user_info.name or "Dify"
  112. account = RegisterService.register(
  113. email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider
  114. )
  115. # Set interface language
  116. preferred_lang = request.accept_languages.best_match(languages)
  117. if preferred_lang and preferred_lang in languages:
  118. interface_language = preferred_lang
  119. else:
  120. interface_language = languages[0]
  121. account.interface_language = interface_language
  122. db.session.commit()
  123. # Link account
  124. AccountService.link_account_integrate(provider, user_info.id, account)
  125. return account
  126. def admin_api_key_required(func):
  127. @wraps(func)
  128. def decorated_view(*args, **kwargs):
  129. current_app.login_manager._update_request_context_with_user(None)
  130. return auth_wrapper(func, *args, **kwargs)
  131. return decorated_view
  132. def login_required(func):
  133. """
  134. If you decorate a view with this, it will ensure that the current user is
  135. logged in and authenticated before calling the actual view. (If they are
  136. not, it calls the :attr:`LoginManager.unauthorized` callback.) For
  137. example::
  138. @app.route('/post')
  139. @login_required
  140. def post():
  141. pass
  142. If there are only certain times you need to require that your user is
  143. logged in, you can do so with::
  144. if not current_user.is_authenticated:
  145. return current_app.login_manager.unauthorized()
  146. ...which is essentially the code that this function adds to your views.
  147. It can be convenient to globally turn off authentication when unit testing.
  148. To enable this, if the application configuration variable `LOGIN_DISABLED`
  149. is set to `True`, this decorator will be ignored.
  150. .. Note ::
  151. Per `W3 guidelines for CORS preflight requests
  152. <http://www.w3.org/TR/cors/#cross-origin-request-with-preflight-0>`_,
  153. HTTP ``OPTIONS`` requests are exempt from login checks.
  154. :param func: The view function to decorate.
  155. :type func: function
  156. """
  157. @wraps(func)
  158. def decorated_view(*args, **kwargs):
  159. return auth_wrapper(func, *args, **kwargs)
  160. return decorated_view
  161. def _get_user():
  162. if has_request_context():
  163. if "_login_user" not in g:
  164. current_app.login_manager._load_user()
  165. return g._login_user
  166. return None