| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588 | 
							- import datetime
 
- import json
 
- import logging
 
- import os
 
- from collections import defaultdict
 
- from typing import Optional
 
- import requests
 
- from core.model_providers.model_factory import ModelFactory
 
- from extensions.ext_database import db
 
- from core.model_providers.model_provider_factory import ModelProviderFactory
 
- from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules
 
- from models.provider import Provider, ProviderModel, TenantPreferredModelProvider, ProviderType, ProviderQuotaType, \
 
-     TenantDefaultModel
 
- class ProviderService:
 
-     def get_provider_list(self, tenant_id: str):
 
-         """
 
-         get provider list of tenant.
 
-         :param tenant_id:
 
-         :return:
 
-         """
 
-         # get rules for all providers
 
-         model_provider_rules = ModelProviderFactory.get_provider_rules()
 
-         model_provider_names = [model_provider_name for model_provider_name, _ in model_provider_rules.items()]
 
-         for model_provider_name, model_provider_rule in model_provider_rules.items():
 
-             if ProviderType.SYSTEM.value in model_provider_rule['support_provider_types'] \
 
-                     and 'system_config' in model_provider_rule and model_provider_rule['system_config'] \
 
-                     and 'supported_quota_types' in model_provider_rule['system_config'] \
 
-                     and 'trial' in model_provider_rule['system_config']['supported_quota_types']:
 
-                 ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
 
-         configurable_model_provider_names = [
 
-             model_provider_name
 
-             for model_provider_name, model_provider_rules in model_provider_rules.items()
 
-             if 'custom' in model_provider_rules['support_provider_types']
 
-                and model_provider_rules['model_flexibility'] == 'configurable'
 
-         ]
 
-         # get all providers for the tenant
 
-         providers = db.session.query(Provider) \
 
-             .filter(
 
-             Provider.tenant_id == tenant_id,
 
-             Provider.provider_name.in_(model_provider_names),
 
-             Provider.is_valid == True
 
-         ).order_by(Provider.created_at.desc()).all()
 
-         provider_name_to_provider_dict = defaultdict(list)
 
-         for provider in providers:
 
-             provider_name_to_provider_dict[provider.provider_name].append(provider)
 
-         # get all configurable provider models for the tenant
 
-         provider_models = db.session.query(ProviderModel) \
 
-             .filter(
 
-             ProviderModel.tenant_id == tenant_id,
 
-             ProviderModel.provider_name.in_(configurable_model_provider_names),
 
-             ProviderModel.is_valid == True
 
-         ).order_by(ProviderModel.created_at.desc()).all()
 
-         provider_name_to_provider_model_dict = defaultdict(list)
 
-         for provider_model in provider_models:
 
-             provider_name_to_provider_model_dict[provider_model.provider_name].append(provider_model)
 
-         # get all preferred provider type for the tenant
 
-         preferred_provider_types = db.session.query(TenantPreferredModelProvider) \
 
-             .filter(
 
-             TenantPreferredModelProvider.tenant_id == tenant_id,
 
-             TenantPreferredModelProvider.provider_name.in_(model_provider_names)
 
-         ).all()
 
-         provider_name_to_preferred_provider_type_dict = {preferred_provider_type.provider_name: preferred_provider_type
 
-                                                          for preferred_provider_type in preferred_provider_types}
 
-         providers_list = {}
 
-         for model_provider_name, model_provider_rule in model_provider_rules.items():
 
-             # get preferred provider type
 
-             preferred_model_provider = provider_name_to_preferred_provider_type_dict.get(model_provider_name)
 
-             preferred_provider_type = ModelProviderFactory.get_preferred_type_by_preferred_model_provider(
 
-                 tenant_id,
 
-                 model_provider_name,
 
-                 preferred_model_provider
 
-             )
 
-             provider_config_dict = {
 
-                 "preferred_provider_type": preferred_provider_type,
 
-                 "model_flexibility": model_provider_rule['model_flexibility'],
 
-             }
 
-             provider_parameter_dict = {}
 
-             if ProviderType.SYSTEM.value in model_provider_rule['support_provider_types']:
 
-                 for quota_type_enum in ProviderQuotaType:
 
-                     quota_type = quota_type_enum.value
 
-                     if quota_type in model_provider_rule['system_config']['supported_quota_types']:
 
-                         key = ProviderType.SYSTEM.value + ':' + quota_type
 
-                         provider_parameter_dict[key] = {
 
-                             "provider_name": model_provider_name,
 
-                             "provider_type": ProviderType.SYSTEM.value,
 
-                             "config": None,
 
-                             "is_valid": False,  # need update
 
-                             "quota_type": quota_type,
 
-                             "quota_unit": model_provider_rule['system_config']['quota_unit'],  # need update
 
-                             "quota_limit": 0 if quota_type != ProviderQuotaType.TRIAL.value else
 
-                             model_provider_rule['system_config']['quota_limit'],  # need update
 
-                             "quota_used": 0,  # need update
 
-                             "last_used": None  # need update
 
-                         }
 
-             if ProviderType.CUSTOM.value in model_provider_rule['support_provider_types']:
 
-                 provider_parameter_dict[ProviderType.CUSTOM.value] = {
 
-                     "provider_name": model_provider_name,
 
-                     "provider_type": ProviderType.CUSTOM.value,
 
-                     "config": None,  # need update
 
-                     "models": [],  # need update
 
-                     "is_valid": False,
 
-                     "last_used": None  # need update
 
-                 }
 
-             model_provider_class = ModelProviderFactory.get_model_provider_class(model_provider_name)
 
-             current_providers = provider_name_to_provider_dict[model_provider_name]
 
-             for provider in current_providers:
 
-                 if provider.provider_type == ProviderType.SYSTEM.value:
 
-                     quota_type = provider.quota_type
 
-                     key = f'{ProviderType.SYSTEM.value}:{quota_type}'
 
-                     if key in provider_parameter_dict:
 
-                         provider_parameter_dict[key]['is_valid'] = provider.is_valid
 
-                         provider_parameter_dict[key]['quota_used'] = provider.quota_used
 
-                         provider_parameter_dict[key]['quota_limit'] = provider.quota_limit
 
-                         provider_parameter_dict[key]['last_used'] = int(provider.last_used.timestamp()) \
 
-                             if provider.last_used else None
 
-                 elif provider.provider_type == ProviderType.CUSTOM.value \
 
-                         and ProviderType.CUSTOM.value in provider_parameter_dict:
 
-                     # if custom
 
-                     key = ProviderType.CUSTOM.value
 
-                     provider_parameter_dict[key]['last_used'] = int(provider.last_used.timestamp()) \
 
-                             if provider.last_used else None
 
-                     provider_parameter_dict[key]['is_valid'] = provider.is_valid
 
-                     if model_provider_rule['model_flexibility'] == 'fixed':
 
-                         provider_parameter_dict[key]['config'] = model_provider_class(provider=provider) \
 
-                             .get_provider_credentials(obfuscated=True)
 
-                     else:
 
-                         models = []
 
-                         provider_models = provider_name_to_provider_model_dict[model_provider_name]
 
-                         for provider_model in provider_models:
 
-                             models.append({
 
-                                 "model_name": provider_model.model_name,
 
-                                 "model_type": provider_model.model_type,
 
-                                 "config": model_provider_class(provider=provider) \
 
-                                     .get_model_credentials(provider_model.model_name,
 
-                                                            ModelType.value_of(provider_model.model_type),
 
-                                                            obfuscated=True),
 
-                                 "is_valid": provider_model.is_valid
 
-                             })
 
-                         provider_parameter_dict[key]['models'] = models
 
-             provider_config_dict['providers'] = list(provider_parameter_dict.values())
 
-             providers_list[model_provider_name] = provider_config_dict
 
-         return providers_list
 
-     def custom_provider_config_validate(self, provider_name: str, config: dict) -> None:
 
-         """
 
-         validate custom provider config.
 
-         :param provider_name:
 
-         :param config:
 
-         :return:
 
-         :raises CredentialsValidateFailedError: When the config credential verification fails.
 
-         """
 
-         # get model provider rules
 
-         model_provider_rules = ModelProviderFactory.get_provider_rule(provider_name)
 
-         if model_provider_rules['model_flexibility'] != 'fixed':
 
-             raise ValueError('Only support fixed model provider')
 
-         # only support provider type CUSTOM
 
-         if ProviderType.CUSTOM.value not in model_provider_rules['support_provider_types']:
 
-             raise ValueError('Only support provider type CUSTOM')
 
-         # validate provider config
 
-         model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name)
 
-         model_provider_class.is_provider_credentials_valid_or_raise(config)
 
-     def save_custom_provider_config(self, tenant_id: str, provider_name: str, config: dict) -> None:
 
-         """
 
-         save custom provider config.
 
-         :param tenant_id:
 
-         :param provider_name:
 
-         :param config:
 
-         :return:
 
-         """
 
-         # validate custom provider config
 
-         self.custom_provider_config_validate(provider_name, config)
 
-         # get provider
 
-         provider = db.session.query(Provider) \
 
-             .filter(
 
-             Provider.tenant_id == tenant_id,
 
-             Provider.provider_name == provider_name,
 
-             Provider.provider_type == ProviderType.CUSTOM.value
 
-         ).first()
 
-         model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name)
 
-         encrypted_config = model_provider_class.encrypt_provider_credentials(tenant_id, config)
 
-         # save provider
 
-         if provider:
 
-             provider.encrypted_config = json.dumps(encrypted_config)
 
-             provider.is_valid = True
 
-             provider.updated_at = datetime.datetime.utcnow()
 
-             db.session.commit()
 
-         else:
 
-             provider = Provider(
 
-                 tenant_id=tenant_id,
 
-                 provider_name=provider_name,
 
-                 provider_type=ProviderType.CUSTOM.value,
 
-                 encrypted_config=json.dumps(encrypted_config),
 
-                 is_valid=True
 
-             )
 
-             db.session.add(provider)
 
-             db.session.commit()
 
-     def delete_custom_provider(self, tenant_id: str, provider_name: str) -> None:
 
-         """
 
-         delete custom provider.
 
-         :param tenant_id:
 
-         :param provider_name:
 
-         :return:
 
-         """
 
-         # get provider
 
-         provider = db.session.query(Provider) \
 
-             .filter(
 
-             Provider.tenant_id == tenant_id,
 
-             Provider.provider_name == provider_name,
 
-             Provider.provider_type == ProviderType.CUSTOM.value
 
-         ).first()
 
-         if provider:
 
-             try:
 
-                 self.switch_preferred_provider(tenant_id, provider_name, ProviderType.SYSTEM.value)
 
-             except ValueError:
 
-                 pass
 
-             db.session.delete(provider)
 
-             db.session.commit()
 
-     def custom_provider_model_config_validate(self,
 
-                                               provider_name: str,
 
-                                               model_name: str,
 
-                                               model_type: str,
 
-                                               config: dict) -> None:
 
-         """
 
-         validate custom provider model config.
 
-         :param provider_name:
 
-         :param model_name:
 
-         :param model_type:
 
-         :param config:
 
-         :return:
 
-         :raises CredentialsValidateFailedError: When the config credential verification fails.
 
-         """
 
-         # get model provider rules
 
-         model_provider_rules = ModelProviderFactory.get_provider_rule(provider_name)
 
-         if model_provider_rules['model_flexibility'] != 'configurable':
 
-             raise ValueError('Only support configurable model provider')
 
-         # only support provider type CUSTOM
 
-         if ProviderType.CUSTOM.value not in model_provider_rules['support_provider_types']:
 
-             raise ValueError('Only support provider type CUSTOM')
 
-         # validate provider model config
 
-         model_type = ModelType.value_of(model_type)
 
-         model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name)
 
-         model_provider_class.is_model_credentials_valid_or_raise(model_name, model_type, config)
 
-     def add_or_save_custom_provider_model_config(self,
 
-                                                  tenant_id: str,
 
-                                                  provider_name: str,
 
-                                                  model_name: str,
 
-                                                  model_type: str,
 
-                                                  config: dict) -> None:
 
-         """
 
-         Add or save custom provider model config.
 
-         :param tenant_id:
 
-         :param provider_name:
 
-         :param model_name:
 
-         :param model_type:
 
-         :param config:
 
-         :return:
 
-         """
 
-         # validate custom provider model config
 
-         self.custom_provider_model_config_validate(provider_name, model_name, model_type, config)
 
-         # get provider
 
-         provider = db.session.query(Provider) \
 
-             .filter(
 
-             Provider.tenant_id == tenant_id,
 
-             Provider.provider_name == provider_name,
 
-             Provider.provider_type == ProviderType.CUSTOM.value
 
-         ).first()
 
-         if not provider:
 
-             provider = Provider(
 
-                 tenant_id=tenant_id,
 
-                 provider_name=provider_name,
 
-                 provider_type=ProviderType.CUSTOM.value,
 
-                 is_valid=True
 
-             )
 
-             db.session.add(provider)
 
-             db.session.commit()
 
-         elif not provider.is_valid:
 
-             provider.is_valid = True
 
-             provider.encrypted_config = None
 
-             db.session.commit()
 
-         model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name)
 
-         encrypted_config = model_provider_class.encrypt_model_credentials(
 
-             tenant_id,
 
-             model_name,
 
-             ModelType.value_of(model_type),
 
-             config
 
-         )
 
-         # get provider model
 
-         provider_model = db.session.query(ProviderModel) \
 
-             .filter(
 
-             ProviderModel.tenant_id == tenant_id,
 
-             ProviderModel.provider_name == provider_name,
 
-             ProviderModel.model_name == model_name,
 
-             ProviderModel.model_type == model_type
 
-         ).first()
 
-         if provider_model:
 
-             provider_model.encrypted_config = json.dumps(encrypted_config)
 
-             provider_model.is_valid = True
 
-             db.session.commit()
 
-         else:
 
-             provider_model = ProviderModel(
 
-                 tenant_id=tenant_id,
 
-                 provider_name=provider_name,
 
-                 model_name=model_name,
 
-                 model_type=model_type,
 
-                 encrypted_config=json.dumps(encrypted_config),
 
-                 is_valid=True
 
-             )
 
-             db.session.add(provider_model)
 
-             db.session.commit()
 
-     def delete_custom_provider_model(self,
 
-                                      tenant_id: str,
 
-                                      provider_name: str,
 
-                                      model_name: str,
 
-                                      model_type: str) -> None:
 
-         """
 
-         delete custom provider model.
 
-         :param tenant_id:
 
-         :param provider_name:
 
-         :param model_name:
 
-         :param model_type:
 
-         :return:
 
-         """
 
-         # get provider model
 
-         provider_model = db.session.query(ProviderModel) \
 
-             .filter(
 
-             ProviderModel.tenant_id == tenant_id,
 
-             ProviderModel.provider_name == provider_name,
 
-             ProviderModel.model_name == model_name,
 
-             ProviderModel.model_type == model_type
 
-         ).first()
 
-         if provider_model:
 
-             db.session.delete(provider_model)
 
-             db.session.commit()
 
-     def switch_preferred_provider(self, tenant_id: str, provider_name: str, preferred_provider_type: str) -> None:
 
-         """
 
-         switch preferred provider.
 
-         :param tenant_id:
 
-         :param provider_name:
 
-         :param preferred_provider_type:
 
-         :return:
 
-         """
 
-         provider_type = ProviderType.value_of(preferred_provider_type)
 
-         if not provider_type:
 
-             raise ValueError(f'Invalid preferred provider type: {preferred_provider_type}')
 
-         model_provider_rules = ModelProviderFactory.get_provider_rule(provider_name)
 
-         if preferred_provider_type not in model_provider_rules['support_provider_types']:
 
-             raise ValueError(f'Not support provider type: {preferred_provider_type}')
 
-         model_provider = ModelProviderFactory.get_model_provider_class(provider_name)
 
-         if not model_provider.is_provider_type_system_supported():
 
-             return
 
-         # get preferred provider
 
-         preferred_model_provider = db.session.query(TenantPreferredModelProvider) \
 
-             .filter(
 
-             TenantPreferredModelProvider.tenant_id == tenant_id,
 
-             TenantPreferredModelProvider.provider_name == provider_name
 
-         ).first()
 
-         if preferred_model_provider:
 
-             preferred_model_provider.preferred_provider_type = preferred_provider_type
 
-         else:
 
-             preferred_model_provider = TenantPreferredModelProvider(
 
-                 tenant_id=tenant_id,
 
-                 provider_name=provider_name,
 
-                 preferred_provider_type=preferred_provider_type
 
-             )
 
-             db.session.add(preferred_model_provider)
 
-         db.session.commit()
 
-     def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> Optional[TenantDefaultModel]:
 
-         """
 
-         get default model of model type.
 
-         :param tenant_id:
 
-         :param model_type:
 
-         :return:
 
-         """
 
-         return ModelFactory.get_default_model(tenant_id, ModelType.value_of(model_type))
 
-     def update_default_model_of_model_type(self,
 
-                                            tenant_id: str,
 
-                                            model_type: str,
 
-                                            provider_name: str,
 
-                                            model_name: str) -> TenantDefaultModel:
 
-         """
 
-         update default model of model type.
 
-         :param tenant_id:
 
-         :param model_type:
 
-         :param provider_name:
 
-         :param model_name:
 
-         :return:
 
-         """
 
-         return ModelFactory.update_default_model(tenant_id, ModelType.value_of(model_type), provider_name, model_name)
 
-     def get_valid_model_list(self, tenant_id: str, model_type: str) -> list:
 
-         """
 
-         get valid model list.
 
-         :param tenant_id:
 
-         :param model_type:
 
-         :return:
 
-         """
 
-         valid_model_list = []
 
-         # get model provider rules
 
-         model_provider_rules = ModelProviderFactory.get_provider_rules()
 
-         for model_provider_name, model_provider_rule in model_provider_rules.items():
 
-             model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
 
-             if not model_provider:
 
-                 continue
 
-             model_list = model_provider.get_supported_model_list(ModelType.value_of(model_type))
 
-             provider = model_provider.provider
 
-             for model in model_list:
 
-                 valid_model_dict = {
 
-                     "model_name": model['id'],
 
-                     "model_display_name": model['name'],
 
-                     "model_type": model_type,
 
-                     "model_provider": {
 
-                         "provider_name": provider.provider_name,
 
-                         "provider_type": provider.provider_type
 
-                     },
 
-                     'features': []
 
-                 }
 
-                 if 'features' in model:
 
-                     valid_model_dict['features'] = model['features']
 
-                 if provider.provider_type == ProviderType.SYSTEM.value:
 
-                     valid_model_dict['model_provider']['quota_type'] = provider.quota_type
 
-                     valid_model_dict['model_provider']['quota_unit'] = model_provider_rule['system_config']['quota_unit']
 
-                     valid_model_dict['model_provider']['quota_limit'] = provider.quota_limit
 
-                     valid_model_dict['model_provider']['quota_used'] = provider.quota_used
 
-                 valid_model_list.append(valid_model_dict)
 
-         return valid_model_list
 
-     def get_model_parameter_rules(self, tenant_id: str, model_provider_name: str, model_name: str, model_type: str) \
 
-             -> ModelKwargsRules:
 
-         """
 
-         get model parameter rules.
 
-         It depends on preferred provider in use.
 
-         :param tenant_id:
 
-         :param model_provider_name:
 
-         :param model_name:
 
-         :param model_type:
 
-         :return:
 
-         """
 
-         # get model provider
 
-         model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
 
-         if not model_provider:
 
-             # get empty model provider
 
-             return ModelKwargsRules()
 
-         # get model parameter rules
 
-         return model_provider.get_model_parameter_rules(model_name, ModelType.value_of(model_type))
 
-     def free_quota_submit(self, tenant_id: str, provider_name: str):
 
-         api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY")
 
-         api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL")
 
-         api_url = api_base_url + '/api/v1/providers/apply'
 
-         headers = {
 
-             'Content-Type': 'application/json',
 
-             'Authorization': f"Bearer {api_key}"
 
-         }
 
-         response = requests.post(api_url, headers=headers, json={'workspace_id': tenant_id, 'provider_name': provider_name})
 
-         if not response.ok:
 
-             logging.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ")
 
-             raise ValueError(f"Error: {response.status_code} ")
 
-         if response.json()["code"] != 'success':
 
-             raise ValueError(
 
-                 f"error: {response.json()['message']}"
 
-             )
 
-         rst = response.json()
 
-         if rst['type'] == 'redirect':
 
-             return {
 
-                 'type': rst['type'],
 
-                 'redirect_url': rst['redirect_url']
 
-             }
 
-         else:
 
-             return {
 
-                 'type': rst['type'],
 
-                 'result': 'success'
 
-             }
 
-     def free_quota_qualification_verify(self, tenant_id: str, provider_name: str, token: Optional[str]):
 
-         api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY")
 
-         api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL")
 
-         api_url = api_base_url + '/api/v1/providers/qualification-verify'
 
-         headers = {
 
-             'Content-Type': 'application/json',
 
-             'Authorization': f"Bearer {api_key}"
 
-         }
 
-         json_data = {'workspace_id': tenant_id, 'provider_name': provider_name}
 
-         if token:
 
-             json_data['token'] = token
 
-         response = requests.post(api_url, headers=headers,
 
-                                  json=json_data)
 
-         if not response.ok:
 
-             logging.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ")
 
-             raise ValueError(f"Error: {response.status_code} ")
 
-         rst = response.json()
 
-         if rst["code"] != 'success':
 
-             raise ValueError(
 
-                 f"error: {rst['message']}"
 
-             )
 
-         data = rst['data']
 
-         if data['qualified'] is True:
 
-             return {
 
-                 'result': 'success',
 
-                 'provider_name': provider_name,
 
-                 'flag': True
 
-             }
 
-         else:
 
-             return {
 
-                 'result': 'success',
 
-                 'provider_name': provider_name,
 
-                 'flag': False,
 
-                 'reason': data['reason']
 
-             }
 
 
  |