tool_providers.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. import json
  2. from controllers.console import api
  3. from controllers.console.setup import setup_required
  4. from controllers.console.wraps import account_initialization_required
  5. from core.tool.provider.errors import ToolValidateFailedError
  6. from core.tool.provider.tool_provider_service import ToolProviderService
  7. from extensions.ext_database import db
  8. from flask_login import current_user
  9. from flask_restful import Resource, abort, reqparse
  10. from libs.login import login_required
  11. from models.tool import ToolProvider, ToolProviderName
  12. from werkzeug.exceptions import Forbidden
  13. class ToolProviderListApi(Resource):
  14. @setup_required
  15. @login_required
  16. @account_initialization_required
  17. def get(self):
  18. tenant_id = current_user.current_tenant_id
  19. tool_credential_dict = {}
  20. for tool_name in ToolProviderName:
  21. tool_credential_dict[tool_name.value] = {
  22. 'tool_name': tool_name.value,
  23. 'is_enabled': False,
  24. 'credentials': None
  25. }
  26. tool_providers = db.session.query(ToolProvider).filter(ToolProvider.tenant_id == tenant_id).all()
  27. for p in tool_providers:
  28. if p.is_enabled:
  29. tool_credential_dict[p.tool_name] = {
  30. 'tool_name': p.tool_name,
  31. 'is_enabled': p.is_enabled,
  32. 'credentials': ToolProviderService(tenant_id, p.tool_name).get_credentials(obfuscated=True)
  33. }
  34. return list(tool_credential_dict.values())
  35. class ToolProviderCredentialsApi(Resource):
  36. @setup_required
  37. @login_required
  38. @account_initialization_required
  39. def post(self, provider):
  40. if provider not in [p.value for p in ToolProviderName]:
  41. abort(404)
  42. # The role of the current user in the ta table must be admin or owner
  43. if current_user.current_tenant.current_role not in ['admin', 'owner']:
  44. raise Forbidden(f'User {current_user.id} is not authorized to update provider token, '
  45. f'current_role is {current_user.current_tenant.current_role}')
  46. parser = reqparse.RequestParser()
  47. parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
  48. args = parser.parse_args()
  49. tenant_id = current_user.current_tenant_id
  50. tool_provider_service = ToolProviderService(tenant_id, provider)
  51. try:
  52. tool_provider_service.credentials_validate(args['credentials'])
  53. except ToolValidateFailedError as ex:
  54. raise ValueError(str(ex))
  55. encrypted_credentials = json.dumps(tool_provider_service.encrypt_credentials(args['credentials']))
  56. tenant = current_user.current_tenant
  57. tool_provider_model = db.session.query(ToolProvider).filter(
  58. ToolProvider.tenant_id == tenant.id,
  59. ToolProvider.tool_name == provider,
  60. ).first()
  61. # Only allow updating token for CUSTOM provider type
  62. if tool_provider_model:
  63. tool_provider_model.encrypted_credentials = encrypted_credentials
  64. tool_provider_model.is_enabled = True
  65. else:
  66. tool_provider_model = ToolProvider(
  67. tenant_id=tenant.id,
  68. tool_name=provider,
  69. encrypted_credentials=encrypted_credentials,
  70. is_enabled=True
  71. )
  72. db.session.add(tool_provider_model)
  73. db.session.commit()
  74. return {'result': 'success'}, 201
  75. class ToolProviderCredentialsValidateApi(Resource):
  76. @setup_required
  77. @login_required
  78. @account_initialization_required
  79. def post(self, provider):
  80. if provider not in [p.value for p in ToolProviderName]:
  81. abort(404)
  82. parser = reqparse.RequestParser()
  83. parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
  84. args = parser.parse_args()
  85. result = True
  86. error = None
  87. tenant_id = current_user.current_tenant_id
  88. tool_provider_service = ToolProviderService(tenant_id, provider)
  89. try:
  90. tool_provider_service.credentials_validate(args['credentials'])
  91. except ToolValidateFailedError as ex:
  92. result = False
  93. error = str(ex)
  94. response = {'result': 'success' if result else 'error'}
  95. if not result:
  96. response['error'] = error
  97. return response
  98. api.add_resource(ToolProviderListApi, '/workspaces/current/tool-providers')
  99. api.add_resource(ToolProviderCredentialsApi, '/workspaces/current/tool-providers/<provider>/credentials')
  100. api.add_resource(ToolProviderCredentialsValidateApi,
  101. '/workspaces/current/tool-providers/<provider>/credentials-validate')