test_embedding.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. import os
  2. import pytest
  3. from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
  4. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  5. from core.model_runtime.model_providers.minimax.text_embedding.text_embedding import MinimaxTextEmbeddingModel
  6. def test_validate_credentials():
  7. model = MinimaxTextEmbeddingModel()
  8. with pytest.raises(CredentialsValidateFailedError):
  9. model.validate_credentials(
  10. model='embo-01',
  11. credentials={
  12. 'minimax_api_key': 'invalid_key',
  13. 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
  14. }
  15. )
  16. model.validate_credentials(
  17. model='embo-01',
  18. credentials={
  19. 'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
  20. 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
  21. }
  22. )
  23. def test_invoke_model():
  24. model = MinimaxTextEmbeddingModel()
  25. result = model.invoke(
  26. model='embo-01',
  27. credentials={
  28. 'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
  29. 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
  30. },
  31. texts=[
  32. "hello",
  33. "world"
  34. ],
  35. user="abc-123"
  36. )
  37. assert isinstance(result, TextEmbeddingResult)
  38. assert len(result.embeddings) == 2
  39. assert result.usage.total_tokens == 16
  40. def test_get_num_tokens():
  41. model = MinimaxTextEmbeddingModel()
  42. num_tokens = model.get_num_tokens(
  43. model='embo-01',
  44. credentials={
  45. 'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
  46. 'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
  47. },
  48. texts=[
  49. "hello",
  50. "world"
  51. ]
  52. )
  53. assert num_tokens == 2