test_llm.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. import os
  2. from collections.abc import Generator
  3. import pytest
  4. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
  5. from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage
  6. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  7. from core.model_runtime.model_providers.replicate.llm.llm import ReplicateLargeLanguageModel
  8. def test_validate_credentials():
  9. model = ReplicateLargeLanguageModel()
  10. with pytest.raises(CredentialsValidateFailedError):
  11. model.validate_credentials(
  12. model='meta/llama-2-13b-chat',
  13. credentials={
  14. 'replicate_api_token': 'invalid_key',
  15. 'model_version': 'f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d'
  16. }
  17. )
  18. model.validate_credentials(
  19. model='meta/llama-2-13b-chat',
  20. credentials={
  21. 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
  22. 'model_version': 'f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d'
  23. }
  24. )
  25. def test_invoke_model():
  26. model = ReplicateLargeLanguageModel()
  27. response = model.invoke(
  28. model='meta/llama-2-13b-chat',
  29. credentials={
  30. 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
  31. 'model_version': 'f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d'
  32. },
  33. prompt_messages=[
  34. SystemPromptMessage(
  35. content='You are a helpful AI assistant.',
  36. ),
  37. UserPromptMessage(
  38. content='Who are you?'
  39. )
  40. ],
  41. model_parameters={
  42. 'temperature': 1.0,
  43. 'top_k': 2,
  44. 'top_p': 0.5,
  45. },
  46. stop=['How'],
  47. stream=False,
  48. user="abc-123"
  49. )
  50. assert isinstance(response, LLMResult)
  51. assert len(response.message.content) > 0
  52. def test_invoke_stream_model():
  53. model = ReplicateLargeLanguageModel()
  54. response = model.invoke(
  55. model='mistralai/mixtral-8x7b-instruct-v0.1',
  56. credentials={
  57. 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
  58. 'model_version': '2b56576fcfbe32fa0526897d8385dd3fb3d36ba6fd0dbe033c72886b81ade93e'
  59. },
  60. prompt_messages=[
  61. SystemPromptMessage(
  62. content='You are a helpful AI assistant.',
  63. ),
  64. UserPromptMessage(
  65. content='Who are you?'
  66. )
  67. ],
  68. model_parameters={
  69. 'temperature': 1.0,
  70. 'top_k': 2,
  71. 'top_p': 0.5,
  72. },
  73. stop=['How'],
  74. stream=True,
  75. user="abc-123"
  76. )
  77. assert isinstance(response, Generator)
  78. for chunk in response:
  79. assert isinstance(chunk, LLMResultChunk)
  80. assert isinstance(chunk.delta, LLMResultChunkDelta)
  81. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  82. def test_get_num_tokens():
  83. model = ReplicateLargeLanguageModel()
  84. num_tokens = model.get_num_tokens(
  85. model='',
  86. credentials={
  87. 'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
  88. 'model_version': '2b56576fcfbe32fa0526897d8385dd3fb3d36ba6fd0dbe033c72886b81ade93e'
  89. },
  90. prompt_messages=[
  91. SystemPromptMessage(
  92. content='You are a helpful AI assistant.',
  93. ),
  94. UserPromptMessage(
  95. content='Hello World!'
  96. )
  97. ]
  98. )
  99. assert num_tokens == 14