test_llm.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. import os
  2. from collections.abc import Generator
  3. import pytest
  4. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
  5. from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage
  6. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  7. from core.model_runtime.model_providers.replicate.llm.llm import ReplicateLargeLanguageModel
  8. def test_validate_credentials():
  9. model = ReplicateLargeLanguageModel()
  10. with pytest.raises(CredentialsValidateFailedError):
  11. model.validate_credentials(
  12. model="meta/llama-2-13b-chat",
  13. credentials={
  14. "replicate_api_token": "invalid_key",
  15. "model_version": "f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d",
  16. },
  17. )
  18. model.validate_credentials(
  19. model="meta/llama-2-13b-chat",
  20. credentials={
  21. "replicate_api_token": os.environ.get("REPLICATE_API_KEY"),
  22. "model_version": "f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d",
  23. },
  24. )
  25. def test_invoke_model():
  26. model = ReplicateLargeLanguageModel()
  27. response = model.invoke(
  28. model="meta/llama-2-13b-chat",
  29. credentials={
  30. "replicate_api_token": os.environ.get("REPLICATE_API_KEY"),
  31. "model_version": "f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d",
  32. },
  33. prompt_messages=[
  34. SystemPromptMessage(
  35. content="You are a helpful AI assistant.",
  36. ),
  37. UserPromptMessage(content="Who are you?"),
  38. ],
  39. model_parameters={
  40. "temperature": 1.0,
  41. "top_k": 2,
  42. "top_p": 0.5,
  43. },
  44. stop=["How"],
  45. stream=False,
  46. user="abc-123",
  47. )
  48. assert isinstance(response, LLMResult)
  49. assert len(response.message.content) > 0
  50. def test_invoke_stream_model():
  51. model = ReplicateLargeLanguageModel()
  52. response = model.invoke(
  53. model="mistralai/mixtral-8x7b-instruct-v0.1",
  54. credentials={
  55. "replicate_api_token": os.environ.get("REPLICATE_API_KEY"),
  56. "model_version": "2b56576fcfbe32fa0526897d8385dd3fb3d36ba6fd0dbe033c72886b81ade93e",
  57. },
  58. prompt_messages=[
  59. SystemPromptMessage(
  60. content="You are a helpful AI assistant.",
  61. ),
  62. UserPromptMessage(content="Who are you?"),
  63. ],
  64. model_parameters={
  65. "temperature": 1.0,
  66. "top_k": 2,
  67. "top_p": 0.5,
  68. },
  69. stop=["How"],
  70. stream=True,
  71. user="abc-123",
  72. )
  73. assert isinstance(response, Generator)
  74. for chunk in response:
  75. assert isinstance(chunk, LLMResultChunk)
  76. assert isinstance(chunk.delta, LLMResultChunkDelta)
  77. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  78. def test_get_num_tokens():
  79. model = ReplicateLargeLanguageModel()
  80. num_tokens = model.get_num_tokens(
  81. model="",
  82. credentials={
  83. "replicate_api_token": os.environ.get("REPLICATE_API_KEY"),
  84. "model_version": "2b56576fcfbe32fa0526897d8385dd3fb3d36ba6fd0dbe033c72886b81ade93e",
  85. },
  86. prompt_messages=[
  87. SystemPromptMessage(
  88. content="You are a helpful AI assistant.",
  89. ),
  90. UserPromptMessage(content="Hello World!"),
  91. ],
  92. )
  93. assert num_tokens == 14