test_llm.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. import os
  2. from collections.abc import Generator
  3. import pytest
  4. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
  5. from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage
  6. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  7. from core.model_runtime.model_providers.bedrock.llm.llm import BedrockLargeLanguageModel
  8. def test_validate_credentials():
  9. model = BedrockLargeLanguageModel()
  10. with pytest.raises(CredentialsValidateFailedError):
  11. model.validate_credentials(model="meta.llama2-13b-chat-v1", credentials={"anthropic_api_key": "invalid_key"})
  12. model.validate_credentials(
  13. model="meta.llama2-13b-chat-v1",
  14. credentials={
  15. "aws_region": os.getenv("AWS_REGION"),
  16. "aws_access_key": os.getenv("AWS_ACCESS_KEY"),
  17. "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
  18. },
  19. )
  20. def test_invoke_model():
  21. model = BedrockLargeLanguageModel()
  22. response = model.invoke(
  23. model="meta.llama2-13b-chat-v1",
  24. credentials={
  25. "aws_region": os.getenv("AWS_REGION"),
  26. "aws_access_key": os.getenv("AWS_ACCESS_KEY"),
  27. "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
  28. },
  29. prompt_messages=[
  30. SystemPromptMessage(
  31. content="You are a helpful AI assistant.",
  32. ),
  33. UserPromptMessage(content="Hello World!"),
  34. ],
  35. model_parameters={"temperature": 0.0, "top_p": 1.0, "max_tokens_to_sample": 10},
  36. stop=["How"],
  37. stream=False,
  38. user="abc-123",
  39. )
  40. assert isinstance(response, LLMResult)
  41. assert len(response.message.content) > 0
  42. def test_invoke_stream_model():
  43. model = BedrockLargeLanguageModel()
  44. response = model.invoke(
  45. model="meta.llama2-13b-chat-v1",
  46. credentials={
  47. "aws_region": os.getenv("AWS_REGION"),
  48. "aws_access_key": os.getenv("AWS_ACCESS_KEY"),
  49. "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
  50. },
  51. prompt_messages=[
  52. SystemPromptMessage(
  53. content="You are a helpful AI assistant.",
  54. ),
  55. UserPromptMessage(content="Hello World!"),
  56. ],
  57. model_parameters={"temperature": 0.0, "max_tokens_to_sample": 100},
  58. stream=True,
  59. user="abc-123",
  60. )
  61. assert isinstance(response, Generator)
  62. for chunk in response:
  63. print(chunk)
  64. assert isinstance(chunk, LLMResultChunk)
  65. assert isinstance(chunk.delta, LLMResultChunkDelta)
  66. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  67. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  68. def test_get_num_tokens():
  69. model = BedrockLargeLanguageModel()
  70. num_tokens = model.get_num_tokens(
  71. model="meta.llama2-13b-chat-v1",
  72. credentials={
  73. "aws_region": os.getenv("AWS_REGION"),
  74. "aws_access_key": os.getenv("AWS_ACCESS_KEY"),
  75. "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
  76. },
  77. messages=[
  78. SystemPromptMessage(
  79. content="You are a helpful AI assistant.",
  80. ),
  81. UserPromptMessage(content="Hello World!"),
  82. ],
  83. )
  84. assert num_tokens == 18