model_providers.py 8.2 KB


  1. import io
  2. from flask import send_file
  3. from flask_login import current_user
  4. from flask_restful import Resource, reqparse
  5. from werkzeug.exceptions import Forbidden
  6. from controllers.console import api
  7. from controllers.console.setup import setup_required
  8. from controllers.console.wraps import account_initialization_required
  9. from core.model_runtime.entities.model_entities import ModelType
  10. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  11. from core.model_runtime.utils.encoders import jsonable_encoder
  12. from libs.login import login_required
  13. from services.billing_service import BillingService
  14. from services.model_provider_service import ModelProviderService
  15. class ModelProviderListApi(Resource):
  16. @setup_required
  17. @login_required
  18. @account_initialization_required
  19. def get(self):
  20. tenant_id = current_user.current_tenant_id
  21. parser = reqparse.RequestParser()
  22. parser.add_argument('model_type', type=str, required=False, nullable=True,
  23. choices=[mt.value for mt in ModelType], location='args')
  24. args = parser.parse_args()
  25. model_provider_service = ModelProviderService()
  26. provider_list = model_provider_service.get_provider_list(
  27. tenant_id=tenant_id,
  28. model_type=args.get('model_type')
  29. )
  30. return jsonable_encoder({"data": provider_list})
  31. class ModelProviderCredentialApi(Resource):
  32. @setup_required
  33. @login_required
  34. @account_initialization_required
  35. def get(self, provider: str):
  36. tenant_id = current_user.current_tenant_id
  37. model_provider_service = ModelProviderService()
  38. credentials = model_provider_service.get_provider_credentials(
  39. tenant_id=tenant_id,
  40. provider=provider
  41. )
  42. return {
  43. "credentials": credentials
  44. }
  45. class ModelProviderValidateApi(Resource):
  46. @setup_required
  47. @login_required
  48. @account_initialization_required
  49. def post(self, provider: str):
  50. parser = reqparse.RequestParser()
  51. parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
  52. args = parser.parse_args()
  53. tenant_id = current_user.current_tenant_id
  54. model_provider_service = ModelProviderService()
  55. result = True
  56. error = None
  57. try:
  58. model_provider_service.provider_credentials_validate(
  59. tenant_id=tenant_id,
  60. provider=provider,
  61. credentials=args['credentials']
  62. )
  63. except CredentialsValidateFailedError as ex:
  64. result = False
  65. error = str(ex)
  66. response = {'result': 'success' if result else 'error'}
  67. if not result:
  68. response['error'] = error
  69. return response
  70. class ModelProviderApi(Resource):
  71. @setup_required
  72. @login_required
  73. @account_initialization_required
  74. def post(self, provider: str):
  75. if not current_user.is_admin_or_owner:
  76. raise Forbidden()
  77. parser = reqparse.RequestParser()
  78. parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
  79. args = parser.parse_args()
  80. model_provider_service = ModelProviderService()
  81. try:
  82. model_provider_service.save_provider_credentials(
  83. tenant_id=current_user.current_tenant_id,
  84. provider=provider,
  85. credentials=args['credentials']
  86. )
  87. except CredentialsValidateFailedError as ex:
  88. raise ValueError(str(ex))
  89. return {'result': 'success'}, 201
  90. @setup_required
  91. @login_required
  92. @account_initialization_required
  93. def delete(self, provider: str):
  94. if not current_user.is_admin_or_owner:
  95. raise Forbidden()
  96. model_provider_service = ModelProviderService()
  97. model_provider_service.remove_provider_credentials(
  98. tenant_id=current_user.current_tenant_id,
  99. provider=provider
  100. )
  101. return {'result': 'success'}, 204
  102. class ModelProviderIconApi(Resource):
  103. """
  104. Get model provider icon
  105. """
  106. @setup_required
  107. @login_required
  108. @account_initialization_required
  109. def get(self, provider: str, icon_type: str, lang: str):
  110. model_provider_service = ModelProviderService()
  111. icon, mimetype = model_provider_service.get_model_provider_icon(
  112. provider=provider,
  113. icon_type=icon_type,
  114. lang=lang
  115. )
  116. return send_file(io.BytesIO(icon), mimetype=mimetype)
  117. class PreferredProviderTypeUpdateApi(Resource):
  118. @setup_required
  119. @login_required
  120. @account_initialization_required
  121. def post(self, provider: str):
  122. if not current_user.is_admin_or_owner:
  123. raise Forbidden()
  124. tenant_id = current_user.current_tenant_id
  125. parser = reqparse.RequestParser()
  126. parser.add_argument('preferred_provider_type', type=str, required=True, nullable=False,
  127. choices=['system', 'custom'], location='json')
  128. args = parser.parse_args()
  129. model_provider_service = ModelProviderService()
  130. model_provider_service.switch_preferred_provider(
  131. tenant_id=tenant_id,
  132. provider=provider,
  133. preferred_provider_type=args['preferred_provider_type']
  134. )
  135. return {'result': 'success'}
  136. class ModelProviderPaymentCheckoutUrlApi(Resource):
  137. @setup_required
  138. @login_required
  139. @account_initialization_required
  140. def get(self, provider: str):
  141. if provider != 'anthropic':
  142. raise ValueError(f'provider name {provider} is invalid')
  143. BillingService.is_tenant_owner_or_admin(current_user)
  144. data = BillingService.get_model_provider_payment_link(provider_name=provider,
  145. tenant_id=current_user.current_tenant_id,
  146. account_id=current_user.id,
  147. prefilled_email=current_user.email)
  148. return data
  149. class ModelProviderFreeQuotaSubmitApi(Resource):
  150. @setup_required
  151. @login_required
  152. @account_initialization_required
  153. def post(self, provider: str):
  154. model_provider_service = ModelProviderService()
  155. result = model_provider_service.free_quota_submit(
  156. tenant_id=current_user.current_tenant_id,
  157. provider=provider
  158. )
  159. return result
  160. class ModelProviderFreeQuotaQualificationVerifyApi(Resource):
  161. @setup_required
  162. @login_required
  163. @account_initialization_required
  164. def get(self, provider: str):
  165. parser = reqparse.RequestParser()
  166. parser.add_argument('token', type=str, required=False, nullable=True, location='args')
  167. args = parser.parse_args()
  168. model_provider_service = ModelProviderService()
  169. result = model_provider_service.free_quota_qualification_verify(
  170. tenant_id=current_user.current_tenant_id,
  171. provider=provider,
  172. token=args['token']
  173. )
  174. return result
  175. api.add_resource(ModelProviderListApi, '/workspaces/current/model-providers')
  176. api.add_resource(ModelProviderCredentialApi, '/workspaces/current/model-providers/<string:provider>/credentials')
  177. api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers/<string:provider>/credentials/validate')
  178. api.add_resource(ModelProviderApi, '/workspaces/current/model-providers/<string:provider>')
  179. api.add_resource(ModelProviderIconApi, '/workspaces/current/model-providers/<string:provider>/'
  180. '<string:icon_type>/<string:lang>')
  181. api.add_resource(PreferredProviderTypeUpdateApi,
  182. '/workspaces/current/model-providers/<string:provider>/preferred-provider-type')
  183. api.add_resource(ModelProviderPaymentCheckoutUrlApi,
  184. '/workspaces/current/model-providers/<string:provider>/checkout-url')
  185. api.add_resource(ModelProviderFreeQuotaSubmitApi,
  186. '/workspaces/current/model-providers/<string:provider>/free-quota-submit')
  187. api.add_resource(ModelProviderFreeQuotaQualificationVerifyApi,
  188. '/workspaces/current/model-providers/<string:provider>/free-quota-qualification-verify')