test_openai_provider.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. import pytest
  2. from unittest.mock import patch, MagicMock
  3. import json
  4. from openai.error import AuthenticationError
  5. from core.model_providers.providers.base import CredentialsValidateFailedError
  6. from core.model_providers.providers.openai_provider import OpenAIProvider
  7. from models.provider import ProviderType, Provider
  8. PROVIDER_NAME = 'openai'
  9. MODEL_PROVIDER_CLASS = OpenAIProvider
  10. VALIDATE_CREDENTIAL_KEY = 'openai_api_key'
  11. def moderation_side_effect(*args, **kwargs):
  12. if kwargs['api_key'] == 'valid_key':
  13. mock_instance = MagicMock()
  14. mock_instance.request = MagicMock()
  15. return mock_instance, {}
  16. else:
  17. raise AuthenticationError('Invalid credentials')
  18. def encrypt_side_effect(tenant_id, encrypt_key):
  19. return f'encrypted_{encrypt_key}'
  20. def decrypt_side_effect(tenant_id, encrypted_key):
  21. return encrypted_key.replace('encrypted_', '')
  22. @patch('openai.ChatCompletion.create', side_effect=moderation_side_effect)
  23. def test_is_provider_credentials_valid_or_raise_valid(mock_create):
  24. # assert True if api_key is valid
  25. credentials = {VALIDATE_CREDENTIAL_KEY: 'valid_key'}
  26. assert MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(credentials) is None
  27. @patch('openai.ChatCompletion.create', side_effect=moderation_side_effect)
  28. def test_is_provider_credentials_valid_or_raise_invalid(mock_create):
  29. # raise CredentialsValidateFailedError if api_key is not in credentials
  30. with pytest.raises(CredentialsValidateFailedError):
  31. MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({})
  32. # raise CredentialsValidateFailedError if api_key is invalid
  33. with pytest.raises(CredentialsValidateFailedError):
  34. MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({VALIDATE_CREDENTIAL_KEY: 'invalid_key'})
  35. @patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
  36. def test_encrypt_credentials(mock_encrypt):
  37. api_key = 'valid_key'
  38. result = MODEL_PROVIDER_CLASS.encrypt_provider_credentials('tenant_id', {VALIDATE_CREDENTIAL_KEY: api_key})
  39. mock_encrypt.assert_called_with('tenant_id', api_key)
  40. assert result[VALIDATE_CREDENTIAL_KEY] == f'encrypted_{api_key}'
  41. @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
  42. def test_get_credentials_custom(mock_decrypt):
  43. provider = Provider(
  44. id='provider_id',
  45. tenant_id='tenant_id',
  46. provider_name=PROVIDER_NAME,
  47. provider_type=ProviderType.CUSTOM.value,
  48. encrypted_config=json.dumps({VALIDATE_CREDENTIAL_KEY: 'encrypted_valid_key'}),
  49. is_valid=True,
  50. )
  51. model_provider = MODEL_PROVIDER_CLASS(provider=provider)
  52. result = model_provider.get_provider_credentials()
  53. assert result[VALIDATE_CREDENTIAL_KEY] == 'valid_key'
  54. @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
  55. def test_get_credentials_custom_str(mock_decrypt):
  56. """
  57. Only the OpenAI provider needs to be compatible with the previous case where the encrypted_config was stored as a plain string.
  58. :param mock_decrypt:
  59. :return:
  60. """
  61. provider = Provider(
  62. id='provider_id',
  63. tenant_id='tenant_id',
  64. provider_name=PROVIDER_NAME,
  65. provider_type=ProviderType.CUSTOM.value,
  66. encrypted_config='encrypted_valid_key',
  67. is_valid=True,
  68. )
  69. model_provider = MODEL_PROVIDER_CLASS(provider=provider)
  70. result = model_provider.get_provider_credentials()
  71. assert result[VALIDATE_CREDENTIAL_KEY] == 'valid_key'
  72. @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
  73. def test_get_credentials_obfuscated(mock_decrypt):
  74. openai_api_key = 'valid_key'
  75. provider = Provider(
  76. id='provider_id',
  77. tenant_id='tenant_id',
  78. provider_name=PROVIDER_NAME,
  79. provider_type=ProviderType.CUSTOM.value,
  80. encrypted_config=json.dumps({VALIDATE_CREDENTIAL_KEY: f'encrypted_{openai_api_key}'}),
  81. is_valid=True,
  82. )
  83. model_provider = MODEL_PROVIDER_CLASS(provider=provider)
  84. result = model_provider.get_provider_credentials(obfuscated=True)
  85. middle_token = result[VALIDATE_CREDENTIAL_KEY][6:-2]
  86. assert len(middle_token) == max(len(openai_api_key) - 8, 0)
  87. assert all(char == '*' for char in middle_token)
  88. @patch('core.model_providers.providers.hosted.hosted_model_providers.openai')
  89. def test_get_credentials_hosted(mock_hosted):
  90. provider = Provider(
  91. id='provider_id',
  92. tenant_id='tenant_id',
  93. provider_name=PROVIDER_NAME,
  94. provider_type=ProviderType.SYSTEM.value,
  95. encrypted_config='',
  96. is_valid=True
  97. )
  98. model_provider = MODEL_PROVIDER_CLASS(provider=provider)
  99. mock_hosted.api_key = 'hosted_key'
  100. result = model_provider.get_provider_credentials()
  101. assert result[VALIDATE_CREDENTIAL_KEY] == 'hosted_key'