|
@@ -10,6 +10,7 @@ from core.entities.provider_configuration import ProviderConfigurations, Provide
|
|
|
from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, CustomModelConfiguration, \
|
|
|
SystemConfiguration, QuotaConfiguration
|
|
|
from core.helper import encrypter
|
|
|
+from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
|
|
|
from core.model_runtime.entities.model_entities import ModelType
|
|
|
from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType
|
|
|
from core.model_runtime.model_providers import model_provider_factory
|
|
@@ -79,9 +80,6 @@ class ProviderManager:
|
|
|
# Get All preferred provider types of the workspace
|
|
|
provider_name_to_preferred_model_provider_records_dict = self._get_all_preferred_model_providers(tenant_id)
|
|
|
|
|
|
- # Get decoding rsa key and cipher for decrypting credentials
|
|
|
- decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
|
|
|
-
|
|
|
provider_configurations = ProviderConfigurations(
|
|
|
tenant_id=tenant_id
|
|
|
)
|
|
@@ -100,19 +98,17 @@ class ProviderManager:
|
|
|
|
|
|
# Convert to custom configuration
|
|
|
custom_configuration = self._to_custom_configuration(
|
|
|
+ tenant_id,
|
|
|
provider_entity,
|
|
|
provider_records,
|
|
|
- provider_model_records,
|
|
|
- decoding_rsa_key,
|
|
|
- decoding_cipher_rsa
|
|
|
+ provider_model_records
|
|
|
)
|
|
|
|
|
|
# Convert to system configuration
|
|
|
system_configuration = self._to_system_configuration(
|
|
|
+ tenant_id,
|
|
|
provider_entity,
|
|
|
- provider_records,
|
|
|
- decoding_rsa_key,
|
|
|
- decoding_cipher_rsa
|
|
|
+ provider_records
|
|
|
)
|
|
|
|
|
|
# Get preferred provider type
|
|
@@ -413,19 +409,17 @@ class ProviderManager:
|
|
|
return provider_name_to_provider_records_dict
|
|
|
|
|
|
def _to_custom_configuration(self,
|
|
|
+ tenant_id: str,
|
|
|
provider_entity: ProviderEntity,
|
|
|
provider_records: list[Provider],
|
|
|
- provider_model_records: list[ProviderModel],
|
|
|
- decoding_rsa_key,
|
|
|
- decoding_cipher_rsa) -> CustomConfiguration:
|
|
|
+ provider_model_records: list[ProviderModel]) -> CustomConfiguration:
|
|
|
"""
|
|
|
Convert to custom configuration.
|
|
|
|
|
|
+ :param tenant_id: workspace id
|
|
|
:param provider_entity: provider entity
|
|
|
:param provider_records: provider records
|
|
|
:param provider_model_records: provider model records
|
|
|
- :param decoding_rsa_key: decoding rsa key
|
|
|
- :param decoding_cipher_rsa: decoding cipher rsa
|
|
|
:return:
|
|
|
"""
|
|
|
# Get provider credential secret variables
|
|
@@ -448,28 +442,48 @@ class ProviderManager:
|
|
|
# Get custom provider credentials
|
|
|
custom_provider_configuration = None
|
|
|
if custom_provider_record:
|
|
|
- try:
|
|
|
- # fix origin data
|
|
|
- if (custom_provider_record.encrypted_config
|
|
|
- and not custom_provider_record.encrypted_config.startswith("{")):
|
|
|
- provider_credentials = {
|
|
|
- "openai_api_key": custom_provider_record.encrypted_config
|
|
|
- }
|
|
|
- else:
|
|
|
- provider_credentials = json.loads(custom_provider_record.encrypted_config)
|
|
|
- except JSONDecodeError:
|
|
|
- provider_credentials = {}
|
|
|
+ provider_credentials_cache = ProviderCredentialsCache(
|
|
|
+ tenant_id=tenant_id,
|
|
|
+ identity_id=custom_provider_record.id,
|
|
|
+ cache_type=ProviderCredentialsCacheType.PROVIDER
|
|
|
+ )
|
|
|
|
|
|
- for variable in provider_credential_secret_variables:
|
|
|
- if variable in provider_credentials:
|
|
|
- try:
|
|
|
- provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
|
|
- provider_credentials.get(variable),
|
|
|
- decoding_rsa_key,
|
|
|
- decoding_cipher_rsa
|
|
|
- )
|
|
|
- except ValueError:
|
|
|
- pass
|
|
|
+ # Get cached provider credentials
|
|
|
+ cached_provider_credentials = provider_credentials_cache.get()
|
|
|
+
|
|
|
+ if not cached_provider_credentials:
|
|
|
+ try:
|
|
|
+ # fix origin data
|
|
|
+ if (custom_provider_record.encrypted_config
|
|
|
+ and not custom_provider_record.encrypted_config.startswith("{")):
|
|
|
+ provider_credentials = {
|
|
|
+ "openai_api_key": custom_provider_record.encrypted_config
|
|
|
+ }
|
|
|
+ else:
|
|
|
+ provider_credentials = json.loads(custom_provider_record.encrypted_config)
|
|
|
+ except JSONDecodeError:
|
|
|
+ provider_credentials = {}
|
|
|
+
|
|
|
+ # Get decoding rsa key and cipher for decrypting credentials
|
|
|
+ decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
|
|
|
+
|
|
|
+ for variable in provider_credential_secret_variables:
|
|
|
+ if variable in provider_credentials:
|
|
|
+ try:
|
|
|
+ provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
|
|
+ provider_credentials.get(variable),
|
|
|
+ decoding_rsa_key,
|
|
|
+ decoding_cipher_rsa
|
|
|
+ )
|
|
|
+ except ValueError:
|
|
|
+ pass
|
|
|
+
|
|
|
+ # cache provider credentials
|
|
|
+ provider_credentials_cache.set(
|
|
|
+ credentials=provider_credentials
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ provider_credentials = cached_provider_credentials
|
|
|
|
|
|
custom_provider_configuration = CustomProviderConfiguration(
|
|
|
credentials=provider_credentials
|
|
@@ -487,21 +501,41 @@ class ProviderManager:
|
|
|
if not provider_model_record.encrypted_config:
|
|
|
continue
|
|
|
|
|
|
- try:
|
|
|
- provider_model_credentials = json.loads(provider_model_record.encrypted_config)
|
|
|
- except JSONDecodeError:
|
|
|
- continue
|
|
|
+ provider_model_credentials_cache = ProviderCredentialsCache(
|
|
|
+ tenant_id=tenant_id,
|
|
|
+ identity_id=provider_model_record.id,
|
|
|
+ cache_type=ProviderCredentialsCacheType.MODEL
|
|
|
+ )
|
|
|
|
|
|
- for variable in model_credential_secret_variables:
|
|
|
- if variable in provider_model_credentials:
|
|
|
- try:
|
|
|
- provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
|
|
- provider_model_credentials.get(variable),
|
|
|
- decoding_rsa_key,
|
|
|
- decoding_cipher_rsa
|
|
|
- )
|
|
|
- except ValueError:
|
|
|
- pass
|
|
|
+ # Get cached provider model credentials
|
|
|
+ cached_provider_model_credentials = provider_model_credentials_cache.get()
|
|
|
+
|
|
|
+ if not cached_provider_model_credentials:
|
|
|
+ try:
|
|
|
+ provider_model_credentials = json.loads(provider_model_record.encrypted_config)
|
|
|
+ except JSONDecodeError:
|
|
|
+ continue
|
|
|
+
|
|
|
+ # Get decoding rsa key and cipher for decrypting credentials
|
|
|
+ decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
|
|
|
+
|
|
|
+ for variable in model_credential_secret_variables:
|
|
|
+ if variable in provider_model_credentials:
|
|
|
+ try:
|
|
|
+ provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
|
|
+ provider_model_credentials.get(variable),
|
|
|
+ decoding_rsa_key,
|
|
|
+ decoding_cipher_rsa
|
|
|
+ )
|
|
|
+ except ValueError:
|
|
|
+ pass
|
|
|
+
|
|
|
+ # cache provider model credentials
|
|
|
+ provider_model_credentials_cache.set(
|
|
|
+ credentials=provider_model_credentials
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ provider_model_credentials = cached_provider_model_credentials
|
|
|
|
|
|
custom_model_configurations.append(
|
|
|
CustomModelConfiguration(
|
|
@@ -517,17 +551,15 @@ class ProviderManager:
|
|
|
)
|
|
|
|
|
|
def _to_system_configuration(self,
|
|
|
+ tenant_id: str,
|
|
|
provider_entity: ProviderEntity,
|
|
|
- provider_records: list[Provider],
|
|
|
- decoding_rsa_key,
|
|
|
- decoding_cipher_rsa) -> SystemConfiguration:
|
|
|
+ provider_records: list[Provider]) -> SystemConfiguration:
|
|
|
"""
|
|
|
Convert to system configuration.
|
|
|
|
|
|
+ :param tenant_id: workspace id
|
|
|
:param provider_entity: provider entity
|
|
|
:param provider_records: provider records
|
|
|
- :param decoding_rsa_key: decoding rsa key
|
|
|
- :param decoding_cipher_rsa: decoding cipher rsa
|
|
|
:return:
|
|
|
"""
|
|
|
# Get hosting configuration
|
|
@@ -580,29 +612,49 @@ class ProviderManager:
|
|
|
provider_record = quota_type_to_provider_records_dict.get(current_quota_type)
|
|
|
|
|
|
if provider_record:
|
|
|
- try:
|
|
|
- provider_credentials = json.loads(provider_record.encrypted_config)
|
|
|
- except JSONDecodeError:
|
|
|
- provider_credentials = {}
|
|
|
-
|
|
|
- # Get provider credential secret variables
|
|
|
- provider_credential_secret_variables = self._extract_secret_variables(
|
|
|
- provider_entity.provider_credential_schema.credential_form_schemas
|
|
|
- if provider_entity.provider_credential_schema else []
|
|
|
+ provider_credentials_cache = ProviderCredentialsCache(
|
|
|
+ tenant_id=tenant_id,
|
|
|
+ identity_id=provider_record.id,
|
|
|
+ cache_type=ProviderCredentialsCacheType.PROVIDER
|
|
|
)
|
|
|
|
|
|
- for variable in provider_credential_secret_variables:
|
|
|
- if variable in provider_credentials:
|
|
|
- try:
|
|
|
- provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
|
|
- provider_credentials.get(variable),
|
|
|
- decoding_rsa_key,
|
|
|
- decoding_cipher_rsa
|
|
|
- )
|
|
|
- except ValueError:
|
|
|
- pass
|
|
|
+ # Get cached provider credentials
|
|
|
+ cached_provider_credentials = provider_credentials_cache.get()
|
|
|
|
|
|
- current_using_credentials = provider_credentials
|
|
|
+ if not cached_provider_credentials:
|
|
|
+ try:
|
|
|
+ provider_credentials = json.loads(provider_record.encrypted_config)
|
|
|
+ except JSONDecodeError:
|
|
|
+ provider_credentials = {}
|
|
|
+
|
|
|
+ # Get provider credential secret variables
|
|
|
+ provider_credential_secret_variables = self._extract_secret_variables(
|
|
|
+ provider_entity.provider_credential_schema.credential_form_schemas
|
|
|
+ if provider_entity.provider_credential_schema else []
|
|
|
+ )
|
|
|
+
|
|
|
+ # Get decoding rsa key and cipher for decrypting credentials
|
|
|
+ decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
|
|
|
+
|
|
|
+ for variable in provider_credential_secret_variables:
|
|
|
+ if variable in provider_credentials:
|
|
|
+ try:
|
|
|
+ provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
|
|
+ provider_credentials.get(variable),
|
|
|
+ decoding_rsa_key,
|
|
|
+ decoding_cipher_rsa
|
|
|
+ )
|
|
|
+ except ValueError:
|
|
|
+ pass
|
|
|
+
|
|
|
+ current_using_credentials = provider_credentials
|
|
|
+
|
|
|
+ # cache provider credentials
|
|
|
+ provider_credentials_cache.set(
|
|
|
+ credentials=current_using_credentials
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ current_using_credentials = cached_provider_credentials
|
|
|
else:
|
|
|
current_using_credentials = {}
|
|
|
|