test_embedding.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. import os
  2. import pytest
  3. from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
  4. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  5. from core.model_runtime.model_providers.volcengine_maas.text_embedding.text_embedding import (
  6. VolcengineMaaSTextEmbeddingModel,
  7. )
  8. def test_validate_credentials():
  9. model = VolcengineMaaSTextEmbeddingModel()
  10. with pytest.raises(CredentialsValidateFailedError):
  11. model.validate_credentials(
  12. model='NOT IMPORTANT',
  13. credentials={
  14. 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
  15. 'volc_region': 'cn-beijing',
  16. 'volc_access_key_id': 'INVALID',
  17. 'volc_secret_access_key': 'INVALID',
  18. 'endpoint_id': 'INVALID',
  19. }
  20. )
  21. model.validate_credentials(
  22. model='NOT IMPORTANT',
  23. credentials={
  24. 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
  25. 'volc_region': 'cn-beijing',
  26. 'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
  27. 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
  28. 'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'),
  29. },
  30. )
  31. def test_invoke_model():
  32. model = VolcengineMaaSTextEmbeddingModel()
  33. result = model.invoke(
  34. model='NOT IMPORTANT',
  35. credentials={
  36. 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
  37. 'volc_region': 'cn-beijing',
  38. 'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
  39. 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
  40. 'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'),
  41. },
  42. texts=[
  43. "hello",
  44. "world"
  45. ],
  46. user="abc-123"
  47. )
  48. assert isinstance(result, TextEmbeddingResult)
  49. assert len(result.embeddings) == 2
  50. assert result.usage.total_tokens > 0
  51. def test_get_num_tokens():
  52. model = VolcengineMaaSTextEmbeddingModel()
  53. num_tokens = model.get_num_tokens(
  54. model='NOT IMPORTANT',
  55. credentials={
  56. 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
  57. 'volc_region': 'cn-beijing',
  58. 'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
  59. 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
  60. 'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'),
  61. },
  62. texts=[
  63. "hello",
  64. "world"
  65. ]
  66. )
  67. assert num_tokens == 2