test_llm.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. import os
  2. from collections.abc import Generator
  3. import pytest
  4. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
  5. from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage
  6. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  7. from core.model_runtime.model_providers.openllm.llm.llm import OpenLLMLargeLanguageModel
  8. def test_validate_credentials_for_chat_model():
  9. model = OpenLLMLargeLanguageModel()
  10. with pytest.raises(CredentialsValidateFailedError):
  11. model.validate_credentials(
  12. model="NOT IMPORTANT",
  13. credentials={
  14. "server_url": "invalid_key",
  15. },
  16. )
  17. model.validate_credentials(
  18. model="NOT IMPORTANT",
  19. credentials={
  20. "server_url": os.environ.get("OPENLLM_SERVER_URL"),
  21. },
  22. )
  23. def test_invoke_model():
  24. model = OpenLLMLargeLanguageModel()
  25. response = model.invoke(
  26. model="NOT IMPORTANT",
  27. credentials={
  28. "server_url": os.environ.get("OPENLLM_SERVER_URL"),
  29. },
  30. prompt_messages=[UserPromptMessage(content="Hello World!")],
  31. model_parameters={
  32. "temperature": 0.7,
  33. "top_p": 1.0,
  34. "top_k": 1,
  35. },
  36. stop=["you"],
  37. user="abc-123",
  38. stream=False,
  39. )
  40. assert isinstance(response, LLMResult)
  41. assert len(response.message.content) > 0
  42. assert response.usage.total_tokens > 0
  43. def test_invoke_stream_model():
  44. model = OpenLLMLargeLanguageModel()
  45. response = model.invoke(
  46. model="NOT IMPORTANT",
  47. credentials={
  48. "server_url": os.environ.get("OPENLLM_SERVER_URL"),
  49. },
  50. prompt_messages=[UserPromptMessage(content="Hello World!")],
  51. model_parameters={
  52. "temperature": 0.7,
  53. "top_p": 1.0,
  54. "top_k": 1,
  55. },
  56. stop=["you"],
  57. stream=True,
  58. user="abc-123",
  59. )
  60. assert isinstance(response, Generator)
  61. for chunk in response:
  62. assert isinstance(chunk, LLMResultChunk)
  63. assert isinstance(chunk.delta, LLMResultChunkDelta)
  64. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  65. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  66. def test_get_num_tokens():
  67. model = OpenLLMLargeLanguageModel()
  68. response = model.get_num_tokens(
  69. model="NOT IMPORTANT",
  70. credentials={
  71. "server_url": os.environ.get("OPENLLM_SERVER_URL"),
  72. },
  73. prompt_messages=[UserPromptMessage(content="Hello World!")],
  74. tools=[],
  75. )
  76. assert isinstance(response, int)
  77. assert response == 3