login.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. from functools import wraps
  2. from typing import Any
  3. import logging
  4. from typing import Optional
  5. from flask import current_app, g, has_request_context, request
  6. from flask_login import user_logged_in # type: ignore
  7. from flask_login.config import EXEMPT_METHODS # type: ignore
  8. from werkzeug.exceptions import Unauthorized
  9. from werkzeug.local import LocalProxy
  10. from constants.languages import languages
  11. from events.tenant_event import tenant_was_created
  12. from services.errors.account import AccountNotFoundError
  13. from services.errors.workspace import WorkSpaceNotAllowedCreateError
  14. from services.feature_service import FeatureService
  15. from services.account_service import AccountService, RegisterService, TenantService
  16. from libs.oauth import OAuthUserInfo
  17. from configs import dify_config
  18. from extensions.ext_database import db
  19. from models.account import Account, Tenant, TenantAccountJoin
  20. #: A proxy for the current user. If no user is logged in, this will be an
  21. #: anonymous user
  22. current_user: Any = LocalProxy(lambda: _get_user())
  23. def auth_wrapper(func, *args, **kwargs):
  24. auth_header = request.headers.get("Authorization")
  25. if dify_config.ADMIN_API_KEY_ENABLE:
  26. if auth_header:
  27. if " " not in auth_header:
  28. raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
  29. auth_scheme, auth_token = auth_header.split(None, 1)
  30. auth_scheme = auth_scheme.lower()
  31. if auth_scheme != "bearer":
  32. raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
  33. admin_api_key = dify_config.ADMIN_API_KEY
  34. if admin_api_key:
  35. if admin_api_key == auth_token:
  36. workspace_id = request.headers.get("X-WORKSPACE-ID")
  37. if workspace_id:
  38. tenant_account_join = (
  39. db.session.query(Tenant, TenantAccountJoin)
  40. .filter(Tenant.id == workspace_id)
  41. .filter(TenantAccountJoin.tenant_id == Tenant.id)
  42. .filter(TenantAccountJoin.role == "owner")
  43. .one_or_none()
  44. )
  45. if tenant_account_join:
  46. tenant, ta = tenant_account_join
  47. # account = Account.query.filter_by(id=ta.account_id).first()
  48. account = None
  49. # Login admin
  50. openid = request.headers.get("X-USER-OPENID")
  51. user = request.headers.get("X-USERNAME")
  52. role = request.headers.get("X-ROLE", "admin")
  53. email = request.headers.get("X-EMAIL")
  54. provider = request.headers.get("X-PROVIDER", "console")
  55. if openid:
  56. account = Account.get_by_openid(provider, openid)
  57. if account:
  58. if ta.account_id == account.id:
  59. role = "owner"
  60. elif user and email and provider:
  61. account = _generate_account(provider, OAuthUserInfo(openid, user, email), workspace_id)
  62. if account:
  63. if role != "owner":
  64. TenantService.create_tenant_member(tenant, account, role)
  65. account.current_tenant = tenant
  66. current_app.login_manager._update_request_context_with_user(account)
  67. user_logged_in.send(current_app._get_current_object(), user=_get_user())
  68. else:
  69. # Login failed
  70. return current_app.login_manager.unauthorized()
  71. else:
  72. openid = request.headers.get("X-USER-OPENID")
  73. tenant = db.session.query(Tenant).filter(Tenant.id == workspace_id).one_or_none()
  74. account = Account.get_by_openid("console", openid)
  75. if tenant and account:
  76. TenantService.create_tenant_member(tenant, account, "owner")
  77. # Login failed
  78. return current_app.login_manager.unauthorized()
  79. if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED:
  80. pass
  81. elif not current_user.is_authenticated:
  82. return current_app.login_manager.unauthorized()
  83. # flask 1.x compatibility
  84. # current_app.ensure_sync is only available in Flask >= 2.0
  85. if callable(getattr(current_app, "ensure_sync", None)):
  86. return current_app.ensure_sync(func)(*args, **kwargs)
  87. return func(*args, **kwargs)
  88. def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]:
  89. account = Account.get_by_openid(provider, user_info.id)
  90. if not account:
  91. account = Account.query.filter_by(email=user_info.email).first()
  92. return account
  93. def _generate_account(provider: str, user_info: OAuthUserInfo, workspace_id: str | None = None):
  94. # Get account by openid or email.
  95. account = _get_account_by_openid_or_email(provider, user_info)
  96. if account:
  97. logging.info('account %s', account.name)
  98. if workspace_id:
  99. TenantService.switch_tenant(account, workspace_id)
  100. tenant = TenantService.get_join_tenants(account)
  101. if not tenant:
  102. if not FeatureService.get_system_features().is_allow_create_workspace:
  103. raise WorkSpaceNotAllowedCreateError()
  104. else:
  105. tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
  106. TenantService.create_tenant_member(tenant, account, role="owner")
  107. account.current_tenant = tenant
  108. tenant_was_created.send(tenant)
  109. if not account:
  110. if not FeatureService.get_system_features().is_allow_register:
  111. raise AccountNotFoundError()
  112. account_name = user_info.name or "Dify"
  113. account = RegisterService.register(
  114. email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider
  115. )
  116. # Set interface language
  117. preferred_lang = request.accept_languages.best_match(languages)
  118. if preferred_lang and preferred_lang in languages:
  119. interface_language = preferred_lang
  120. else:
  121. interface_language = languages[0]
  122. account.interface_language = interface_language
  123. db.session.commit()
  124. # Link account
  125. AccountService.link_account_integrate(provider, user_info.id, account)
  126. return account
  127. def admin_api_key_required(func):
  128. @wraps(func)
  129. def decorated_view(*args, **kwargs):
  130. current_app.login_manager._update_request_context_with_user(None)
  131. return auth_wrapper(func, *args, **kwargs)
  132. return decorated_view
  133. def auth_wrapper(func, *args, **kwargs):
  134. auth_header = request.headers.get("Authorization")
  135. if dify_config.ADMIN_API_KEY_ENABLE:
  136. if auth_header:
  137. if " " not in auth_header:
  138. raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
  139. auth_scheme, auth_token = auth_header.split(None, 1)
  140. auth_scheme = auth_scheme.lower()
  141. if auth_scheme != "bearer":
  142. raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
  143. admin_api_key = dify_config.ADMIN_API_KEY
  144. if admin_api_key:
  145. if admin_api_key == auth_token:
  146. workspace_id = request.headers.get("X-WORKSPACE-ID")
  147. if workspace_id:
  148. tenant_account_join = (
  149. db.session.query(Tenant, TenantAccountJoin)
  150. .filter(Tenant.id == workspace_id)
  151. .filter(TenantAccountJoin.tenant_id == Tenant.id)
  152. .filter(TenantAccountJoin.role == "owner")
  153. .one_or_none()
  154. )
  155. if tenant_account_join:
  156. tenant, ta = tenant_account_join
  157. # account = Account.query.filter_by(id=ta.account_id).first()
  158. account = None
  159. # Login admin
  160. openid = request.headers.get("X-USER-OPENID")
  161. user = request.headers.get("X-USERNAME")
  162. role = request.headers.get("X-ROLE", "admin")
  163. email = request.headers.get("X-EMAIL")
  164. provider = request.headers.get("X-PROVIDER", "console")
  165. if openid:
  166. account = Account.get_by_openid(provider, openid)
  167. if account:
  168. if ta.account_id == account.id:
  169. role = "owner"
  170. elif user and email and provider:
  171. account = _generate_account(provider, OAuthUserInfo(openid, user, email), workspace_id)
  172. if account:
  173. if role != "owner":
  174. TenantService.create_tenant_member(tenant, account, role)
  175. account.current_tenant = tenant
  176. current_app.login_manager._update_request_context_with_user(account)
  177. user_logged_in.send(current_app._get_current_object(), user=_get_user())
  178. else:
  179. # Login failed
  180. return current_app.login_manager.unauthorized()
  181. else:
  182. openid = request.headers.get("X-USER-OPENID")
  183. tenant = db.session.query(Tenant).filter(Tenant.id == workspace_id).one_or_none()
  184. account = Account.get_by_openid("console", openid)
  185. if tenant and account:
  186. TenantService.create_tenant_member(tenant, account, "owner")
  187. # Login failed
  188. return current_app.login_manager.unauthorized()
  189. if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED:
  190. pass
  191. elif not current_user.is_authenticated:
  192. return current_app.login_manager.unauthorized()
  193. # flask 1.x compatibility
  194. # current_app.ensure_sync is only available in Flask >= 2.0
  195. if callable(getattr(current_app, "ensure_sync", None)):
  196. return current_app.ensure_sync(func)(*args, **kwargs)
  197. return func(*args, **kwargs)
  198. def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]:
  199. account = Account.get_by_openid(provider, user_info.id)
  200. if not account:
  201. account = Account.query.filter_by(email=user_info.email).first()
  202. return account
  203. def _generate_account(provider: str, user_info: OAuthUserInfo, workspace_id: str | None = None):
  204. # Get account by openid or email.
  205. account = _get_account_by_openid_or_email(provider, user_info)
  206. if account:
  207. logging.info('account %s', account.name)
  208. if workspace_id:
  209. TenantService.switch_tenant(account, workspace_id)
  210. tenant = TenantService.get_join_tenants(account)
  211. if not tenant:
  212. if not FeatureService.get_system_features().is_allow_create_workspace:
  213. raise WorkSpaceNotAllowedCreateError()
  214. else:
  215. tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
  216. TenantService.create_tenant_member(tenant, account, role="owner")
  217. account.current_tenant = tenant
  218. tenant_was_created.send(tenant)
  219. if not account:
  220. if not FeatureService.get_system_features().is_allow_register:
  221. raise AccountNotFoundError()
  222. account_name = user_info.name or "Dify"
  223. account = RegisterService.register(
  224. email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider
  225. )
  226. # Set interface language
  227. preferred_lang = request.accept_languages.best_match(languages)
  228. if preferred_lang and preferred_lang in languages:
  229. interface_language = preferred_lang
  230. else:
  231. interface_language = languages[0]
  232. account.interface_language = interface_language
  233. db.session.commit()
  234. # Link account
  235. AccountService.link_account_integrate(provider, user_info.id, account)
  236. return account
  237. def admin_api_key_required(func):
  238. @wraps(func)
  239. def decorated_view(*args, **kwargs):
  240. current_app.login_manager._update_request_context_with_user(None)
  241. return auth_wrapper(func, *args, **kwargs)
  242. return decorated_view
  243. def login_required(func):
  244. """
  245. If you decorate a view with this, it will ensure that the current user is
  246. logged in and authenticated before calling the actual view. (If they are
  247. not, it calls the :attr:`LoginManager.unauthorized` callback.) For
  248. example::
  249. @app.route('/post')
  250. @login_required
  251. def post():
  252. pass
  253. If there are only certain times you need to require that your user is
  254. logged in, you can do so with::
  255. if not current_user.is_authenticated:
  256. return current_app.login_manager.unauthorized()
  257. ...which is essentially the code that this function adds to your views.
  258. It can be convenient to globally turn off authentication when unit testing.
  259. To enable this, if the application configuration variable `LOGIN_DISABLED`
  260. is set to `True`, this decorator will be ignored.
  261. .. Note ::
  262. Per `W3 guidelines for CORS preflight requests
  263. <http://www.w3.org/TR/cors/#cross-origin-request-with-preflight-0>`_,
  264. HTTP ``OPTIONS`` requests are exempt from login checks.
  265. :param func: The view function to decorate.
  266. :type func: function
  267. """
  268. @wraps(func)
  269. def decorated_view(*args, **kwargs):
  270. return auth_wrapper(func, *args, **kwargs)
  271. return decorated_view
  272. def _get_user():
  273. if has_request_context():
  274. if "_login_user" not in g:
  275. current_app.login_manager._load_user() # type: ignore
  276. return g._login_user
  277. return None