Ver Fonte

feat: reuse decoding_rsa_key & decoding_cipher_rsa & optimize construct (#1937)

takatost há 1 ano atrás
pai
commit
296bf443a8

+ 21 - 3
api/core/entities/provider_configuration.py

@@ -520,7 +520,13 @@ class ProviderConfiguration(BaseModel):
             provider_models.extend(
                 [
                     ModelWithProviderEntity(
-                        **m.dict(),
+                        model=m.model,
+                        label=m.label,
+                        model_type=m.model_type,
+                        features=m.features,
+                        fetch_from=m.fetch_from,
+                        model_properties=m.model_properties,
+                        deprecated=m.deprecated,
                         provider=SimpleModelProviderEntity(self.provider),
                         status=ModelStatus.ACTIVE
                     )
@@ -569,7 +575,13 @@ class ProviderConfiguration(BaseModel):
             for m in models:
                 provider_models.append(
                     ModelWithProviderEntity(
-                        **m.dict(),
+                        model=m.model,
+                        label=m.label,
+                        model_type=m.model_type,
+                        features=m.features,
+                        fetch_from=m.fetch_from,
+                        model_properties=m.model_properties,
+                        deprecated=m.deprecated,
                         provider=SimpleModelProviderEntity(self.provider),
                         status=ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
                     )
@@ -597,7 +609,13 @@ class ProviderConfiguration(BaseModel):
 
             provider_models.append(
                 ModelWithProviderEntity(
-                    **custom_model_schema.dict(),
+                    model=custom_model_schema.model,
+                    label=custom_model_schema.label,
+                    model_type=custom_model_schema.model_type,
+                    features=custom_model_schema.features,
+                    fetch_from=custom_model_schema.fetch_from,
+                    model_properties=custom_model_schema.model_properties,
+                    deprecated=custom_model_schema.deprecated,
                     provider=SimpleModelProviderEntity(self.provider),
                     status=ModelStatus.ACTIVE
                 )

+ 15 - 9
api/core/provider_manager.py

@@ -24,6 +24,9 @@ class ProviderManager:
     """
     ProviderManager is a class that manages the model providers includes Hosting and Customize Model Providers.
     """
+    def __init__(self) -> None:
+        self.decoding_rsa_key = None
+        self.decoding_cipher_rsa = None
 
     def get_configurations(self, tenant_id: str) -> ProviderConfigurations:
         """
@@ -472,15 +475,16 @@ class ProviderManager:
                     provider_credentials = {}
 
                 # Get decoding rsa key and cipher for decrypting credentials
-                decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
+                if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
+                    self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
 
                 for variable in provider_credential_secret_variables:
                     if variable in provider_credentials:
                         try:
                             provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
                                 provider_credentials.get(variable),
-                                decoding_rsa_key,
-                                decoding_cipher_rsa
+                                self.decoding_rsa_key,
+                                self.decoding_cipher_rsa
                             )
                         except ValueError:
                             pass
@@ -524,15 +528,16 @@ class ProviderManager:
                     continue
 
                 # Get decoding rsa key and cipher for decrypting credentials
-                decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
+                if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
+                    self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
 
                 for variable in model_credential_secret_variables:
                     if variable in provider_model_credentials:
                         try:
                             provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding(
                                 provider_model_credentials.get(variable),
-                                decoding_rsa_key,
-                                decoding_cipher_rsa
+                                self.decoding_rsa_key,
+                                self.decoding_cipher_rsa
                             )
                         except ValueError:
                             pass
@@ -641,15 +646,16 @@ class ProviderManager:
                     )
 
                     # Get decoding rsa key and cipher for decrypting credentials
-                    decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
+                    if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
+                        self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
 
                     for variable in provider_credential_secret_variables:
                         if variable in provider_credentials:
                             try:
                                 provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
                                     provider_credentials.get(variable),
-                                    decoding_rsa_key,
-                                    decoding_cipher_rsa
+                                    self.decoding_rsa_key,
+                                    self.decoding_cipher_rsa
                                 )
                             except ValueError:
                                 pass

+ 24 - 4
api/services/model_provider_service.py

@@ -14,7 +14,7 @@ from core.provider_manager import ProviderManager
 from models.provider import ProviderType
 from services.entities.model_provider_entities import ProviderResponse, CustomConfigurationResponse, \
     SystemConfigurationResponse, CustomConfigurationStatus, ProviderWithModelsResponse, ModelResponse, \
-    DefaultModelResponse, ModelWithProviderEntityResponse
+    DefaultModelResponse, ModelWithProviderEntityResponse, SimpleProviderEntityResponse
 
 logger = logging.getLogger(__name__)
 
@@ -45,7 +45,17 @@ class ModelProviderService:
                     continue
 
             provider_response = ProviderResponse(
-                **provider_configuration.provider.dict(),
+                provider=provider_configuration.provider.provider,
+                label=provider_configuration.provider.label,
+                description=provider_configuration.provider.description,
+                icon_small=provider_configuration.provider.icon_small,
+                icon_large=provider_configuration.provider.icon_large,
+                background=provider_configuration.provider.background,
+                help=provider_configuration.provider.help,
+                supported_model_types=provider_configuration.provider.supported_model_types,
+                configurate_methods=provider_configuration.provider.configurate_methods,
+                provider_credential_schema=provider_configuration.provider.provider_credential_schema,
+                model_credential_schema=provider_configuration.provider.model_credential_schema,
                 preferred_provider_type=provider_configuration.preferred_provider_type,
                 custom_configuration=CustomConfigurationResponse(
                     status=CustomConfigurationStatus.ACTIVE
@@ -53,7 +63,9 @@ class ModelProviderService:
                     else CustomConfigurationStatus.NO_CONFIGURE
                 ),
                 system_configuration=SystemConfigurationResponse(
-                    **provider_configuration.system_configuration.dict()
+                    enabled=provider_configuration.system_configuration.enabled,
+                    current_quota_type=provider_configuration.system_configuration.current_quota_type,
+                    quota_configurations=provider_configuration.system_configuration.quota_configurations
                 )
             )
 
@@ -369,7 +381,15 @@ class ModelProviderService:
         )
 
         return DefaultModelResponse(
-            **result.dict()
+            model=result.model,
+            model_type=result.model_type,
+            provider=SimpleProviderEntityResponse(
+                provider=result.provider.provider,
+                label=result.provider.label,
+                icon_small=result.provider.icon_small,
+                icon_large=result.provider.icon_large,
+                supported_model_types=result.provider.supported_model_types
+            )
         ) if result else None
 
     def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str) -> None: