models.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. import logging
  2. from flask_login import current_user
  3. from libs.login import login_required
  4. from flask_restful import Resource, reqparse
  5. from controllers.console import api
  6. from controllers.console.setup import setup_required
  7. from controllers.console.wraps import account_initialization_required
  8. from core.model_providers.model_provider_factory import ModelProviderFactory
  9. from core.model_providers.models.entity.model_params import ModelType
  10. from models.provider import ProviderType
  11. from services.provider_service import ProviderService
  12. class DefaultModelApi(Resource):
  13. @setup_required
  14. @login_required
  15. @account_initialization_required
  16. def get(self):
  17. parser = reqparse.RequestParser()
  18. parser.add_argument('model_type', type=str, required=True, nullable=False,
  19. choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='args')
  20. args = parser.parse_args()
  21. tenant_id = current_user.current_tenant_id
  22. provider_service = ProviderService()
  23. default_model = provider_service.get_default_model_of_model_type(
  24. tenant_id=tenant_id,
  25. model_type=args['model_type']
  26. )
  27. if not default_model:
  28. return None
  29. model_provider = ModelProviderFactory.get_preferred_model_provider(
  30. tenant_id,
  31. default_model.provider_name
  32. )
  33. if not model_provider:
  34. return {
  35. 'model_name': default_model.model_name,
  36. 'model_type': default_model.model_type,
  37. 'model_provider': {
  38. 'provider_name': default_model.provider_name
  39. }
  40. }
  41. provider = model_provider.provider
  42. rst = {
  43. 'model_name': default_model.model_name,
  44. 'model_type': default_model.model_type,
  45. 'model_provider': {
  46. 'provider_name': provider.provider_name,
  47. 'provider_type': provider.provider_type
  48. }
  49. }
  50. model_provider_rules = ModelProviderFactory.get_provider_rule(default_model.provider_name)
  51. if provider.provider_type == ProviderType.SYSTEM.value:
  52. rst['model_provider']['quota_type'] = provider.quota_type
  53. rst['model_provider']['quota_unit'] = model_provider_rules['system_config']['quota_unit']
  54. rst['model_provider']['quota_limit'] = provider.quota_limit
  55. rst['model_provider']['quota_used'] = provider.quota_used
  56. return rst
  57. @setup_required
  58. @login_required
  59. @account_initialization_required
  60. def post(self):
  61. parser = reqparse.RequestParser()
  62. parser.add_argument('model_settings', type=list, required=True, nullable=False, location='json')
  63. args = parser.parse_args()
  64. provider_service = ProviderService()
  65. model_settings = args['model_settings']
  66. for model_setting in model_settings:
  67. try:
  68. provider_service.update_default_model_of_model_type(
  69. tenant_id=current_user.current_tenant_id,
  70. model_type=model_setting['model_type'],
  71. provider_name=model_setting['provider_name'],
  72. model_name=model_setting['model_name']
  73. )
  74. except Exception:
  75. logging.warning(f"{model_setting['model_type']} save error")
  76. return {'result': 'success'}
  77. class ValidModelApi(Resource):
  78. @setup_required
  79. @login_required
  80. @account_initialization_required
  81. def get(self, model_type):
  82. ModelType.value_of(model_type)
  83. provider_service = ProviderService()
  84. valid_models = provider_service.get_valid_model_list(
  85. tenant_id=current_user.current_tenant_id,
  86. model_type=model_type
  87. )
  88. return valid_models
  89. api.add_resource(DefaultModelApi, '/workspaces/current/default-model')
  90. api.add_resource(ValidModelApi, '/workspaces/current/models/model-type/<string:model_type>')