123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113 |
- import logging
- from flask_login import current_user
- from libs.login import login_required
- from flask_restful import Resource, reqparse
- from controllers.console import api
- from controllers.console.setup import setup_required
- from controllers.console.wraps import account_initialization_required
- from core.model_providers.model_provider_factory import ModelProviderFactory
- from core.model_providers.models.entity.model_params import ModelType
- from models.provider import ProviderType
- from services.provider_service import ProviderService
- class DefaultModelApi(Resource):
- @setup_required
- @login_required
- @account_initialization_required
- def get(self):
- parser = reqparse.RequestParser()
- parser.add_argument('model_type', type=str, required=True, nullable=False,
- choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='args')
- args = parser.parse_args()
- tenant_id = current_user.current_tenant_id
- provider_service = ProviderService()
- default_model = provider_service.get_default_model_of_model_type(
- tenant_id=tenant_id,
- model_type=args['model_type']
- )
- if not default_model:
- return None
- model_provider = ModelProviderFactory.get_preferred_model_provider(
- tenant_id,
- default_model.provider_name
- )
- if not model_provider:
- return {
- 'model_name': default_model.model_name,
- 'model_type': default_model.model_type,
- 'model_provider': {
- 'provider_name': default_model.provider_name
- }
- }
- provider = model_provider.provider
- rst = {
- 'model_name': default_model.model_name,
- 'model_type': default_model.model_type,
- 'model_provider': {
- 'provider_name': provider.provider_name,
- 'provider_type': provider.provider_type
- }
- }
- model_provider_rules = ModelProviderFactory.get_provider_rule(default_model.provider_name)
- if provider.provider_type == ProviderType.SYSTEM.value:
- rst['model_provider']['quota_type'] = provider.quota_type
- rst['model_provider']['quota_unit'] = model_provider_rules['system_config']['quota_unit']
- rst['model_provider']['quota_limit'] = provider.quota_limit
- rst['model_provider']['quota_used'] = provider.quota_used
- return rst
- @setup_required
- @login_required
- @account_initialization_required
- def post(self):
- parser = reqparse.RequestParser()
- parser.add_argument('model_settings', type=list, required=True, nullable=False, location='json')
- args = parser.parse_args()
- provider_service = ProviderService()
- model_settings = args['model_settings']
- for model_setting in model_settings:
- try:
- provider_service.update_default_model_of_model_type(
- tenant_id=current_user.current_tenant_id,
- model_type=model_setting['model_type'],
- provider_name=model_setting['provider_name'],
- model_name=model_setting['model_name']
- )
- except Exception:
- logging.warning(f"{model_setting['model_type']} save error")
- return {'result': 'success'}
- class ValidModelApi(Resource):
- @setup_required
- @login_required
- @account_initialization_required
- def get(self, model_type):
- ModelType.value_of(model_type)
- provider_service = ProviderService()
- valid_models = provider_service.get_valid_model_list(
- tenant_id=current_user.current_tenant_id,
- model_type=model_type
- )
- return valid_models
- api.add_resource(DefaultModelApi, '/workspaces/current/default-model')
- api.add_resource(ValidModelApi, '/workspaces/current/models/model-type/<string:model_type>')
|