瀏覽代碼

fix: model provider credentials null value validate failed (#2009)

takatost 1 年之前
父節點
當前提交
1779cea6e3

+ 2 - 13
api/core/entities/provider_configuration.py

@@ -165,7 +165,7 @@ class ProviderConfiguration(BaseModel):
                     if value == '[__HIDDEN__]' and key in original_credentials:
                         credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
 
-        model_provider_factory.provider_credentials_validate(
+        credentials = model_provider_factory.provider_credentials_validate(
             self.provider.provider,
             credentials
         )
@@ -308,24 +308,13 @@ class ProviderConfiguration(BaseModel):
                     if value == '[__HIDDEN__]' and key in original_credentials:
                         credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
 
-        model_provider_factory.model_credentials_validate(
+        credentials = model_provider_factory.model_credentials_validate(
             provider=self.provider.provider,
             model_type=model_type,
             model=model,
             credentials=credentials
         )
 
-        model_schema = (
-            model_provider_factory.get_provider_instance(self.provider.provider)
-            .get_model_instance(model_type)._get_customizable_model_schema(
-                model=model,
-                credentials=credentials
-            )
-        )
-
-        if model_schema:
-            credentials['schema'] = json.dumps(encoders.jsonable_encoder(model_schema))
-
         for key, value in credentials.items():
             if key in provider_credential_secret_variables:
                 credentials[key] = encrypter.encrypt_token(self.tenant_id, value)

+ 10 - 6
api/core/model_runtime/model_providers/model_provider_factory.py

@@ -61,7 +61,7 @@ class ModelProviderFactory:
         # return providers
         return providers
 
-    def provider_credentials_validate(self, provider: str, credentials: dict) -> None:
+    def provider_credentials_validate(self, provider: str, credentials: dict) -> dict:
         """
         Validate provider credentials
 
@@ -80,13 +80,15 @@ class ModelProviderFactory:
 
         # validate provider credential schema
         validator = ProviderCredentialSchemaValidator(provider_credential_schema)
-        validator.validate_and_filter(credentials)
+        filtered_credentials = validator.validate_and_filter(credentials)
 
         # validate the credentials, raise exception if validation failed
-        model_provider_instance.validate_provider_credentials(credentials)
+        model_provider_instance.validate_provider_credentials(filtered_credentials)
+
+        return filtered_credentials
 
     def model_credentials_validate(self, provider: str, model_type: ModelType,
-                                   model: str, credentials: dict) -> None:
+                                   model: str, credentials: dict) -> dict:
         """
         Validate model credentials
 
@@ -107,13 +109,15 @@ class ModelProviderFactory:
 
         # validate model credential schema
         validator = ModelCredentialSchemaValidator(model_type, model_credential_schema)
-        validator.validate_and_filter(credentials)
+        filtered_credentials = validator.validate_and_filter(credentials)
 
         # get model instance of the model type
         model_instance = model_provider_instance.get_model_instance(model_type)
 
         # call validate_credentials method of model type to validate credentials, raise exception if validation failed
-        model_instance.validate_credentials(model, credentials)
+        model_instance.validate_credentials(model, filtered_credentials)
+
+        return filtered_credentials
 
     def get_models(self,
                    provider: Optional[str] = None,

+ 1 - 1
api/core/model_runtime/schema_validators/common_validator.py

@@ -46,7 +46,7 @@ class CommonValidator:
         :return: validated credential form schema value
         """
         #  If the variable does not exist in credentials
-        if credential_form_schema.variable not in credentials:
+        if credential_form_schema.variable not in credentials or not credentials[credential_form_schema.variable]:
             # If required is True, an exception is thrown
             if credential_form_schema.required:
                 raise ValueError(f'Variable {credential_form_schema.variable} is required')