test_text_embedding.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. import os
  2. import pytest
  3. from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
  4. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  5. from core.model_runtime.model_providers.replicate.text_embedding.text_embedding import ReplicateEmbeddingModel
  6. def test_validate_credentials_one():
  7. model = ReplicateEmbeddingModel()
  8. with pytest.raises(CredentialsValidateFailedError):
  9. model.validate_credentials(
  10. model="replicate/all-mpnet-base-v2",
  11. credentials={
  12. "replicate_api_token": "invalid_key",
  13. "model_version": "b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305",
  14. },
  15. )
  16. model.validate_credentials(
  17. model="replicate/all-mpnet-base-v2",
  18. credentials={
  19. "replicate_api_token": os.environ.get("REPLICATE_API_KEY"),
  20. "model_version": "b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305",
  21. },
  22. )
  23. def test_validate_credentials_two():
  24. model = ReplicateEmbeddingModel()
  25. with pytest.raises(CredentialsValidateFailedError):
  26. model.validate_credentials(
  27. model="nateraw/bge-large-en-v1.5",
  28. credentials={
  29. "replicate_api_token": "invalid_key",
  30. "model_version": "9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1",
  31. },
  32. )
  33. model.validate_credentials(
  34. model="nateraw/bge-large-en-v1.5",
  35. credentials={
  36. "replicate_api_token": os.environ.get("REPLICATE_API_KEY"),
  37. "model_version": "9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1",
  38. },
  39. )
  40. def test_invoke_model_one():
  41. model = ReplicateEmbeddingModel()
  42. result = model.invoke(
  43. model="nateraw/bge-large-en-v1.5",
  44. credentials={
  45. "replicate_api_token": os.environ.get("REPLICATE_API_KEY"),
  46. "model_version": "9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1",
  47. },
  48. texts=["hello", "world"],
  49. user="abc-123",
  50. )
  51. assert isinstance(result, TextEmbeddingResult)
  52. assert len(result.embeddings) == 2
  53. assert result.usage.total_tokens == 2
  54. def test_invoke_model_two():
  55. model = ReplicateEmbeddingModel()
  56. result = model.invoke(
  57. model="andreasjansson/clip-features",
  58. credentials={
  59. "replicate_api_token": os.environ.get("REPLICATE_API_KEY"),
  60. "model_version": "75b33f253f7714a281ad3e9b28f63e3232d583716ef6718f2e46641077ea040a",
  61. },
  62. texts=["hello", "world"],
  63. user="abc-123",
  64. )
  65. assert isinstance(result, TextEmbeddingResult)
  66. assert len(result.embeddings) == 2
  67. assert result.usage.total_tokens == 2
  68. def test_invoke_model_three():
  69. model = ReplicateEmbeddingModel()
  70. result = model.invoke(
  71. model="replicate/all-mpnet-base-v2",
  72. credentials={
  73. "replicate_api_token": os.environ.get("REPLICATE_API_KEY"),
  74. "model_version": "b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305",
  75. },
  76. texts=["hello", "world"],
  77. user="abc-123",
  78. )
  79. assert isinstance(result, TextEmbeddingResult)
  80. assert len(result.embeddings) == 2
  81. assert result.usage.total_tokens == 2
  82. def test_invoke_model_four():
  83. model = ReplicateEmbeddingModel()
  84. result = model.invoke(
  85. model="nateraw/jina-embeddings-v2-base-en",
  86. credentials={
  87. "replicate_api_token": os.environ.get("REPLICATE_API_KEY"),
  88. "model_version": "f8367a1c072ba2bc28af549d1faeacfe9b88b3f0e475add7a75091dac507f79e",
  89. },
  90. texts=["hello", "world"],
  91. user="abc-123",
  92. )
  93. assert isinstance(result, TextEmbeddingResult)
  94. assert len(result.embeddings) == 2
  95. assert result.usage.total_tokens == 2
  96. def test_get_num_tokens():
  97. model = ReplicateEmbeddingModel()
  98. num_tokens = model.get_num_tokens(
  99. model="nateraw/jina-embeddings-v2-base-en",
  100. credentials={
  101. "replicate_api_token": os.environ.get("REPLICATE_API_KEY"),
  102. "model_version": "f8367a1c072ba2bc28af549d1faeacfe9b88b3f0e475add7a75091dac507f79e",
  103. },
  104. texts=["hello", "world"],
  105. )
  106. assert num_tokens == 2