test_llm.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. import os
  2. from collections.abc import Generator
  3. import pytest
  4. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
  5. from core.model_runtime.entities.message_entities import (
  6. AssistantPromptMessage,
  7. PromptMessageTool,
  8. SystemPromptMessage,
  9. UserPromptMessage,
  10. )
  11. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  12. from core.model_runtime.model_providers.zhipuai.llm.llm import ZhipuAILargeLanguageModel
  13. def test_validate_credentials():
  14. model = ZhipuAILargeLanguageModel()
  15. with pytest.raises(CredentialsValidateFailedError):
  16. model.validate_credentials(
  17. model='chatglm_turbo',
  18. credentials={
  19. 'api_key': 'invalid_key'
  20. }
  21. )
  22. model.validate_credentials(
  23. model='chatglm_turbo',
  24. credentials={
  25. 'api_key': os.environ.get('ZHIPUAI_API_KEY')
  26. }
  27. )
  28. def test_invoke_model():
  29. model = ZhipuAILargeLanguageModel()
  30. response = model.invoke(
  31. model='chatglm_turbo',
  32. credentials={
  33. 'api_key': os.environ.get('ZHIPUAI_API_KEY')
  34. },
  35. prompt_messages=[
  36. UserPromptMessage(
  37. content='Who are you?'
  38. )
  39. ],
  40. model_parameters={
  41. 'temperature': 0.9,
  42. 'top_p': 0.7
  43. },
  44. stop=['How'],
  45. stream=False,
  46. user="abc-123"
  47. )
  48. assert isinstance(response, LLMResult)
  49. assert len(response.message.content) > 0
  50. def test_invoke_stream_model():
  51. model = ZhipuAILargeLanguageModel()
  52. response = model.invoke(
  53. model='chatglm_turbo',
  54. credentials={
  55. 'api_key': os.environ.get('ZHIPUAI_API_KEY')
  56. },
  57. prompt_messages=[
  58. UserPromptMessage(
  59. content='Hello World!'
  60. )
  61. ],
  62. model_parameters={
  63. 'temperature': 0.9,
  64. 'top_p': 0.7
  65. },
  66. stream=True,
  67. user="abc-123"
  68. )
  69. assert isinstance(response, Generator)
  70. for chunk in response:
  71. assert isinstance(chunk, LLMResultChunk)
  72. assert isinstance(chunk.delta, LLMResultChunkDelta)
  73. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  74. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  75. def test_get_num_tokens():
  76. model = ZhipuAILargeLanguageModel()
  77. num_tokens = model.get_num_tokens(
  78. model='chatglm_turbo',
  79. credentials={
  80. 'api_key': os.environ.get('ZHIPUAI_API_KEY')
  81. },
  82. prompt_messages=[
  83. SystemPromptMessage(
  84. content='You are a helpful AI assistant.',
  85. ),
  86. UserPromptMessage(
  87. content='Hello World!'
  88. )
  89. ]
  90. )
  91. assert num_tokens == 14
  92. def test_get_tools_num_tokens():
  93. model = ZhipuAILargeLanguageModel()
  94. num_tokens = model.get_num_tokens(
  95. model='tools',
  96. credentials={
  97. 'api_key': os.environ.get('ZHIPUAI_API_KEY')
  98. },
  99. tools=[
  100. PromptMessageTool(
  101. name='get_current_weather',
  102. description='Get the current weather in a given location',
  103. parameters={
  104. "type": "object",
  105. "properties": {
  106. "location": {
  107. "type": "string",
  108. "description": "The city and state e.g. San Francisco, CA"
  109. },
  110. "unit": {
  111. "type": "string",
  112. "enum": [
  113. "c",
  114. "f"
  115. ]
  116. }
  117. },
  118. "required": [
  119. "location"
  120. ]
  121. }
  122. )
  123. ],
  124. prompt_messages=[
  125. SystemPromptMessage(
  126. content='You are a helpful AI assistant.',
  127. ),
  128. UserPromptMessage(
  129. content='Hello World!'
  130. )
  131. ]
  132. )
  133. assert num_tokens == 108