models.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. from flask_login import login_required, current_user
  2. from flask_restful import Resource, reqparse
  3. from controllers.console import api
  4. from controllers.console.setup import setup_required
  5. from controllers.console.wraps import account_initialization_required
  6. from core.model_providers.model_provider_factory import ModelProviderFactory
  7. from core.model_providers.models.entity.model_params import ModelType
  8. from models.provider import ProviderType
  9. from services.provider_service import ProviderService
  10. class DefaultModelApi(Resource):
  11. @setup_required
  12. @login_required
  13. @account_initialization_required
  14. def get(self):
  15. parser = reqparse.RequestParser()
  16. parser.add_argument('model_type', type=str, required=True, nullable=False,
  17. choices=['text-generation', 'embeddings', 'speech2text'], location='args')
  18. args = parser.parse_args()
  19. tenant_id = current_user.current_tenant_id
  20. provider_service = ProviderService()
  21. default_model = provider_service.get_default_model_of_model_type(
  22. tenant_id=tenant_id,
  23. model_type=args['model_type']
  24. )
  25. if not default_model:
  26. return None
  27. model_provider = ModelProviderFactory.get_preferred_model_provider(
  28. tenant_id,
  29. default_model.provider_name
  30. )
  31. if not model_provider:
  32. return {
  33. 'model_name': default_model.model_name,
  34. 'model_type': default_model.model_type,
  35. 'model_provider': {
  36. 'provider_name': default_model.provider_name
  37. }
  38. }
  39. provider = model_provider.provider
  40. rst = {
  41. 'model_name': default_model.model_name,
  42. 'model_type': default_model.model_type,
  43. 'model_provider': {
  44. 'provider_name': provider.provider_name,
  45. 'provider_type': provider.provider_type
  46. }
  47. }
  48. model_provider_rules = ModelProviderFactory.get_provider_rule(default_model.provider_name)
  49. if provider.provider_type == ProviderType.SYSTEM.value:
  50. rst['model_provider']['quota_type'] = provider.quota_type
  51. rst['model_provider']['quota_unit'] = model_provider_rules['system_config']['quota_unit']
  52. rst['model_provider']['quota_limit'] = provider.quota_limit
  53. rst['model_provider']['quota_used'] = provider.quota_used
  54. return rst
  55. @setup_required
  56. @login_required
  57. @account_initialization_required
  58. def post(self):
  59. parser = reqparse.RequestParser()
  60. parser.add_argument('model_name', type=str, required=True, nullable=False, location='json')
  61. parser.add_argument('model_type', type=str, required=True, nullable=False,
  62. choices=['text-generation', 'embeddings', 'speech2text'], location='json')
  63. parser.add_argument('provider_name', type=str, required=True, nullable=False, location='json')
  64. args = parser.parse_args()
  65. provider_service = ProviderService()
  66. provider_service.update_default_model_of_model_type(
  67. tenant_id=current_user.current_tenant_id,
  68. model_type=args['model_type'],
  69. provider_name=args['provider_name'],
  70. model_name=args['model_name']
  71. )
  72. return {'result': 'success'}
  73. class ValidModelApi(Resource):
  74. @setup_required
  75. @login_required
  76. @account_initialization_required
  77. def get(self, model_type):
  78. ModelType.value_of(model_type)
  79. provider_service = ProviderService()
  80. valid_models = provider_service.get_valid_model_list(
  81. tenant_id=current_user.current_tenant_id,
  82. model_type=model_type
  83. )
  84. return valid_models
  85. api.add_resource(DefaultModelApi, '/workspaces/current/default-model')
  86. api.add_resource(ValidModelApi, '/workspaces/current/models/model-type/<string:model_type>')