test_replicate_provider.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. import pytest
  2. from unittest.mock import patch, MagicMock
  3. import json
  4. from core.model_providers.models.entity.model_params import ModelType
  5. from core.model_providers.providers.base import CredentialsValidateFailedError
  6. from core.model_providers.providers.replicate_provider import ReplicateProvider
  7. from models.provider import ProviderType, Provider, ProviderModel
  8. PROVIDER_NAME = 'replicate'
  9. MODEL_PROVIDER_CLASS = ReplicateProvider
  10. VALIDATE_CREDENTIAL = {
  11. 'model_version': 'fake-version',
  12. 'replicate_api_token': 'valid_key'
  13. }
  14. def encrypt_side_effect(tenant_id, encrypt_key):
  15. return f'encrypted_{encrypt_key}'
  16. def decrypt_side_effect(tenant_id, encrypted_key):
  17. return encrypted_key.replace('encrypted_', '')
  18. def test_is_credentials_valid_or_raise_valid(mocker):
  19. mock_query = MagicMock()
  20. mock_query.return_value = None
  21. mocker.patch('replicate.model.ModelCollection.get', return_value=mock_query)
  22. mocker.patch('replicate.model.Model.versions', return_value=mock_query)
  23. MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
  24. model_name='test_model_name',
  25. model_type=ModelType.TEXT_GENERATION,
  26. credentials=VALIDATE_CREDENTIAL.copy()
  27. )
  28. def test_is_credentials_valid_or_raise_invalid():
  29. # raise CredentialsValidateFailedError if replicate_api_token is not in credentials
  30. with pytest.raises(CredentialsValidateFailedError):
  31. MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
  32. model_name='test_model_name',
  33. model_type=ModelType.TEXT_GENERATION,
  34. credentials={}
  35. )
  36. # raise CredentialsValidateFailedError if replicate_api_token is invalid
  37. with pytest.raises(CredentialsValidateFailedError):
  38. MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
  39. model_name='test_model_name',
  40. model_type=ModelType.TEXT_GENERATION,
  41. credentials={'replicate_api_token': 'invalid_key'})
  42. @patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
  43. def test_encrypt_model_credentials(mock_encrypt):
  44. api_key = 'valid_key'
  45. result = MODEL_PROVIDER_CLASS.encrypt_model_credentials(
  46. tenant_id='tenant_id',
  47. model_name='test_model_name',
  48. model_type=ModelType.TEXT_GENERATION,
  49. credentials=VALIDATE_CREDENTIAL.copy()
  50. )
  51. mock_encrypt.assert_called_with('tenant_id', api_key)
  52. assert result['replicate_api_token'] == f'encrypted_{api_key}'
  53. @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
  54. def test_get_model_credentials_custom(mock_decrypt, mocker):
  55. provider = Provider(
  56. id='provider_id',
  57. tenant_id='tenant_id',
  58. provider_name=PROVIDER_NAME,
  59. provider_type=ProviderType.CUSTOM.value,
  60. encrypted_config=None,
  61. is_valid=True,
  62. )
  63. encrypted_credential = VALIDATE_CREDENTIAL.copy()
  64. encrypted_credential['replicate_api_token'] = 'encrypted_' + encrypted_credential['replicate_api_token']
  65. mock_query = MagicMock()
  66. mock_query.filter.return_value.first.return_value = ProviderModel(
  67. encrypted_config=json.dumps(encrypted_credential)
  68. )
  69. mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
  70. model_provider = MODEL_PROVIDER_CLASS(provider=provider)
  71. result = model_provider.get_model_credentials(
  72. model_name='test_model_name',
  73. model_type=ModelType.TEXT_GENERATION
  74. )
  75. assert result['replicate_api_token'] == 'valid_key'
  76. @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
  77. def test_get_model_credentials_obfuscated(mock_decrypt, mocker):
  78. provider = Provider(
  79. id='provider_id',
  80. tenant_id='tenant_id',
  81. provider_name=PROVIDER_NAME,
  82. provider_type=ProviderType.CUSTOM.value,
  83. encrypted_config=None,
  84. is_valid=True,
  85. )
  86. encrypted_credential = VALIDATE_CREDENTIAL.copy()
  87. encrypted_credential['replicate_api_token'] = 'encrypted_' + encrypted_credential['replicate_api_token']
  88. mock_query = MagicMock()
  89. mock_query.filter.return_value.first.return_value = ProviderModel(
  90. encrypted_config=json.dumps(encrypted_credential)
  91. )
  92. mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
  93. model_provider = MODEL_PROVIDER_CLASS(provider=provider)
  94. result = model_provider.get_model_credentials(
  95. model_name='test_model_name',
  96. model_type=ModelType.TEXT_GENERATION,
  97. obfuscated=True
  98. )
  99. middle_token = result['replicate_api_token'][6:-2]
  100. assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['replicate_api_token']) - 8, 0)
  101. assert all(char == '*' for char in middle_token)