Ver Fonte

feat: optimize performance (#1928)

takatost há 1 ano atrás
pai
commit
3fa5204b0c

+ 33 - 0
api/core/entities/provider_configuration.py

@@ -10,6 +10,7 @@ from pydantic import BaseModel
 from core.entities.model_entities import ModelWithProviderEntity, ModelStatus, SimpleModelProviderEntity
 from core.entities.provider_entities import SystemConfiguration, CustomConfiguration, SystemConfigurationStatus
 from core.helper import encrypter
+from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
 from core.model_runtime.entities.model_entities import ModelType
 from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType
 from core.model_runtime.model_providers import model_provider_factory
@@ -171,6 +172,14 @@ class ProviderConfiguration(BaseModel):
             db.session.add(provider_record)
             db.session.commit()
 
+        provider_model_credentials_cache = ProviderCredentialsCache(
+            tenant_id=self.tenant_id,
+            identity_id=provider_record.id,
+            cache_type=ProviderCredentialsCacheType.PROVIDER
+        )
+
+        provider_model_credentials_cache.delete()
+
         self.switch_preferred_provider_type(ProviderType.CUSTOM)
 
     def delete_custom_credentials(self) -> None:
@@ -193,6 +202,14 @@ class ProviderConfiguration(BaseModel):
             db.session.delete(provider_record)
             db.session.commit()
 
+            provider_model_credentials_cache = ProviderCredentialsCache(
+                tenant_id=self.tenant_id,
+                identity_id=provider_record.id,
+                cache_type=ProviderCredentialsCacheType.PROVIDER
+            )
+
+            provider_model_credentials_cache.delete()
+
     def get_custom_model_credentials(self, model_type: ModelType, model: str, obfuscated: bool = False) \
             -> Optional[dict]:
         """
@@ -314,6 +331,14 @@ class ProviderConfiguration(BaseModel):
             db.session.add(provider_model_record)
             db.session.commit()
 
+        provider_model_credentials_cache = ProviderCredentialsCache(
+            tenant_id=self.tenant_id,
+            identity_id=provider_model_record.id,
+            cache_type=ProviderCredentialsCacheType.MODEL
+        )
+
+        provider_model_credentials_cache.delete()
+
     def delete_custom_model_credentials(self, model_type: ModelType, model: str) -> None:
         """
         Delete custom model credentials.
@@ -335,6 +360,14 @@ class ProviderConfiguration(BaseModel):
             db.session.delete(provider_model_record)
             db.session.commit()
 
+            provider_model_credentials_cache = ProviderCredentialsCache(
+                tenant_id=self.tenant_id,
+                identity_id=provider_model_record.id,
+                cache_type=ProviderCredentialsCacheType.MODEL
+            )
+
+            provider_model_credentials_cache.delete()
+
     def get_provider_instance(self) -> ModelProvider:
         """
         Get provider instance.

+ 51 - 0
api/core/helper/model_provider_cache.py

@@ -0,0 +1,51 @@
+import json
+from enum import Enum
+from json import JSONDecodeError
+from typing import Optional
+
+from extensions.ext_redis import redis_client
+
+
+class ProviderCredentialsCacheType(Enum):
+    PROVIDER = "provider"
+    MODEL = "provider_model"
+
+
+class ProviderCredentialsCache:
+    def __init__(self, tenant_id: str, identity_id: str, cache_type: ProviderCredentialsCacheType):
+        self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
+
+    def get(self) -> Optional[dict]:
+        """
+        Get cached model provider credentials.
+
+        :return:
+        """
+        cached_provider_credentials = redis_client.get(self.cache_key)
+        if cached_provider_credentials:
+            try:
+                cached_provider_credentials = cached_provider_credentials.decode('utf-8')
+                cached_provider_credentials = json.loads(cached_provider_credentials)
+            except JSONDecodeError:
+                return None
+
+            return cached_provider_credentials
+        else:
+            return None
+
+    def set(self, credentials: dict) -> None:
+        """
+        Cache model provider credentials.
+
+        :param credentials: provider credentials
+        :return:
+        """
+        redis_client.setex(self.cache_key, 3600, json.dumps(credentials))
+
+    def delete(self) -> None:
+        """
+        Delete cached model provider credentials.
+
+        :return:
+        """
+        redis_client.delete(self.cache_key)

+ 4 - 0
api/core/model_runtime/model_providers/model_provider_factory.py

@@ -30,6 +30,10 @@ class ModelProviderExtension(BaseModel):
 class ModelProviderFactory:
     model_provider_extensions: dict[str, ModelProviderExtension] = None
 
+    def __init__(self) -> None:
+        # for cache in memory
+        self.get_providers()
+
     def get_providers(self) -> list[ProviderEntity]:
         """
         Get all providers

+ 126 - 74
api/core/provider_manager.py

@@ -10,6 +10,7 @@ from core.entities.provider_configuration import ProviderConfigurations, Provide
 from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, CustomModelConfiguration, \
     SystemConfiguration, QuotaConfiguration
 from core.helper import encrypter
+from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
 from core.model_runtime.entities.model_entities import ModelType
 from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType
 from core.model_runtime.model_providers import model_provider_factory
@@ -79,9 +80,6 @@ class ProviderManager:
         # Get All preferred provider types of the workspace
         provider_name_to_preferred_model_provider_records_dict = self._get_all_preferred_model_providers(tenant_id)
 
-        # Get decoding rsa key and cipher for decrypting credentials
-        decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
-
         provider_configurations = ProviderConfigurations(
             tenant_id=tenant_id
         )
@@ -100,19 +98,17 @@ class ProviderManager:
 
             # Convert to custom configuration
             custom_configuration = self._to_custom_configuration(
+                tenant_id,
                 provider_entity,
                 provider_records,
-                provider_model_records,
-                decoding_rsa_key,
-                decoding_cipher_rsa
+                provider_model_records
             )
 
             # Convert to system configuration
             system_configuration = self._to_system_configuration(
+                tenant_id,
                 provider_entity,
-                provider_records,
-                decoding_rsa_key,
-                decoding_cipher_rsa
+                provider_records
             )
 
             # Get preferred provider type
@@ -413,19 +409,17 @@ class ProviderManager:
         return provider_name_to_provider_records_dict
 
     def _to_custom_configuration(self,
+                                 tenant_id: str,
                                  provider_entity: ProviderEntity,
                                  provider_records: list[Provider],
-                                 provider_model_records: list[ProviderModel],
-                                 decoding_rsa_key,
-                                 decoding_cipher_rsa) -> CustomConfiguration:
+                                 provider_model_records: list[ProviderModel]) -> CustomConfiguration:
         """
         Convert to custom configuration.
 
+        :param tenant_id: workspace id
         :param provider_entity: provider entity
         :param provider_records: provider records
         :param provider_model_records: provider model records
-        :param decoding_rsa_key: decoding rsa key
-        :param decoding_cipher_rsa: decoding cipher rsa
         :return:
         """
         # Get provider credential secret variables
@@ -448,28 +442,48 @@ class ProviderManager:
         # Get custom provider credentials
         custom_provider_configuration = None
         if custom_provider_record:
-            try:
-                # fix origin data
-                if (custom_provider_record.encrypted_config
-                        and not custom_provider_record.encrypted_config.startswith("{")):
-                    provider_credentials = {
-                        "openai_api_key": custom_provider_record.encrypted_config
-                    }
-                else:
-                    provider_credentials = json.loads(custom_provider_record.encrypted_config)
-            except JSONDecodeError:
-                provider_credentials = {}
+            provider_credentials_cache = ProviderCredentialsCache(
+                tenant_id=tenant_id,
+                identity_id=custom_provider_record.id,
+                cache_type=ProviderCredentialsCacheType.PROVIDER
+            )
 
-            for variable in provider_credential_secret_variables:
-                if variable in provider_credentials:
-                    try:
-                        provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
-                            provider_credentials.get(variable),
-                            decoding_rsa_key,
-                            decoding_cipher_rsa
-                        )
-                    except ValueError:
-                        pass
+            # Get cached provider credentials
+            cached_provider_credentials = provider_credentials_cache.get()
+
+            if not cached_provider_credentials:
+                try:
+                    # fix origin data
+                    if (custom_provider_record.encrypted_config
+                            and not custom_provider_record.encrypted_config.startswith("{")):
+                        provider_credentials = {
+                            "openai_api_key": custom_provider_record.encrypted_config
+                        }
+                    else:
+                        provider_credentials = json.loads(custom_provider_record.encrypted_config)
+                except JSONDecodeError:
+                    provider_credentials = {}
+
+                # Get decoding rsa key and cipher for decrypting credentials
+                decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
+
+                for variable in provider_credential_secret_variables:
+                    if variable in provider_credentials:
+                        try:
+                            provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
+                                provider_credentials.get(variable),
+                                decoding_rsa_key,
+                                decoding_cipher_rsa
+                            )
+                        except ValueError:
+                            pass
+
+                # cache provider credentials
+                provider_credentials_cache.set(
+                    credentials=provider_credentials
+                )
+            else:
+                provider_credentials = cached_provider_credentials
 
             custom_provider_configuration = CustomProviderConfiguration(
                 credentials=provider_credentials
@@ -487,21 +501,41 @@ class ProviderManager:
             if not provider_model_record.encrypted_config:
                 continue
 
-            try:
-                provider_model_credentials = json.loads(provider_model_record.encrypted_config)
-            except JSONDecodeError:
-                continue
+            provider_model_credentials_cache = ProviderCredentialsCache(
+                tenant_id=tenant_id,
+                identity_id=provider_model_record.id,
+                cache_type=ProviderCredentialsCacheType.MODEL
+            )
 
-            for variable in model_credential_secret_variables:
-                if variable in provider_model_credentials:
-                    try:
-                        provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding(
-                            provider_model_credentials.get(variable),
-                            decoding_rsa_key,
-                            decoding_cipher_rsa
-                        )
-                    except ValueError:
-                        pass
+            # Get cached provider model credentials
+            cached_provider_model_credentials = provider_model_credentials_cache.get()
+
+            if not cached_provider_model_credentials:
+                try:
+                    provider_model_credentials = json.loads(provider_model_record.encrypted_config)
+                except JSONDecodeError:
+                    continue
+
+                # Get decoding rsa key and cipher for decrypting credentials
+                decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
+
+                for variable in model_credential_secret_variables:
+                    if variable in provider_model_credentials:
+                        try:
+                            provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding(
+                                provider_model_credentials.get(variable),
+                                decoding_rsa_key,
+                                decoding_cipher_rsa
+                            )
+                        except ValueError:
+                            pass
+
+                # cache provider model credentials
+                provider_model_credentials_cache.set(
+                    credentials=provider_model_credentials
+                )
+            else:
+                provider_model_credentials = cached_provider_model_credentials
 
             custom_model_configurations.append(
                 CustomModelConfiguration(
@@ -517,17 +551,15 @@ class ProviderManager:
         )
 
     def _to_system_configuration(self,
+                                 tenant_id: str,
                                  provider_entity: ProviderEntity,
-                                 provider_records: list[Provider],
-                                 decoding_rsa_key,
-                                 decoding_cipher_rsa) -> SystemConfiguration:
+                                 provider_records: list[Provider]) -> SystemConfiguration:
         """
         Convert to system configuration.
 
+        :param tenant_id: workspace id
         :param provider_entity: provider entity
         :param provider_records: provider records
-        :param decoding_rsa_key: decoding rsa key
-        :param decoding_cipher_rsa: decoding cipher rsa
         :return:
         """
         # Get hosting configuration
@@ -580,29 +612,49 @@ class ProviderManager:
             provider_record = quota_type_to_provider_records_dict.get(current_quota_type)
 
             if provider_record:
-                try:
-                    provider_credentials = json.loads(provider_record.encrypted_config)
-                except JSONDecodeError:
-                    provider_credentials = {}
-
-                # Get provider credential secret variables
-                provider_credential_secret_variables = self._extract_secret_variables(
-                    provider_entity.provider_credential_schema.credential_form_schemas
-                    if provider_entity.provider_credential_schema else []
+                provider_credentials_cache = ProviderCredentialsCache(
+                    tenant_id=tenant_id,
+                    identity_id=provider_record.id,
+                    cache_type=ProviderCredentialsCacheType.PROVIDER
                 )
 
-                for variable in provider_credential_secret_variables:
-                    if variable in provider_credentials:
-                        try:
-                            provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
-                                provider_credentials.get(variable),
-                                decoding_rsa_key,
-                                decoding_cipher_rsa
-                            )
-                        except ValueError:
-                            pass
+                # Get cached provider credentials
+                cached_provider_credentials = provider_credentials_cache.get()
 
-                current_using_credentials = provider_credentials
+                if not cached_provider_credentials:
+                    try:
+                        provider_credentials = json.loads(provider_record.encrypted_config)
+                    except JSONDecodeError:
+                        provider_credentials = {}
+
+                    # Get provider credential secret variables
+                    provider_credential_secret_variables = self._extract_secret_variables(
+                        provider_entity.provider_credential_schema.credential_form_schemas
+                        if provider_entity.provider_credential_schema else []
+                    )
+
+                    # Get decoding rsa key and cipher for decrypting credentials
+                    decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
+
+                    for variable in provider_credential_secret_variables:
+                        if variable in provider_credentials:
+                            try:
+                                provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
+                                    provider_credentials.get(variable),
+                                    decoding_rsa_key,
+                                    decoding_cipher_rsa
+                                )
+                            except ValueError:
+                                pass
+
+                    current_using_credentials = provider_credentials
+
+                    # cache provider credentials
+                    provider_credentials_cache.set(
+                        credentials=current_using_credentials
+                    )
+                else:
+                    current_using_credentials = cached_provider_credentials
             else:
                 current_using_credentials = {}
 

+ 1 - 1
api/services/model_provider_service.py

@@ -6,7 +6,7 @@ from typing import Optional, cast, Tuple
 import requests
 from flask import current_app
 
-from core.entities.model_entities import ModelWithProviderEntity, ModelStatus, DefaultModelEntity
+from core.entities.model_entities import ModelStatus
 from core.model_runtime.entities.model_entities import ModelType, ParameterRule
 from core.model_runtime.model_providers import model_provider_factory
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel