oauth.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. import logging
  2. from datetime import datetime
  3. from typing import Optional
  4. import requests
  5. from extensions.ext_database import db
  6. from flask import current_app, redirect, request
  7. from flask_restful import Resource
  8. from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
  9. from models.account import Account, AccountStatus
  10. from services.account_service import AccountService, RegisterService
  11. from .. import api
  12. def get_oauth_providers():
  13. with current_app.app_context():
  14. github_oauth = GitHubOAuth(client_id=current_app.config.get('GITHUB_CLIENT_ID'),
  15. client_secret=current_app.config.get(
  16. 'GITHUB_CLIENT_SECRET'),
  17. redirect_uri=current_app.config.get(
  18. 'CONSOLE_API_URL') + '/console/api/oauth/authorize/github')
  19. google_oauth = GoogleOAuth(client_id=current_app.config.get('GOOGLE_CLIENT_ID'),
  20. client_secret=current_app.config.get(
  21. 'GOOGLE_CLIENT_SECRET'),
  22. redirect_uri=current_app.config.get(
  23. 'CONSOLE_API_URL') + '/console/api/oauth/authorize/google')
  24. OAUTH_PROVIDERS = {
  25. 'github': github_oauth,
  26. 'google': google_oauth
  27. }
  28. return OAUTH_PROVIDERS
  29. class OAuthLogin(Resource):
  30. def get(self, provider: str):
  31. OAUTH_PROVIDERS = get_oauth_providers()
  32. with current_app.app_context():
  33. oauth_provider = OAUTH_PROVIDERS.get(provider)
  34. print(vars(oauth_provider))
  35. if not oauth_provider:
  36. return {'error': 'Invalid provider'}, 400
  37. auth_url = oauth_provider.get_authorization_url()
  38. return redirect(auth_url)
  39. class OAuthCallback(Resource):
  40. def get(self, provider: str):
  41. OAUTH_PROVIDERS = get_oauth_providers()
  42. with current_app.app_context():
  43. oauth_provider = OAUTH_PROVIDERS.get(provider)
  44. if not oauth_provider:
  45. return {'error': 'Invalid provider'}, 400
  46. code = request.args.get('code')
  47. try:
  48. token = oauth_provider.get_access_token(code)
  49. user_info = oauth_provider.get_user_info(token)
  50. except requests.exceptions.HTTPError as e:
  51. logging.exception(
  52. f"An error occurred during the OAuth process with {provider}: {e.response.text}")
  53. return {'error': 'OAuth process failed'}, 400
  54. account = _generate_account(provider, user_info)
  55. # Check account status
  56. if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value:
  57. return {'error': 'Account is banned or closed.'}, 403
  58. if account.status == AccountStatus.PENDING.value:
  59. account.status = AccountStatus.ACTIVE.value
  60. account.initialized_at = datetime.utcnow()
  61. db.session.commit()
  62. AccountService.update_last_login(account, request)
  63. token = AccountService.get_account_jwt_token(account)
  64. return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?console_token={token}')
  65. def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]:
  66. account = Account.get_by_openid(provider, user_info.id)
  67. if not account:
  68. account = Account.query.filter_by(email=user_info.email).first()
  69. return account
  70. def _generate_account(provider: str, user_info: OAuthUserInfo):
  71. # Get account by openid or email.
  72. account = _get_account_by_openid_or_email(provider, user_info)
  73. if not account:
  74. # Create account
  75. account_name = user_info.name if user_info.name else 'Dify'
  76. account = RegisterService.register(
  77. email=user_info.email,
  78. name=account_name,
  79. password=None,
  80. open_id=user_info.id,
  81. provider=provider
  82. )
  83. # Set interface language
  84. preferred_lang = request.accept_languages.best_match(['zh', 'en'])
  85. if preferred_lang == 'zh':
  86. interface_language = 'zh-Hans'
  87. else:
  88. interface_language = 'en-US'
  89. account.interface_language = interface_language
  90. db.session.commit()
  91. # Link account
  92. AccountService.link_account_integrate(provider, user_info.id, account)
  93. return account
  94. api.add_resource(OAuthLogin, '/oauth/login/<provider>')
  95. api.add_resource(OAuthCallback, '/oauth/authorize/<provider>')