| 
					
				 | 
			
			
				@@ -1,7 +1,7 @@ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import datetime 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import json 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import logging 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-import time 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from json import JSONDecodeError 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from typing import Optional, List, Dict, Tuple, Iterator 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -11,8 +11,9 @@ from core.entities.model_entities import ModelWithProviderEntity, ModelStatus, S 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from core.entities.provider_entities import SystemConfiguration, CustomConfiguration, SystemConfigurationStatus 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from core.helper import encrypter 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-from core.model_runtime.entities.model_entities import ModelType 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from core.model_runtime.entities.model_entities import ModelType, FetchFrom 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType, \ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ConfigurateMethod 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from core.model_runtime.model_providers import model_provider_factory 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from core.model_runtime.model_providers.__base.ai_model import AIModel 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from core.model_runtime.model_providers.__base.model_provider import ModelProvider 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -22,6 +23,8 @@ from models.provider import ProviderType, Provider, ProviderModel, TenantPreferr 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 logger = logging.getLogger(__name__) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+original_provider_configurate_methods = {} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 class ProviderConfiguration(BaseModel): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     """ 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -34,6 +37,20 @@ class ProviderConfiguration(BaseModel): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     system_configuration: SystemConfiguration 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     custom_configuration: CustomConfiguration 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def __init__(self, **data): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        super().__init__(**data) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if self.provider.provider not in original_provider_configurate_methods: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            original_provider_configurate_methods[self.provider.provider] = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            for configurate_method in self.provider.configurate_methods: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                original_provider_configurate_methods[self.provider.provider].append(configurate_method) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if (any([len(quota_configuration.restrict_models) > 0 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                     for quota_configuration in self.system_configuration.quota_configurations]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         Get current credentials. 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -123,7 +140,8 @@ class ProviderConfiguration(BaseModel): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if provider_record: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             try: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                original_credentials = json.loads(provider_record.encrypted_config) if provider_record.encrypted_config else {} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                original_credentials = json.loads( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    provider_record.encrypted_config) if provider_record.encrypted_config else {} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             except JSONDecodeError: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 original_credentials = {} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -265,7 +283,8 @@ class ProviderConfiguration(BaseModel): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if provider_model_record: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             try: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                original_credentials = json.loads(provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                original_credentials = json.loads( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             except JSONDecodeError: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 original_credentials = {} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -534,21 +553,70 @@ class ProviderConfiguration(BaseModel): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 ] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if self.provider.provider not in original_provider_configurate_methods: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            original_provider_configurate_methods[self.provider.provider] = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            for configurate_method in provider_instance.get_provider_schema().configurate_methods: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                original_provider_configurate_methods[self.provider.provider].append(configurate_method) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        should_use_custom_model = False 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            should_use_custom_model = True 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         for quota_configuration in self.system_configuration.quota_configurations: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             if self.system_configuration.current_quota_type != quota_configuration.quota_type: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 continue 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            restrict_llms = quota_configuration.restrict_llms 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            if not restrict_llms: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            restrict_models = quota_configuration.restrict_models 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if len(restrict_models) == 0: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 break 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if should_use_custom_model: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    # only customizable model 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    for restrict_model in restrict_models: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        copy_credentials = self.system_configuration.credentials.copy() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        if restrict_model.base_model_name: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            copy_credentials['base_model_name'] = restrict_model.base_model_name 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        try: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            custom_model_schema = ( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                provider_instance.get_model_instance(restrict_model.model_type) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                .get_customizable_model_schema_from_credentials( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                    restrict_model.model, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                    copy_credentials 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        except Exception as ex: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            logger.warning(f'get custom model schema failed, {ex}') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            continue 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        if not custom_model_schema: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            continue 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        if custom_model_schema.model_type not in model_types: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            continue 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        provider_models.append( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            ModelWithProviderEntity( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                model=custom_model_schema.model, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                label=custom_model_schema.label, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                model_type=custom_model_schema.model_type, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                features=custom_model_schema.features, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                fetch_from=FetchFrom.PREDEFINED_MODEL, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                model_properties=custom_model_schema.model_properties, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                deprecated=custom_model_schema.deprecated, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                provider=SimpleModelProviderEntity(self.provider), 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                                status=ModelStatus.ACTIVE 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             # if llm name not in restricted llm list, remove it 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            restrict_model_names = [rm.model for rm in restrict_models] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             for m in provider_models: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                if m.model_type == ModelType.LLM and m.model not in restrict_llms: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                if m.model_type == ModelType.LLM and m.model not in restrict_model_names: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     m.status = ModelStatus.NO_PERMISSION 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 elif not quota_configuration.is_valid: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     m.status = ModelStatus.QUOTA_EXCEEDED 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         return provider_models 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def _get_custom_provider_models(self, 
			 |