providers.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. # -*- coding:utf-8 -*-
  2. from flask_login import current_user
  3. from libs.login import login_required
  4. from flask_restful import Resource, reqparse
  5. from werkzeug.exceptions import Forbidden
  6. from controllers.console import api
  7. from controllers.console.setup import setup_required
  8. from controllers.console.wraps import account_initialization_required
  9. from core.model_providers.providers.base import CredentialsValidateFailedError
  10. from models.provider import ProviderType
  11. from services.provider_service import ProviderService
  12. class ProviderListApi(Resource):
  13. @setup_required
  14. @login_required
  15. @account_initialization_required
  16. def get(self):
  17. tenant_id = current_user.current_tenant_id
  18. """
  19. If the type is AZURE_OPENAI, decode and return the four fields of azure_api_type, azure_api_version:,
  20. azure_api_base, azure_api_key as an object, where azure_api_key displays the first 6 bits in plaintext, and the
  21. rest is replaced by * and the last two bits are displayed in plaintext
  22. If the type is other, decode and return the Token field directly, the field displays the first 6 bits in
  23. plaintext, the rest is replaced by * and the last two bits are displayed in plaintext
  24. """
  25. provider_service = ProviderService()
  26. provider_info_list = provider_service.get_provider_list(tenant_id)
  27. provider_list = [
  28. {
  29. 'provider_name': p['provider_name'],
  30. 'provider_type': p['provider_type'],
  31. 'is_valid': p['is_valid'],
  32. 'last_used': p['last_used'],
  33. 'is_enabled': p['is_valid'],
  34. **({
  35. 'quota_type': p['quota_type'],
  36. 'quota_limit': p['quota_limit'],
  37. 'quota_used': p['quota_used']
  38. } if p['provider_type'] == ProviderType.SYSTEM.value else {}),
  39. 'token': (p['config'] if p['provider_name'] != 'openai' else p['config']['openai_api_key'])
  40. if p['config'] else None
  41. }
  42. for name, provider_info in provider_info_list.items()
  43. for p in provider_info['providers']
  44. ]
  45. return provider_list
  46. class ProviderTokenApi(Resource):
  47. @setup_required
  48. @login_required
  49. @account_initialization_required
  50. def post(self, provider):
  51. # The role of the current user in the ta table must be admin or owner
  52. if current_user.current_tenant.current_role not in ['admin', 'owner']:
  53. raise Forbidden()
  54. parser = reqparse.RequestParser()
  55. parser.add_argument('token', required=True, nullable=False, location='json')
  56. args = parser.parse_args()
  57. if provider == 'openai':
  58. args['token'] = {
  59. 'openai_api_key': args['token']
  60. }
  61. provider_service = ProviderService()
  62. try:
  63. provider_service.save_custom_provider_config(
  64. tenant_id=current_user.current_tenant_id,
  65. provider_name=provider,
  66. config=args['token']
  67. )
  68. except CredentialsValidateFailedError as ex:
  69. raise ValueError(str(ex))
  70. return {'result': 'success'}, 201
  71. class ProviderTokenValidateApi(Resource):
  72. @setup_required
  73. @login_required
  74. @account_initialization_required
  75. def post(self, provider):
  76. parser = reqparse.RequestParser()
  77. parser.add_argument('token', required=True, nullable=False, location='json')
  78. args = parser.parse_args()
  79. provider_service = ProviderService()
  80. if provider == 'openai':
  81. args['token'] = {
  82. 'openai_api_key': args['token']
  83. }
  84. result = True
  85. error = None
  86. try:
  87. provider_service.custom_provider_config_validate(
  88. provider_name=provider,
  89. config=args['token']
  90. )
  91. except CredentialsValidateFailedError as ex:
  92. result = False
  93. error = str(ex)
  94. response = {'result': 'success' if result else 'error'}
  95. if not result:
  96. response['error'] = error
  97. return response
  98. api.add_resource(ProviderTokenApi, '/workspaces/current/providers/<provider>/token',
  99. endpoint='workspaces_current_providers_token') # PUT for updating provider token
  100. api.add_resource(ProviderTokenValidateApi, '/workspaces/current/providers/<provider>/token-validate',
  101. endpoint='workspaces_current_providers_token_validate') # POST for validating provider token
  102. api.add_resource(ProviderListApi, '/workspaces/current/providers') # GET for getting providers list