test_chatglm_provider.py 3.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. import pytest
  2. from unittest.mock import patch
  3. import json
  4. import requests
  5. from langchain.schema import LLMResult, Generation, AIMessage, ChatResult, ChatGeneration
  6. from requests import Response
  7. from core.model_providers.providers.base import CredentialsValidateFailedError
  8. from core.model_providers.providers.chatglm_provider import ChatGLMProvider
  9. from core.model_providers.providers.spark_provider import SparkProvider
  10. from models.provider import ProviderType, Provider
  11. PROVIDER_NAME = 'chatglm'
  12. MODEL_PROVIDER_CLASS = ChatGLMProvider
  13. VALIDATE_CREDENTIAL = {
  14. 'api_base': 'valid_api_base',
  15. }
  16. def encrypt_side_effect(tenant_id, encrypt_key):
  17. return f'encrypted_{encrypt_key}'
  18. def decrypt_side_effect(tenant_id, encrypted_key):
  19. return encrypted_key.replace('encrypted_', '')
  20. def test_is_provider_credentials_valid_or_raise_valid(mocker):
  21. mock_response = Response()
  22. mock_response.status_code = 200
  23. mock_response._content = json.dumps({'models': []}).encode('utf-8')
  24. mocker.patch('requests.get',
  25. return_value=mock_response)
  26. MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL)
  27. def test_is_provider_credentials_valid_or_raise_invalid():
  28. # raise CredentialsValidateFailedError if api_key is not in credentials
  29. with pytest.raises(CredentialsValidateFailedError):
  30. MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({})
  31. credential = VALIDATE_CREDENTIAL.copy()
  32. credential['api_base'] = 'invalid_api_base'
  33. # raise CredentialsValidateFailedError if api_key is invalid
  34. with pytest.raises(CredentialsValidateFailedError):
  35. MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(credential)
  36. @patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
  37. def test_encrypt_credentials(mock_encrypt):
  38. result = MODEL_PROVIDER_CLASS.encrypt_provider_credentials('tenant_id', VALIDATE_CREDENTIAL.copy())
  39. assert result['api_base'] == f'encrypted_{VALIDATE_CREDENTIAL["api_base"]}'
  40. @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
  41. def test_get_credentials_custom(mock_decrypt):
  42. encrypted_credential = VALIDATE_CREDENTIAL.copy()
  43. encrypted_credential['api_base'] = 'encrypted_' + encrypted_credential['api_base']
  44. provider = Provider(
  45. id='provider_id',
  46. tenant_id='tenant_id',
  47. provider_name=PROVIDER_NAME,
  48. provider_type=ProviderType.CUSTOM.value,
  49. encrypted_config=json.dumps(encrypted_credential),
  50. is_valid=True,
  51. )
  52. model_provider = MODEL_PROVIDER_CLASS(provider=provider)
  53. result = model_provider.get_provider_credentials()
  54. assert result['api_base'] == 'valid_api_base'
  55. @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
  56. def test_get_credentials_obfuscated(mock_decrypt):
  57. encrypted_credential = VALIDATE_CREDENTIAL.copy()
  58. encrypted_credential['api_base'] = 'encrypted_' + encrypted_credential['api_base']
  59. provider = Provider(
  60. id='provider_id',
  61. tenant_id='tenant_id',
  62. provider_name=PROVIDER_NAME,
  63. provider_type=ProviderType.CUSTOM.value,
  64. encrypted_config=json.dumps(encrypted_credential),
  65. is_valid=True,
  66. )
  67. model_provider = MODEL_PROVIDER_CLASS(provider=provider)
  68. result = model_provider.get_provider_credentials(obfuscated=True)
  69. middle_token = result['api_base'][6:-2]
  70. assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['api_base']) - 8, 0)
  71. assert all(char == '*' for char in middle_token)