models.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. import logging
  2. from flask_login import current_user
  3. from flask_restful import Resource, reqparse
  4. from werkzeug.exceptions import Forbidden
  5. from controllers.console import api
  6. from controllers.console.setup import setup_required
  7. from controllers.console.wraps import account_initialization_required
  8. from core.model_runtime.entities.model_entities import ModelType
  9. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  10. from core.model_runtime.utils.encoders import jsonable_encoder
  11. from libs.login import login_required
  12. from models.account import TenantAccountRole
  13. from services.model_provider_service import ModelProviderService
  14. class DefaultModelApi(Resource):
  15. @setup_required
  16. @login_required
  17. @account_initialization_required
  18. def get(self):
  19. parser = reqparse.RequestParser()
  20. parser.add_argument('model_type', type=str, required=True, nullable=False,
  21. choices=[mt.value for mt in ModelType], location='args')
  22. args = parser.parse_args()
  23. tenant_id = current_user.current_tenant_id
  24. model_provider_service = ModelProviderService()
  25. default_model_entity = model_provider_service.get_default_model_of_model_type(
  26. tenant_id=tenant_id,
  27. model_type=args['model_type']
  28. )
  29. return jsonable_encoder({
  30. "data": default_model_entity
  31. })
  32. @setup_required
  33. @login_required
  34. @account_initialization_required
  35. def post(self):
  36. parser = reqparse.RequestParser()
  37. parser.add_argument('model_settings', type=list, required=True, nullable=False, location='json')
  38. args = parser.parse_args()
  39. tenant_id = current_user.current_tenant_id
  40. model_provider_service = ModelProviderService()
  41. model_settings = args['model_settings']
  42. for model_setting in model_settings:
  43. if 'model_type' not in model_setting or model_setting['model_type'] not in [mt.value for mt in ModelType]:
  44. raise ValueError('invalid model type')
  45. if 'provider' not in model_setting:
  46. continue
  47. if 'model' not in model_setting:
  48. raise ValueError('invalid model')
  49. try:
  50. model_provider_service.update_default_model_of_model_type(
  51. tenant_id=tenant_id,
  52. model_type=model_setting['model_type'],
  53. provider=model_setting['provider'],
  54. model=model_setting['model']
  55. )
  56. except Exception:
  57. logging.warning(f"{model_setting['model_type']} save error")
  58. return {'result': 'success'}
  59. class ModelProviderModelApi(Resource):
  60. @setup_required
  61. @login_required
  62. @account_initialization_required
  63. def get(self, provider):
  64. tenant_id = current_user.current_tenant_id
  65. model_provider_service = ModelProviderService()
  66. models = model_provider_service.get_models_by_provider(
  67. tenant_id=tenant_id,
  68. provider=provider
  69. )
  70. return jsonable_encoder({
  71. "data": models
  72. })
  73. @setup_required
  74. @login_required
  75. @account_initialization_required
  76. def post(self, provider: str):
  77. if not TenantAccountRole.is_privileged_role(current_user.current_tenant.current_role):
  78. raise Forbidden()
  79. tenant_id = current_user.current_tenant_id
  80. parser = reqparse.RequestParser()
  81. parser.add_argument('model', type=str, required=True, nullable=False, location='json')
  82. parser.add_argument('model_type', type=str, required=True, nullable=False,
  83. choices=[mt.value for mt in ModelType], location='json')
  84. parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
  85. args = parser.parse_args()
  86. model_provider_service = ModelProviderService()
  87. try:
  88. model_provider_service.save_model_credentials(
  89. tenant_id=tenant_id,
  90. provider=provider,
  91. model=args['model'],
  92. model_type=args['model_type'],
  93. credentials=args['credentials']
  94. )
  95. except CredentialsValidateFailedError as ex:
  96. raise ValueError(str(ex))
  97. return {'result': 'success'}, 200
  98. @setup_required
  99. @login_required
  100. @account_initialization_required
  101. def delete(self, provider: str):
  102. if not TenantAccountRole.is_privileged_role(current_user.current_tenant.current_role):
  103. raise Forbidden()
  104. tenant_id = current_user.current_tenant_id
  105. parser = reqparse.RequestParser()
  106. parser.add_argument('model', type=str, required=True, nullable=False, location='json')
  107. parser.add_argument('model_type', type=str, required=True, nullable=False,
  108. choices=[mt.value for mt in ModelType], location='json')
  109. args = parser.parse_args()
  110. model_provider_service = ModelProviderService()
  111. model_provider_service.remove_model_credentials(
  112. tenant_id=tenant_id,
  113. provider=provider,
  114. model=args['model'],
  115. model_type=args['model_type']
  116. )
  117. return {'result': 'success'}, 204
  118. class ModelProviderModelCredentialApi(Resource):
  119. @setup_required
  120. @login_required
  121. @account_initialization_required
  122. def get(self, provider: str):
  123. tenant_id = current_user.current_tenant_id
  124. parser = reqparse.RequestParser()
  125. parser.add_argument('model', type=str, required=True, nullable=False, location='args')
  126. parser.add_argument('model_type', type=str, required=True, nullable=False,
  127. choices=[mt.value for mt in ModelType], location='args')
  128. args = parser.parse_args()
  129. model_provider_service = ModelProviderService()
  130. credentials = model_provider_service.get_model_credentials(
  131. tenant_id=tenant_id,
  132. provider=provider,
  133. model_type=args['model_type'],
  134. model=args['model']
  135. )
  136. return {
  137. "credentials": credentials
  138. }
  139. class ModelProviderModelValidateApi(Resource):
  140. @setup_required
  141. @login_required
  142. @account_initialization_required
  143. def post(self, provider: str):
  144. tenant_id = current_user.current_tenant_id
  145. parser = reqparse.RequestParser()
  146. parser.add_argument('model', type=str, required=True, nullable=False, location='json')
  147. parser.add_argument('model_type', type=str, required=True, nullable=False,
  148. choices=[mt.value for mt in ModelType], location='json')
  149. parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
  150. args = parser.parse_args()
  151. model_provider_service = ModelProviderService()
  152. result = True
  153. error = None
  154. try:
  155. model_provider_service.model_credentials_validate(
  156. tenant_id=tenant_id,
  157. provider=provider,
  158. model=args['model'],
  159. model_type=args['model_type'],
  160. credentials=args['credentials']
  161. )
  162. except CredentialsValidateFailedError as ex:
  163. result = False
  164. error = str(ex)
  165. response = {'result': 'success' if result else 'error'}
  166. if not result:
  167. response['error'] = error
  168. return response
  169. class ModelProviderModelParameterRuleApi(Resource):
  170. @setup_required
  171. @login_required
  172. @account_initialization_required
  173. def get(self, provider: str):
  174. parser = reqparse.RequestParser()
  175. parser.add_argument('model', type=str, required=True, nullable=False, location='args')
  176. args = parser.parse_args()
  177. tenant_id = current_user.current_tenant_id
  178. model_provider_service = ModelProviderService()
  179. parameter_rules = model_provider_service.get_model_parameter_rules(
  180. tenant_id=tenant_id,
  181. provider=provider,
  182. model=args['model']
  183. )
  184. return jsonable_encoder({
  185. "data": parameter_rules
  186. })
  187. class ModelProviderAvailableModelApi(Resource):
  188. @setup_required
  189. @login_required
  190. @account_initialization_required
  191. def get(self, model_type):
  192. tenant_id = current_user.current_tenant_id
  193. model_provider_service = ModelProviderService()
  194. models = model_provider_service.get_models_by_model_type(
  195. tenant_id=tenant_id,
  196. model_type=model_type
  197. )
  198. return jsonable_encoder({
  199. "data": models
  200. })
  201. api.add_resource(ModelProviderModelApi, '/workspaces/current/model-providers/<string:provider>/models')
  202. api.add_resource(ModelProviderModelCredentialApi,
  203. '/workspaces/current/model-providers/<string:provider>/models/credentials')
  204. api.add_resource(ModelProviderModelValidateApi,
  205. '/workspaces/current/model-providers/<string:provider>/models/credentials/validate')
  206. api.add_resource(ModelProviderModelParameterRuleApi,
  207. '/workspaces/current/model-providers/<string:provider>/models/parameter-rules')
  208. api.add_resource(ModelProviderAvailableModelApi, '/workspaces/current/models/model-types/<string:model_type>')
  209. api.add_resource(DefaultModelApi, '/workspaces/current/default-model')