test_text_embedding.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. import os
  2. import pytest
  3. from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
  4. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  5. from core.model_runtime.model_providers.openai.text_embedding.text_embedding import OpenAITextEmbeddingModel
  6. from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
  7. @pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True)
  8. def test_validate_credentials(setup_openai_mock):
  9. model = OpenAITextEmbeddingModel()
  10. with pytest.raises(CredentialsValidateFailedError):
  11. model.validate_credentials(
  12. model='text-embedding-ada-002',
  13. credentials={
  14. 'openai_api_key': 'invalid_key'
  15. }
  16. )
  17. model.validate_credentials(
  18. model='text-embedding-ada-002',
  19. credentials={
  20. 'openai_api_key': os.environ.get('OPENAI_API_KEY')
  21. }
  22. )
  23. @pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True)
  24. def test_invoke_model(setup_openai_mock):
  25. model = OpenAITextEmbeddingModel()
  26. result = model.invoke(
  27. model='text-embedding-ada-002',
  28. credentials={
  29. 'openai_api_key': os.environ.get('OPENAI_API_KEY'),
  30. 'openai_api_base': 'https://api.openai.com'
  31. },
  32. texts=[
  33. "hello",
  34. "world",
  35. " ".join(["long_text"] * 100),
  36. " ".join(["another_long_text"] * 100)
  37. ],
  38. user="abc-123"
  39. )
  40. assert isinstance(result, TextEmbeddingResult)
  41. assert len(result.embeddings) == 4
  42. assert result.usage.total_tokens == 2
  43. def test_get_num_tokens():
  44. model = OpenAITextEmbeddingModel()
  45. num_tokens = model.get_num_tokens(
  46. model='text-embedding-ada-002',
  47. credentials={
  48. 'openai_api_key': os.environ.get('OPENAI_API_KEY'),
  49. 'openai_api_base': 'https://api.openai.com'
  50. },
  51. texts=[
  52. "hello",
  53. "world"
  54. ]
  55. )
  56. assert num_tokens == 2