test_llm.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. import os
  2. from collections.abc import Generator
  3. import pytest
  4. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
  5. from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage
  6. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  7. from core.model_runtime.model_providers.hunyuan.llm.llm import HunyuanLargeLanguageModel
  8. def test_validate_credentials():
  9. model = HunyuanLargeLanguageModel()
  10. with pytest.raises(CredentialsValidateFailedError):
  11. model.validate_credentials(
  12. model="hunyuan-standard", credentials={"secret_id": "invalid_key", "secret_key": "invalid_key"}
  13. )
  14. model.validate_credentials(
  15. model="hunyuan-standard",
  16. credentials={
  17. "secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
  18. "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
  19. },
  20. )
  21. def test_invoke_model():
  22. model = HunyuanLargeLanguageModel()
  23. response = model.invoke(
  24. model="hunyuan-standard",
  25. credentials={
  26. "secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
  27. "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
  28. },
  29. prompt_messages=[UserPromptMessage(content="Hi")],
  30. model_parameters={"temperature": 0.5, "max_tokens": 10},
  31. stop=["How"],
  32. stream=False,
  33. user="abc-123",
  34. )
  35. assert isinstance(response, LLMResult)
  36. assert len(response.message.content) > 0
  37. def test_invoke_stream_model():
  38. model = HunyuanLargeLanguageModel()
  39. response = model.invoke(
  40. model="hunyuan-standard",
  41. credentials={
  42. "secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
  43. "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
  44. },
  45. prompt_messages=[UserPromptMessage(content="Hi")],
  46. model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234},
  47. stream=True,
  48. user="abc-123",
  49. )
  50. assert isinstance(response, Generator)
  51. for chunk in response:
  52. assert isinstance(chunk, LLMResultChunk)
  53. assert isinstance(chunk.delta, LLMResultChunkDelta)
  54. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  55. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  56. def test_get_num_tokens():
  57. model = HunyuanLargeLanguageModel()
  58. num_tokens = model.get_num_tokens(
  59. model="hunyuan-standard",
  60. credentials={
  61. "secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
  62. "secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
  63. },
  64. prompt_messages=[
  65. SystemPromptMessage(
  66. content="You are a helpful AI assistant.",
  67. ),
  68. UserPromptMessage(content="Hello World!"),
  69. ],
  70. )
  71. assert num_tokens == 14