tool_providers.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. import json
  2. from flask_login import login_required, current_user
  3. from flask_restful import Resource, abort, reqparse
  4. from werkzeug.exceptions import Forbidden
  5. from controllers.console import api
  6. from controllers.console.setup import setup_required
  7. from controllers.console.wraps import account_initialization_required
  8. from core.tool.provider.errors import ToolValidateFailedError
  9. from core.tool.provider.tool_provider_service import ToolProviderService
  10. from extensions.ext_database import db
  11. from models.tool import ToolProvider, ToolProviderName
  12. class ToolProviderListApi(Resource):
  13. @setup_required
  14. @login_required
  15. @account_initialization_required
  16. def get(self):
  17. tenant_id = current_user.current_tenant_id
  18. tool_credential_dict = {}
  19. for tool_name in ToolProviderName:
  20. tool_credential_dict[tool_name.value] = {
  21. 'tool_name': tool_name.value,
  22. 'is_enabled': False,
  23. 'credentials': None
  24. }
  25. tool_providers = db.session.query(ToolProvider).filter(ToolProvider.tenant_id == tenant_id).all()
  26. for p in tool_providers:
  27. if p.is_enabled:
  28. tool_credential_dict[p.tool_name] = {
  29. 'tool_name': p.tool_name,
  30. 'is_enabled': p.is_enabled,
  31. 'credentials': ToolProviderService(tenant_id, p.tool_name).get_credentials(obfuscated=True)
  32. }
  33. return list(tool_credential_dict.values())
  34. class ToolProviderCredentialsApi(Resource):
  35. @setup_required
  36. @login_required
  37. @account_initialization_required
  38. def post(self, provider):
  39. if provider not in [p.value for p in ToolProviderName]:
  40. abort(404)
  41. # The role of the current user in the ta table must be admin or owner
  42. if current_user.current_tenant.current_role not in ['admin', 'owner']:
  43. raise Forbidden(f'User {current_user.id} is not authorized to update provider token, '
  44. f'current_role is {current_user.current_tenant.current_role}')
  45. parser = reqparse.RequestParser()
  46. parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
  47. args = parser.parse_args()
  48. tenant_id = current_user.current_tenant_id
  49. tool_provider_service = ToolProviderService(tenant_id, provider)
  50. try:
  51. tool_provider_service.credentials_validate(args['credentials'])
  52. except ToolValidateFailedError as ex:
  53. raise ValueError(str(ex))
  54. encrypted_credentials = json.dumps(tool_provider_service.encrypt_credentials(args['credentials']))
  55. tenant = current_user.current_tenant
  56. tool_provider_model = db.session.query(ToolProvider).filter(
  57. ToolProvider.tenant_id == tenant.id,
  58. ToolProvider.tool_name == provider,
  59. ).first()
  60. # Only allow updating token for CUSTOM provider type
  61. if tool_provider_model:
  62. tool_provider_model.encrypted_credentials = encrypted_credentials
  63. tool_provider_model.is_enabled = True
  64. else:
  65. tool_provider_model = ToolProvider(
  66. tenant_id=tenant.id,
  67. tool_name=provider,
  68. encrypted_credentials=encrypted_credentials,
  69. is_enabled=True
  70. )
  71. db.session.add(tool_provider_model)
  72. db.session.commit()
  73. return {'result': 'success'}, 201
  74. class ToolProviderCredentialsValidateApi(Resource):
  75. @setup_required
  76. @login_required
  77. @account_initialization_required
  78. def post(self, provider):
  79. if provider not in [p.value for p in ToolProviderName]:
  80. abort(404)
  81. parser = reqparse.RequestParser()
  82. parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
  83. args = parser.parse_args()
  84. result = True
  85. error = None
  86. tenant_id = current_user.current_tenant_id
  87. tool_provider_service = ToolProviderService(tenant_id, provider)
  88. try:
  89. tool_provider_service.credentials_validate(args['credentials'])
  90. except ToolValidateFailedError as ex:
  91. result = False
  92. error = str(ex)
  93. response = {'result': 'success' if result else 'error'}
  94. if not result:
  95. response['error'] = error
  96. return response
  97. api.add_resource(ToolProviderListApi, '/workspaces/current/tool-providers')
  98. api.add_resource(ToolProviderCredentialsApi, '/workspaces/current/tool-providers/<provider>/credentials')
  99. api.add_resource(ToolProviderCredentialsValidateApi,
  100. '/workspaces/current/tool-providers/<provider>/credentials-validate')