|
@@ -61,7 +61,7 @@ class ModelProviderFactory:
|
|
|
# return providers
|
|
|
return providers
|
|
|
|
|
|
- def provider_credentials_validate(self, provider: str, credentials: dict) -> None:
|
|
|
+ def provider_credentials_validate(self, provider: str, credentials: dict) -> dict:
|
|
|
"""
|
|
|
Validate provider credentials
|
|
|
|
|
@@ -80,13 +80,15 @@ class ModelProviderFactory:
|
|
|
|
|
|
# validate provider credential schema
|
|
|
validator = ProviderCredentialSchemaValidator(provider_credential_schema)
|
|
|
- validator.validate_and_filter(credentials)
|
|
|
+ filtered_credentials = validator.validate_and_filter(credentials)
|
|
|
|
|
|
# validate the credentials, raise exception if validation failed
|
|
|
- model_provider_instance.validate_provider_credentials(credentials)
|
|
|
+ model_provider_instance.validate_provider_credentials(filtered_credentials)
|
|
|
+
|
|
|
+ return filtered_credentials
|
|
|
|
|
|
def model_credentials_validate(self, provider: str, model_type: ModelType,
|
|
|
- model: str, credentials: dict) -> None:
|
|
|
+ model: str, credentials: dict) -> dict:
|
|
|
"""
|
|
|
Validate model credentials
|
|
|
|
|
@@ -107,13 +109,15 @@ class ModelProviderFactory:
|
|
|
|
|
|
# validate model credential schema
|
|
|
validator = ModelCredentialSchemaValidator(model_type, model_credential_schema)
|
|
|
- validator.validate_and_filter(credentials)
|
|
|
+ filtered_credentials = validator.validate_and_filter(credentials)
|
|
|
|
|
|
# get model instance of the model type
|
|
|
model_instance = model_provider_instance.get_model_instance(model_type)
|
|
|
|
|
|
# call validate_credentials method of model type to validate credentials, raise exception if validation failed
|
|
|
- model_instance.validate_credentials(model, credentials)
|
|
|
+ model_instance.validate_credentials(model, filtered_credentials)
|
|
|
+
|
|
|
+ return filtered_credentials
|
|
|
|
|
|
def get_models(self,
|
|
|
provider: Optional[str] = None,
|