test_text_embedding.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. import os
  2. import pytest
  3. from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
  4. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  5. from core.model_runtime.model_providers.openai.text_embedding.text_embedding import OpenAITextEmbeddingModel
  6. from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
  7. @pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True)
  8. def test_validate_credentials(setup_openai_mock):
  9. model = OpenAITextEmbeddingModel()
  10. with pytest.raises(CredentialsValidateFailedError):
  11. model.validate_credentials(
  12. model='text-embedding-ada-002',
  13. credentials={
  14. 'openai_api_key': 'invalid_key'
  15. }
  16. )
  17. model.validate_credentials(
  18. model='text-embedding-ada-002',
  19. credentials={
  20. 'openai_api_key': os.environ.get('OPENAI_API_KEY')
  21. }
  22. )
  23. @pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True)
  24. def test_invoke_model(setup_openai_mock):
  25. model = OpenAITextEmbeddingModel()
  26. result = model.invoke(
  27. model='text-embedding-ada-002',
  28. credentials={
  29. 'openai_api_key': os.environ.get('OPENAI_API_KEY'),
  30. 'openai_api_base': 'https://api.openai.com'
  31. },
  32. texts=[
  33. "hello",
  34. "world"
  35. ],
  36. user="abc-123"
  37. )
  38. assert isinstance(result, TextEmbeddingResult)
  39. assert len(result.embeddings) == 2
  40. assert result.usage.total_tokens == 2
  41. def test_get_num_tokens():
  42. model = OpenAITextEmbeddingModel()
  43. num_tokens = model.get_num_tokens(
  44. model='text-embedding-ada-002',
  45. credentials={
  46. 'openai_api_key': os.environ.get('OPENAI_API_KEY'),
  47. 'openai_api_base': 'https://api.openai.com'
  48. },
  49. texts=[
  50. "hello",
  51. "world"
  52. ]
  53. )
  54. assert num_tokens == 2