Browse Source

Fix/no trial provider (#823)

takatost 1 year ago
parent
commit
8e15ba6cd6

+ 28 - 4
api/core/model_providers/model_provider_factory.py

@@ -168,10 +168,34 @@ class ModelProviderFactory:
             model_provider_rules = ModelProviderFactory.get_provider_rule(model_provider_name)
             for quota_type_enum in ProviderQuotaType:
                 quota_type = quota_type_enum.value
-                if quota_type in model_provider_rules['system_config']['supported_quota_types'] \
-                        and quota_type in quota_type_to_provider_dict.keys():
-                    provider = quota_type_to_provider_dict[quota_type]
-                    if provider.is_valid and provider.quota_limit > provider.quota_used:
+                if quota_type in model_provider_rules['system_config']['supported_quota_types']:
+                    if quota_type in quota_type_to_provider_dict.keys():
+                        provider = quota_type_to_provider_dict[quota_type]
+                        if provider.is_valid and provider.quota_limit > provider.quota_used:
+                            return provider
+                    elif quota_type == ProviderQuotaType.TRIAL.value:
+                        try:
+                            provider = Provider(
+                                tenant_id=tenant_id,
+                                provider_name=model_provider_name,
+                                provider_type=ProviderType.SYSTEM.value,
+                                is_valid=True,
+                                quota_type=ProviderQuotaType.TRIAL.value,
+                                quota_limit=model_provider_rules['system_config']['quota_limit'],
+                                quota_used=0
+                            )
+                            db.session.add(provider)
+                            db.session.commit()
+                        except IntegrityError:
+                            db.session.rollback()
+                            provider = db.session.query(Provider) \
+                                .filter(
+                                Provider.tenant_id == tenant_id,
+                                Provider.provider_name == model_provider_name,
+                                Provider.provider_type == ProviderType.SYSTEM.value,
+                                Provider.quota_type == ProviderQuotaType.TRIAL.value
+                            ).first()
+
                         return provider
 
             no_system_provider = True

+ 8 - 0
api/services/provider_service.py

@@ -23,6 +23,14 @@ class ProviderService:
         # get rules for all providers
         model_provider_rules = ModelProviderFactory.get_provider_rules()
         model_provider_names = [model_provider_name for model_provider_name, _ in model_provider_rules.items()]
+
+        for model_provider_name, model_provider_rule in model_provider_rules.items():
+            if ProviderType.SYSTEM.value in model_provider_rule['support_provider_types'] \
+                    and 'system_config' in model_provider_rule and model_provider_rule['system_config'] \
+                    and 'supported_quota_types' in model_provider_rule['system_config'] \
+                    and 'trial' in model_provider_rule['system_config']['supported_quota_types']:
+                ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
+
         configurable_model_provider_names = [
             model_provider_name
             for model_provider_name, model_provider_rules in model_provider_rules.items()