model_providers.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. from flask_login import current_user
  2. from libs.login import login_required
  3. from flask_restful import Resource, reqparse
  4. from werkzeug.exceptions import Forbidden
  5. from controllers.console import api
  6. from controllers.console.app.error import ProviderNotInitializeError
  7. from controllers.console.setup import setup_required
  8. from controllers.console.wraps import account_initialization_required
  9. from core.model_providers.error import LLMBadRequestError
  10. from core.model_providers.providers.base import CredentialsValidateFailedError
  11. from services.provider_checkout_service import ProviderCheckoutService
  12. from services.provider_service import ProviderService
  13. class ModelProviderListApi(Resource):
  14. @setup_required
  15. @login_required
  16. @account_initialization_required
  17. def get(self):
  18. tenant_id = current_user.current_tenant_id
  19. parser = reqparse.RequestParser()
  20. parser.add_argument('model_type', type=str, required=False, nullable=True, location='args')
  21. args = parser.parse_args()
  22. provider_service = ProviderService()
  23. provider_list = provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.get('model_type'))
  24. return provider_list
  25. class ModelProviderValidateApi(Resource):
  26. @setup_required
  27. @login_required
  28. @account_initialization_required
  29. def post(self, provider_name: str):
  30. parser = reqparse.RequestParser()
  31. parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
  32. args = parser.parse_args()
  33. provider_service = ProviderService()
  34. result = True
  35. error = None
  36. try:
  37. provider_service.custom_provider_config_validate(
  38. provider_name=provider_name,
  39. config=args['config']
  40. )
  41. except CredentialsValidateFailedError as ex:
  42. result = False
  43. error = str(ex)
  44. response = {'result': 'success' if result else 'error'}
  45. if not result:
  46. response['error'] = error
  47. return response
  48. class ModelProviderUpdateApi(Resource):
  49. @setup_required
  50. @login_required
  51. @account_initialization_required
  52. def post(self, provider_name: str):
  53. if current_user.current_tenant.current_role not in ['admin', 'owner']:
  54. raise Forbidden()
  55. parser = reqparse.RequestParser()
  56. parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
  57. args = parser.parse_args()
  58. provider_service = ProviderService()
  59. try:
  60. provider_service.save_custom_provider_config(
  61. tenant_id=current_user.current_tenant_id,
  62. provider_name=provider_name,
  63. config=args['config']
  64. )
  65. except CredentialsValidateFailedError as ex:
  66. raise ValueError(str(ex))
  67. return {'result': 'success'}, 201
  68. @setup_required
  69. @login_required
  70. @account_initialization_required
  71. def delete(self, provider_name: str):
  72. if current_user.current_tenant.current_role not in ['admin', 'owner']:
  73. raise Forbidden()
  74. provider_service = ProviderService()
  75. provider_service.delete_custom_provider(
  76. tenant_id=current_user.current_tenant_id,
  77. provider_name=provider_name
  78. )
  79. return {'result': 'success'}, 204
  80. class ModelProviderModelValidateApi(Resource):
  81. @setup_required
  82. @login_required
  83. @account_initialization_required
  84. def post(self, provider_name: str):
  85. parser = reqparse.RequestParser()
  86. parser.add_argument('model_name', type=str, required=True, nullable=False, location='json')
  87. parser.add_argument('model_type', type=str, required=True, nullable=False,
  88. choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='json')
  89. parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
  90. args = parser.parse_args()
  91. provider_service = ProviderService()
  92. result = True
  93. error = None
  94. try:
  95. provider_service.custom_provider_model_config_validate(
  96. provider_name=provider_name,
  97. model_name=args['model_name'],
  98. model_type=args['model_type'],
  99. config=args['config']
  100. )
  101. except CredentialsValidateFailedError as ex:
  102. result = False
  103. error = str(ex)
  104. response = {'result': 'success' if result else 'error'}
  105. if not result:
  106. response['error'] = error
  107. return response
  108. class ModelProviderModelUpdateApi(Resource):
  109. @setup_required
  110. @login_required
  111. @account_initialization_required
  112. def post(self, provider_name: str):
  113. if current_user.current_tenant.current_role not in ['admin', 'owner']:
  114. raise Forbidden()
  115. parser = reqparse.RequestParser()
  116. parser.add_argument('model_name', type=str, required=True, nullable=False, location='json')
  117. parser.add_argument('model_type', type=str, required=True, nullable=False,
  118. choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='json')
  119. parser.add_argument('config', type=dict, required=True, nullable=False, location='json')
  120. args = parser.parse_args()
  121. provider_service = ProviderService()
  122. try:
  123. provider_service.add_or_save_custom_provider_model_config(
  124. tenant_id=current_user.current_tenant_id,
  125. provider_name=provider_name,
  126. model_name=args['model_name'],
  127. model_type=args['model_type'],
  128. config=args['config']
  129. )
  130. except CredentialsValidateFailedError as ex:
  131. raise ValueError(str(ex))
  132. return {'result': 'success'}, 200
  133. @setup_required
  134. @login_required
  135. @account_initialization_required
  136. def delete(self, provider_name: str):
  137. if current_user.current_tenant.current_role not in ['admin', 'owner']:
  138. raise Forbidden()
  139. parser = reqparse.RequestParser()
  140. parser.add_argument('model_name', type=str, required=True, nullable=False, location='args')
  141. parser.add_argument('model_type', type=str, required=True, nullable=False,
  142. choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='args')
  143. args = parser.parse_args()
  144. provider_service = ProviderService()
  145. provider_service.delete_custom_provider_model(
  146. tenant_id=current_user.current_tenant_id,
  147. provider_name=provider_name,
  148. model_name=args['model_name'],
  149. model_type=args['model_type']
  150. )
  151. return {'result': 'success'}, 204
  152. class PreferredProviderTypeUpdateApi(Resource):
  153. @setup_required
  154. @login_required
  155. @account_initialization_required
  156. def post(self, provider_name: str):
  157. if current_user.current_tenant.current_role not in ['admin', 'owner']:
  158. raise Forbidden()
  159. parser = reqparse.RequestParser()
  160. parser.add_argument('preferred_provider_type', type=str, required=True, nullable=False,
  161. choices=['system', 'custom'], location='json')
  162. args = parser.parse_args()
  163. provider_service = ProviderService()
  164. provider_service.switch_preferred_provider(
  165. tenant_id=current_user.current_tenant_id,
  166. provider_name=provider_name,
  167. preferred_provider_type=args['preferred_provider_type']
  168. )
  169. return {'result': 'success'}
  170. class ModelProviderModelParameterRuleApi(Resource):
  171. @setup_required
  172. @login_required
  173. @account_initialization_required
  174. def get(self, provider_name: str):
  175. parser = reqparse.RequestParser()
  176. parser.add_argument('model_name', type=str, required=True, nullable=False, location='args')
  177. args = parser.parse_args()
  178. provider_service = ProviderService()
  179. try:
  180. parameter_rules = provider_service.get_model_parameter_rules(
  181. tenant_id=current_user.current_tenant_id,
  182. model_provider_name=provider_name,
  183. model_name=args['model_name'],
  184. model_type='text-generation'
  185. )
  186. except LLMBadRequestError:
  187. raise ProviderNotInitializeError(
  188. f"Current Text Generation Model is invalid. Please switch to the available model.")
  189. rules = {
  190. k: {
  191. 'enabled': v.enabled,
  192. 'min': v.min,
  193. 'max': v.max,
  194. 'default': v.default,
  195. 'precision': v.precision
  196. }
  197. for k, v in vars(parameter_rules).items()
  198. }
  199. return rules
  200. class ModelProviderPaymentCheckoutUrlApi(Resource):
  201. @setup_required
  202. @login_required
  203. @account_initialization_required
  204. def get(self, provider_name: str):
  205. provider_service = ProviderCheckoutService()
  206. provider_checkout = provider_service.create_checkout(
  207. tenant_id=current_user.current_tenant_id,
  208. provider_name=provider_name,
  209. account=current_user
  210. )
  211. return {
  212. 'url': provider_checkout.get_checkout_url()
  213. }
  214. class ModelProviderFreeQuotaSubmitApi(Resource):
  215. @setup_required
  216. @login_required
  217. @account_initialization_required
  218. def post(self, provider_name: str):
  219. provider_service = ProviderService()
  220. result = provider_service.free_quota_submit(
  221. tenant_id=current_user.current_tenant_id,
  222. provider_name=provider_name
  223. )
  224. return result
  225. class ModelProviderFreeQuotaQualificationVerifyApi(Resource):
  226. @setup_required
  227. @login_required
  228. @account_initialization_required
  229. def get(self, provider_name: str):
  230. parser = reqparse.RequestParser()
  231. parser.add_argument('token', type=str, required=False, nullable=True, location='args')
  232. args = parser.parse_args()
  233. provider_service = ProviderService()
  234. result = provider_service.free_quota_qualification_verify(
  235. tenant_id=current_user.current_tenant_id,
  236. provider_name=provider_name,
  237. token=args['token']
  238. )
  239. return result
  240. api.add_resource(ModelProviderListApi, '/workspaces/current/model-providers')
  241. api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers/<string:provider_name>/validate')
  242. api.add_resource(ModelProviderUpdateApi, '/workspaces/current/model-providers/<string:provider_name>')
  243. api.add_resource(ModelProviderModelValidateApi,
  244. '/workspaces/current/model-providers/<string:provider_name>/models/validate')
  245. api.add_resource(ModelProviderModelUpdateApi,
  246. '/workspaces/current/model-providers/<string:provider_name>/models')
  247. api.add_resource(PreferredProviderTypeUpdateApi,
  248. '/workspaces/current/model-providers/<string:provider_name>/preferred-provider-type')
  249. api.add_resource(ModelProviderModelParameterRuleApi,
  250. '/workspaces/current/model-providers/<string:provider_name>/models/parameter-rules')
  251. api.add_resource(ModelProviderPaymentCheckoutUrlApi,
  252. '/workspaces/current/model-providers/<string:provider_name>/checkout-url')
  253. api.add_resource(ModelProviderFreeQuotaSubmitApi,
  254. '/workspaces/current/model-providers/<string:provider_name>/free-quota-submit')
  255. api.add_resource(ModelProviderFreeQuotaQualificationVerifyApi,
  256. '/workspaces/current/model-providers/<string:provider_name>/free-quota-qualification-verify')