data_source_oauth.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. import logging
  2. import requests
  3. from controllers.console import api
  4. from flask import current_app, redirect, request
  5. from flask_login import current_user
  6. from flask_restful import Resource
  7. from libs.login import login_required
  8. from libs.oauth_data_source import NotionOAuth
  9. from werkzeug.exceptions import Forbidden
  10. from ..setup import setup_required
  11. from ..wraps import account_initialization_required
  12. def get_oauth_providers():
  13. with current_app.app_context():
  14. notion_oauth = NotionOAuth(client_id=current_app.config.get('NOTION_CLIENT_ID'),
  15. client_secret=current_app.config.get(
  16. 'NOTION_CLIENT_SECRET'),
  17. redirect_uri=current_app.config.get(
  18. 'CONSOLE_API_URL') + '/console/api/oauth/data-source/callback/notion')
  19. OAUTH_PROVIDERS = {
  20. 'notion': notion_oauth
  21. }
  22. return OAUTH_PROVIDERS
  23. class OAuthDataSource(Resource):
  24. def get(self, provider: str):
  25. # The role of the current user in the table must be admin or owner
  26. if current_user.current_tenant.current_role not in ['admin', 'owner']:
  27. raise Forbidden()
  28. OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
  29. with current_app.app_context():
  30. oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
  31. print(vars(oauth_provider))
  32. if not oauth_provider:
  33. return {'error': 'Invalid provider'}, 400
  34. if current_app.config.get('NOTION_INTEGRATION_TYPE') == 'internal':
  35. internal_secret = current_app.config.get('NOTION_INTERNAL_SECRET')
  36. oauth_provider.save_internal_access_token(internal_secret)
  37. return { 'data': '' }
  38. else:
  39. auth_url = oauth_provider.get_authorization_url()
  40. return { 'data': auth_url }, 200
  41. class OAuthDataSourceCallback(Resource):
  42. def get(self, provider: str):
  43. OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
  44. with current_app.app_context():
  45. oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
  46. if not oauth_provider:
  47. return {'error': 'Invalid provider'}, 400
  48. if 'code' in request.args:
  49. code = request.args.get('code')
  50. return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?type=notion&code={code}')
  51. elif 'error' in request.args:
  52. error = request.args.get('error')
  53. return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?type=notion&error={error}')
  54. else:
  55. return redirect(f'{current_app.config.get("CONSOLE_WEB_URL")}?type=notion&error=Access denied')
  56. class OAuthDataSourceBinding(Resource):
  57. def get(self, provider: str):
  58. OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
  59. with current_app.app_context():
  60. oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
  61. if not oauth_provider:
  62. return {'error': 'Invalid provider'}, 400
  63. if 'code' in request.args:
  64. code = request.args.get('code')
  65. try:
  66. oauth_provider.get_access_token(code)
  67. except requests.exceptions.HTTPError as e:
  68. logging.exception(
  69. f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}")
  70. return {'error': 'OAuth data source process failed'}, 400
  71. return {'result': 'success'}, 200
  72. class OAuthDataSourceSync(Resource):
  73. @setup_required
  74. @login_required
  75. @account_initialization_required
  76. def get(self, provider, binding_id):
  77. provider = str(provider)
  78. binding_id = str(binding_id)
  79. OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers()
  80. with current_app.app_context():
  81. oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider)
  82. if not oauth_provider:
  83. return {'error': 'Invalid provider'}, 400
  84. try:
  85. oauth_provider.sync_data_source(binding_id)
  86. except requests.exceptions.HTTPError as e:
  87. logging.exception(
  88. f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}")
  89. return {'error': 'OAuth data source process failed'}, 400
  90. return {'result': 'success'}, 200
  91. api.add_resource(OAuthDataSource, '/oauth/data-source/<string:provider>')
  92. api.add_resource(OAuthDataSourceCallback, '/oauth/data-source/callback/<string:provider>')
  93. api.add_resource(OAuthDataSourceBinding, '/oauth/data-source/binding/<string:provider>')
  94. api.add_resource(OAuthDataSourceSync, '/oauth/data-source/<string:provider>/<uuid:binding_id>/sync')