123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511 |
- import datetime
- import json
- from collections import defaultdict
- from typing import Optional
- from core.model_providers.model_factory import ModelFactory
- from extensions.ext_database import db
- from core.model_providers.model_provider_factory import ModelProviderFactory
- from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules
- from models.provider import Provider, ProviderModel, TenantPreferredModelProvider, ProviderType, ProviderQuotaType, \
- TenantDefaultModel
- class ProviderService:
- def get_provider_list(self, tenant_id: str):
- """
- get provider list of tenant.
- :param tenant_id:
- :return:
- """
- # get rules for all providers
- model_provider_rules = ModelProviderFactory.get_provider_rules()
- model_provider_names = [model_provider_name for model_provider_name, _ in model_provider_rules.items()]
- for model_provider_name, model_provider_rule in model_provider_rules.items():
- if ProviderType.SYSTEM.value in model_provider_rule['support_provider_types'] \
- and 'system_config' in model_provider_rule and model_provider_rule['system_config'] \
- and 'supported_quota_types' in model_provider_rule['system_config'] \
- and 'trial' in model_provider_rule['system_config']['supported_quota_types']:
- ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
- configurable_model_provider_names = [
- model_provider_name
- for model_provider_name, model_provider_rules in model_provider_rules.items()
- if 'custom' in model_provider_rules['support_provider_types']
- and model_provider_rules['model_flexibility'] == 'configurable'
- ]
- # get all providers for the tenant
- providers = db.session.query(Provider) \
- .filter(
- Provider.tenant_id == tenant_id,
- Provider.provider_name.in_(model_provider_names),
- Provider.is_valid == True
- ).order_by(Provider.created_at.desc()).all()
- provider_name_to_provider_dict = defaultdict(list)
- for provider in providers:
- provider_name_to_provider_dict[provider.provider_name].append(provider)
- # get all configurable provider models for the tenant
- provider_models = db.session.query(ProviderModel) \
- .filter(
- ProviderModel.tenant_id == tenant_id,
- ProviderModel.provider_name.in_(configurable_model_provider_names),
- ProviderModel.is_valid == True
- ).order_by(ProviderModel.created_at.desc()).all()
- provider_name_to_provider_model_dict = defaultdict(list)
- for provider_model in provider_models:
- provider_name_to_provider_model_dict[provider_model.provider_name].append(provider_model)
- # get all preferred provider type for the tenant
- preferred_provider_types = db.session.query(TenantPreferredModelProvider) \
- .filter(
- TenantPreferredModelProvider.tenant_id == tenant_id,
- TenantPreferredModelProvider.provider_name.in_(model_provider_names)
- ).all()
- provider_name_to_preferred_provider_type_dict = {preferred_provider_type.provider_name: preferred_provider_type
- for preferred_provider_type in preferred_provider_types}
- providers_list = {}
- for model_provider_name, model_provider_rule in model_provider_rules.items():
- # get preferred provider type
- preferred_model_provider = provider_name_to_preferred_provider_type_dict.get(model_provider_name)
- preferred_provider_type = ModelProviderFactory.get_preferred_type_by_preferred_model_provider(
- tenant_id,
- model_provider_name,
- preferred_model_provider
- )
- provider_config_dict = {
- "preferred_provider_type": preferred_provider_type,
- "model_flexibility": model_provider_rule['model_flexibility'],
- }
- provider_parameter_dict = {}
- if ProviderType.SYSTEM.value in model_provider_rule['support_provider_types']:
- for quota_type_enum in ProviderQuotaType:
- quota_type = quota_type_enum.value
- if quota_type in model_provider_rule['system_config']['supported_quota_types']:
- key = ProviderType.SYSTEM.value + ':' + quota_type
- provider_parameter_dict[key] = {
- "provider_name": model_provider_name,
- "provider_type": ProviderType.SYSTEM.value,
- "config": None,
- "is_valid": False, # need update
- "quota_type": quota_type,
- "quota_unit": model_provider_rule['system_config']['quota_unit'], # need update
- "quota_limit": 0 if quota_type != ProviderQuotaType.TRIAL.value else
- model_provider_rule['system_config']['quota_limit'], # need update
- "quota_used": 0, # need update
- "last_used": None # need update
- }
- if ProviderType.CUSTOM.value in model_provider_rule['support_provider_types']:
- provider_parameter_dict[ProviderType.CUSTOM.value] = {
- "provider_name": model_provider_name,
- "provider_type": ProviderType.CUSTOM.value,
- "config": None, # need update
- "models": [], # need update
- "is_valid": False,
- "last_used": None # need update
- }
- model_provider_class = ModelProviderFactory.get_model_provider_class(model_provider_name)
- current_providers = provider_name_to_provider_dict[model_provider_name]
- for provider in current_providers:
- if provider.provider_type == ProviderType.SYSTEM.value:
- quota_type = provider.quota_type
- key = f'{ProviderType.SYSTEM.value}:{quota_type}'
- if key in provider_parameter_dict:
- provider_parameter_dict[key]['is_valid'] = provider.is_valid
- provider_parameter_dict[key]['quota_used'] = provider.quota_used
- provider_parameter_dict[key]['quota_limit'] = provider.quota_limit
- provider_parameter_dict[key]['last_used'] = provider.last_used
- elif provider.provider_type == ProviderType.CUSTOM.value \
- and ProviderType.CUSTOM.value in provider_parameter_dict:
- # if custom
- key = ProviderType.CUSTOM.value
- provider_parameter_dict[key]['last_used'] = provider.last_used
- provider_parameter_dict[key]['is_valid'] = provider.is_valid
- if model_provider_rule['model_flexibility'] == 'fixed':
- provider_parameter_dict[key]['config'] = model_provider_class(provider=provider) \
- .get_provider_credentials(obfuscated=True)
- else:
- models = []
- provider_models = provider_name_to_provider_model_dict[model_provider_name]
- for provider_model in provider_models:
- models.append({
- "model_name": provider_model.model_name,
- "model_type": provider_model.model_type,
- "config": model_provider_class(provider=provider) \
- .get_model_credentials(provider_model.model_name,
- ModelType.value_of(provider_model.model_type),
- obfuscated=True),
- "is_valid": provider_model.is_valid
- })
- provider_parameter_dict[key]['models'] = models
- provider_config_dict['providers'] = list(provider_parameter_dict.values())
- providers_list[model_provider_name] = provider_config_dict
- return providers_list
- def custom_provider_config_validate(self, provider_name: str, config: dict) -> None:
- """
- validate custom provider config.
- :param provider_name:
- :param config:
- :return:
- :raises CredentialsValidateFailedError: When the config credential verification fails.
- """
- # get model provider rules
- model_provider_rules = ModelProviderFactory.get_provider_rule(provider_name)
- if model_provider_rules['model_flexibility'] != 'fixed':
- raise ValueError('Only support fixed model provider')
- # only support provider type CUSTOM
- if ProviderType.CUSTOM.value not in model_provider_rules['support_provider_types']:
- raise ValueError('Only support provider type CUSTOM')
- # validate provider config
- model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name)
- model_provider_class.is_provider_credentials_valid_or_raise(config)
- def save_custom_provider_config(self, tenant_id: str, provider_name: str, config: dict) -> None:
- """
- save custom provider config.
- :param tenant_id:
- :param provider_name:
- :param config:
- :return:
- """
- # validate custom provider config
- self.custom_provider_config_validate(provider_name, config)
- # get provider
- provider = db.session.query(Provider) \
- .filter(
- Provider.tenant_id == tenant_id,
- Provider.provider_name == provider_name,
- Provider.provider_type == ProviderType.CUSTOM.value
- ).first()
- model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name)
- encrypted_config = model_provider_class.encrypt_provider_credentials(tenant_id, config)
- # save provider
- if provider:
- provider.encrypted_config = json.dumps(encrypted_config)
- provider.is_valid = True
- provider.updated_at = datetime.datetime.utcnow()
- db.session.commit()
- else:
- provider = Provider(
- tenant_id=tenant_id,
- provider_name=provider_name,
- provider_type=ProviderType.CUSTOM.value,
- encrypted_config=json.dumps(encrypted_config),
- is_valid=True
- )
- db.session.add(provider)
- db.session.commit()
- def delete_custom_provider(self, tenant_id: str, provider_name: str) -> None:
- """
- delete custom provider.
- :param tenant_id:
- :param provider_name:
- :return:
- """
- # get provider
- provider = db.session.query(Provider) \
- .filter(
- Provider.tenant_id == tenant_id,
- Provider.provider_name == provider_name,
- Provider.provider_type == ProviderType.CUSTOM.value
- ).first()
- if provider:
- try:
- self.switch_preferred_provider(tenant_id, provider_name, ProviderType.SYSTEM.value)
- except ValueError:
- pass
- db.session.delete(provider)
- db.session.commit()
- def custom_provider_model_config_validate(self,
- provider_name: str,
- model_name: str,
- model_type: str,
- config: dict) -> None:
- """
- validate custom provider model config.
- :param provider_name:
- :param model_name:
- :param model_type:
- :param config:
- :return:
- :raises CredentialsValidateFailedError: When the config credential verification fails.
- """
- # get model provider rules
- model_provider_rules = ModelProviderFactory.get_provider_rule(provider_name)
- if model_provider_rules['model_flexibility'] != 'configurable':
- raise ValueError('Only support configurable model provider')
- # only support provider type CUSTOM
- if ProviderType.CUSTOM.value not in model_provider_rules['support_provider_types']:
- raise ValueError('Only support provider type CUSTOM')
- # validate provider model config
- model_type = ModelType.value_of(model_type)
- model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name)
- model_provider_class.is_model_credentials_valid_or_raise(model_name, model_type, config)
- def add_or_save_custom_provider_model_config(self,
- tenant_id: str,
- provider_name: str,
- model_name: str,
- model_type: str,
- config: dict) -> None:
- """
- Add or save custom provider model config.
- :param tenant_id:
- :param provider_name:
- :param model_name:
- :param model_type:
- :param config:
- :return:
- """
- # validate custom provider model config
- self.custom_provider_model_config_validate(provider_name, model_name, model_type, config)
- # get provider
- provider = db.session.query(Provider) \
- .filter(
- Provider.tenant_id == tenant_id,
- Provider.provider_name == provider_name,
- Provider.provider_type == ProviderType.CUSTOM.value
- ).first()
- if not provider:
- provider = Provider(
- tenant_id=tenant_id,
- provider_name=provider_name,
- provider_type=ProviderType.CUSTOM.value,
- is_valid=True
- )
- db.session.add(provider)
- db.session.commit()
- elif not provider.is_valid:
- provider.is_valid = True
- provider.encrypted_config = None
- db.session.commit()
- model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name)
- encrypted_config = model_provider_class.encrypt_model_credentials(
- tenant_id,
- model_name,
- ModelType.value_of(model_type),
- config
- )
- # get provider model
- provider_model = db.session.query(ProviderModel) \
- .filter(
- ProviderModel.tenant_id == tenant_id,
- ProviderModel.provider_name == provider_name,
- ProviderModel.model_name == model_name,
- ProviderModel.model_type == model_type
- ).first()
- if provider_model:
- provider_model.encrypted_config = json.dumps(encrypted_config)
- provider_model.is_valid = True
- db.session.commit()
- else:
- provider_model = ProviderModel(
- tenant_id=tenant_id,
- provider_name=provider_name,
- model_name=model_name,
- model_type=model_type,
- encrypted_config=json.dumps(encrypted_config),
- is_valid=True
- )
- db.session.add(provider_model)
- db.session.commit()
- def delete_custom_provider_model(self,
- tenant_id: str,
- provider_name: str,
- model_name: str,
- model_type: str) -> None:
- """
- delete custom provider model.
- :param tenant_id:
- :param provider_name:
- :param model_name:
- :param model_type:
- :return:
- """
- # get provider model
- provider_model = db.session.query(ProviderModel) \
- .filter(
- ProviderModel.tenant_id == tenant_id,
- ProviderModel.provider_name == provider_name,
- ProviderModel.model_name == model_name,
- ProviderModel.model_type == model_type
- ).first()
- if provider_model:
- db.session.delete(provider_model)
- db.session.commit()
- def switch_preferred_provider(self, tenant_id: str, provider_name: str, preferred_provider_type: str) -> None:
- """
- switch preferred provider.
- :param tenant_id:
- :param provider_name:
- :param preferred_provider_type:
- :return:
- """
- provider_type = ProviderType.value_of(preferred_provider_type)
- if not provider_type:
- raise ValueError(f'Invalid preferred provider type: {preferred_provider_type}')
- model_provider_rules = ModelProviderFactory.get_provider_rule(provider_name)
- if preferred_provider_type not in model_provider_rules['support_provider_types']:
- raise ValueError(f'Not support provider type: {preferred_provider_type}')
- model_provider = ModelProviderFactory.get_model_provider_class(provider_name)
- if not model_provider.is_provider_type_system_supported():
- return
- # get preferred provider
- preferred_model_provider = db.session.query(TenantPreferredModelProvider) \
- .filter(
- TenantPreferredModelProvider.tenant_id == tenant_id,
- TenantPreferredModelProvider.provider_name == provider_name
- ).first()
- if preferred_model_provider:
- preferred_model_provider.preferred_provider_type = preferred_provider_type
- else:
- preferred_model_provider = TenantPreferredModelProvider(
- tenant_id=tenant_id,
- provider_name=provider_name,
- preferred_provider_type=preferred_provider_type
- )
- db.session.add(preferred_model_provider)
- db.session.commit()
- def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> Optional[TenantDefaultModel]:
- """
- get default model of model type.
- :param tenant_id:
- :param model_type:
- :return:
- """
- return ModelFactory.get_default_model(tenant_id, ModelType.value_of(model_type))
- def update_default_model_of_model_type(self,
- tenant_id: str,
- model_type: str,
- provider_name: str,
- model_name: str) -> TenantDefaultModel:
- """
- update default model of model type.
- :param tenant_id:
- :param model_type:
- :param provider_name:
- :param model_name:
- :return:
- """
- return ModelFactory.update_default_model(tenant_id, ModelType.value_of(model_type), provider_name, model_name)
- def get_valid_model_list(self, tenant_id: str, model_type: str) -> list:
- """
- get valid model list.
- :param tenant_id:
- :param model_type:
- :return:
- """
- valid_model_list = []
- # get model provider rules
- model_provider_rules = ModelProviderFactory.get_provider_rules()
- for model_provider_name, model_provider_rule in model_provider_rules.items():
- model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
- if not model_provider:
- continue
- model_list = model_provider.get_supported_model_list(ModelType.value_of(model_type))
- provider = model_provider.provider
- for model in model_list:
- valid_model_dict = {
- "model_name": model['id'],
- "model_type": model_type,
- "model_provider": {
- "provider_name": provider.provider_name,
- "provider_type": provider.provider_type
- },
- 'features': []
- }
- if 'features' in model:
- valid_model_dict['features'] = model['features']
- if provider.provider_type == ProviderType.SYSTEM.value:
- valid_model_dict['model_provider']['quota_type'] = provider.quota_type
- valid_model_dict['model_provider']['quota_unit'] = model_provider_rule['system_config']['quota_unit']
- valid_model_dict['model_provider']['quota_limit'] = provider.quota_limit
- valid_model_dict['model_provider']['quota_used'] = provider.quota_used
- valid_model_list.append(valid_model_dict)
- return valid_model_list
- def get_model_parameter_rules(self, tenant_id: str, model_provider_name: str, model_name: str, model_type: str) \
- -> ModelKwargsRules:
- """
- get model parameter rules.
- It depends on preferred provider in use.
- :param tenant_id:
- :param model_provider_name:
- :param model_name:
- :param model_type:
- :return:
- """
- # get model provider
- model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
- if not model_provider:
- # get empty model provider
- return ModelKwargsRules()
- # get model parameter rules
- return model_provider.get_model_parameter_rules(model_name, ModelType.value_of(model_type))
|