test_rerank.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. import os
  2. import pytest
  3. from core.model_runtime.entities.rerank_entities import RerankResult
  4. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  5. from core.model_runtime.model_providers.cohere.rerank.rerank import CohereRerankModel
  6. def test_validate_credentials():
  7. model = CohereRerankModel()
  8. with pytest.raises(CredentialsValidateFailedError):
  9. model.validate_credentials(
  10. model='rerank-english-v2.0',
  11. credentials={
  12. 'api_key': 'invalid_key'
  13. }
  14. )
  15. model.validate_credentials(
  16. model='rerank-english-v2.0',
  17. credentials={
  18. 'api_key': os.environ.get('COHERE_API_KEY')
  19. }
  20. )
  21. def test_invoke_model():
  22. model = CohereRerankModel()
  23. result = model.invoke(
  24. model='rerank-english-v2.0',
  25. credentials={
  26. 'api_key': os.environ.get('COHERE_API_KEY')
  27. },
  28. query="What is the capital of the United States?",
  29. docs=[
  30. "Carson City is the capital city of the American state of Nevada. At the 2010 United States "
  31. "Census, Carson City had a population of 55,274.",
  32. "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) "
  33. "is the capital of the United States. It is a federal district. The President of the USA and many major "
  34. "national government offices are in the territory. This makes it the political center of the United "
  35. "States of America."
  36. ],
  37. score_threshold=0.8
  38. )
  39. assert isinstance(result, RerankResult)
  40. assert len(result.docs) == 1
  41. assert result.docs[0].index == 1
  42. assert result.docs[0].score >= 0.8