test_text_embedding.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. import os
  2. import pytest
  3. from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
  4. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  5. from core.model_runtime.model_providers.replicate.text_embedding.text_embedding import ReplicateEmbeddingModel
  6. def test_validate_credentials_one():
  7. model = ReplicateEmbeddingModel()
  8. with pytest.raises(CredentialsValidateFailedError):
  9. model.validate_credentials(
  10. model='replicate/all-mpnet-base-v2',
  11. credentials={
  12. 'replicate_api_token': 'invalid_key',
  13. 'model_version': 'b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305'
  14. }
  15. )
  16. model.validate_credentials(
  17. model='replicate/all-mpnet-base-v2',
  18. credentials={
  19. 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
  20. 'model_version': 'b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305'
  21. }
  22. )
  23. def test_validate_credentials_two():
  24. model = ReplicateEmbeddingModel()
  25. with pytest.raises(CredentialsValidateFailedError):
  26. model.validate_credentials(
  27. model='nateraw/bge-large-en-v1.5',
  28. credentials={
  29. 'replicate_api_token': 'invalid_key',
  30. 'model_version': '9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1'
  31. }
  32. )
  33. model.validate_credentials(
  34. model='nateraw/bge-large-en-v1.5',
  35. credentials={
  36. 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
  37. 'model_version': '9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1'
  38. }
  39. )
  40. def test_invoke_model_one():
  41. model = ReplicateEmbeddingModel()
  42. result = model.invoke(
  43. model='nateraw/bge-large-en-v1.5',
  44. credentials={
  45. 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
  46. 'model_version': '9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1'
  47. },
  48. texts=[
  49. "hello",
  50. "world"
  51. ],
  52. user="abc-123"
  53. )
  54. assert isinstance(result, TextEmbeddingResult)
  55. assert len(result.embeddings) == 2
  56. assert result.usage.total_tokens == 2
  57. def test_invoke_model_two():
  58. model = ReplicateEmbeddingModel()
  59. result = model.invoke(
  60. model='andreasjansson/clip-features',
  61. credentials={
  62. 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
  63. 'model_version': '75b33f253f7714a281ad3e9b28f63e3232d583716ef6718f2e46641077ea040a'
  64. },
  65. texts=[
  66. "hello",
  67. "world"
  68. ],
  69. user="abc-123"
  70. )
  71. assert isinstance(result, TextEmbeddingResult)
  72. assert len(result.embeddings) == 2
  73. assert result.usage.total_tokens == 2
  74. def test_invoke_model_three():
  75. model = ReplicateEmbeddingModel()
  76. result = model.invoke(
  77. model='replicate/all-mpnet-base-v2',
  78. credentials={
  79. 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
  80. 'model_version': 'b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305'
  81. },
  82. texts=[
  83. "hello",
  84. "world"
  85. ],
  86. user="abc-123"
  87. )
  88. assert isinstance(result, TextEmbeddingResult)
  89. assert len(result.embeddings) == 2
  90. assert result.usage.total_tokens == 2
  91. def test_invoke_model_four():
  92. model = ReplicateEmbeddingModel()
  93. result = model.invoke(
  94. model='nateraw/jina-embeddings-v2-base-en',
  95. credentials={
  96. 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
  97. 'model_version': 'f8367a1c072ba2bc28af549d1faeacfe9b88b3f0e475add7a75091dac507f79e'
  98. },
  99. texts=[
  100. "hello",
  101. "world"
  102. ],
  103. user="abc-123"
  104. )
  105. assert isinstance(result, TextEmbeddingResult)
  106. assert len(result.embeddings) == 2
  107. assert result.usage.total_tokens == 2
  108. def test_get_num_tokens():
  109. model = ReplicateEmbeddingModel()
  110. num_tokens = model.get_num_tokens(
  111. model='nateraw/jina-embeddings-v2-base-en',
  112. credentials={
  113. 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
  114. 'model_version': 'f8367a1c072ba2bc28af549d1faeacfe9b88b3f0e475add7a75091dac507f79e'
  115. },
  116. texts=[
  117. "hello",
  118. "world"
  119. ]
  120. )
  121. assert num_tokens == 2