test_llm.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. import os
  2. import pytest
  3. from typing import Generator
  4. from time import sleep
  5. from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage, SystemPromptMessage
  6. from core.model_runtime.entities.model_entities import AIModelEntity
  7. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunkDelta, \
  8. LLMResultChunk
  9. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  10. from core.model_runtime.model_providers.baichuan.llm.llm import BaichuanLarguageModel
  11. def test_predefined_models():
  12. model = BaichuanLarguageModel()
  13. model_schemas = model.predefined_models()
  14. assert len(model_schemas) >= 1
  15. assert isinstance(model_schemas[0], AIModelEntity)
  16. def test_validate_credentials_for_chat_model():
  17. sleep(3)
  18. model = BaichuanLarguageModel()
  19. with pytest.raises(CredentialsValidateFailedError):
  20. model.validate_credentials(
  21. model='baichuan2-turbo',
  22. credentials={
  23. 'api_key': 'invalid_key',
  24. 'secret_key': 'invalid_key'
  25. }
  26. )
  27. model.validate_credentials(
  28. model='baichuan2-turbo',
  29. credentials={
  30. 'api_key': os.environ.get('BAICHUAN_API_KEY'),
  31. 'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
  32. }
  33. )
  34. def test_invoke_model():
  35. sleep(3)
  36. model = BaichuanLarguageModel()
  37. response = model.invoke(
  38. model='baichuan2-turbo',
  39. credentials={
  40. 'api_key': os.environ.get('BAICHUAN_API_KEY'),
  41. 'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
  42. },
  43. prompt_messages=[
  44. UserPromptMessage(
  45. content='Hello World!'
  46. )
  47. ],
  48. model_parameters={
  49. 'temperature': 0.7,
  50. 'top_p': 1.0,
  51. 'top_k': 1,
  52. },
  53. stop=['you'],
  54. user="abc-123",
  55. stream=False
  56. )
  57. assert isinstance(response, LLMResult)
  58. assert len(response.message.content) > 0
  59. assert response.usage.total_tokens > 0
  60. def test_invoke_model_with_system_message():
  61. sleep(3)
  62. model = BaichuanLarguageModel()
  63. response = model.invoke(
  64. model='baichuan2-turbo',
  65. credentials={
  66. 'api_key': os.environ.get('BAICHUAN_API_KEY'),
  67. 'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
  68. },
  69. prompt_messages=[
  70. SystemPromptMessage(
  71. content='请记住你是Kasumi。'
  72. ),
  73. UserPromptMessage(
  74. content='现在告诉我你是谁?'
  75. )
  76. ],
  77. model_parameters={
  78. 'temperature': 0.7,
  79. 'top_p': 1.0,
  80. 'top_k': 1,
  81. },
  82. stop=['you'],
  83. user="abc-123",
  84. stream=False
  85. )
  86. assert isinstance(response, LLMResult)
  87. assert len(response.message.content) > 0
  88. assert response.usage.total_tokens > 0
  89. def test_invoke_stream_model():
  90. sleep(3)
  91. model = BaichuanLarguageModel()
  92. response = model.invoke(
  93. model='baichuan2-turbo',
  94. credentials={
  95. 'api_key': os.environ.get('BAICHUAN_API_KEY'),
  96. 'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
  97. },
  98. prompt_messages=[
  99. UserPromptMessage(
  100. content='Hello World!'
  101. )
  102. ],
  103. model_parameters={
  104. 'temperature': 0.7,
  105. 'top_p': 1.0,
  106. 'top_k': 1,
  107. },
  108. stop=['you'],
  109. stream=True,
  110. user="abc-123"
  111. )
  112. assert isinstance(response, Generator)
  113. for chunk in response:
  114. assert isinstance(chunk, LLMResultChunk)
  115. assert isinstance(chunk.delta, LLMResultChunkDelta)
  116. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  117. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  118. def test_invoke_with_search():
  119. sleep(3)
  120. model = BaichuanLarguageModel()
  121. response = model.invoke(
  122. model='baichuan2-turbo',
  123. credentials={
  124. 'api_key': os.environ.get('BAICHUAN_API_KEY'),
  125. 'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
  126. },
  127. prompt_messages=[
  128. UserPromptMessage(
  129. content='北京今天的天气怎么样'
  130. )
  131. ],
  132. model_parameters={
  133. 'temperature': 0.7,
  134. 'top_p': 1.0,
  135. 'top_k': 1,
  136. 'with_search_enhance': True,
  137. },
  138. stop=['you'],
  139. stream=True,
  140. user="abc-123"
  141. )
  142. assert isinstance(response, Generator)
  143. total_message = ''
  144. for chunk in response:
  145. assert isinstance(chunk, LLMResultChunk)
  146. assert isinstance(chunk.delta, LLMResultChunkDelta)
  147. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  148. assert len(chunk.delta.message.content) > 0 if not chunk.delta.finish_reason else True
  149. total_message += chunk.delta.message.content
  150. assert '不' not in total_message
  151. def test_get_num_tokens():
  152. sleep(3)
  153. model = BaichuanLarguageModel()
  154. response = model.get_num_tokens(
  155. model='baichuan2-turbo',
  156. credentials={
  157. 'api_key': os.environ.get('BAICHUAN_API_KEY'),
  158. 'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
  159. },
  160. prompt_messages=[
  161. UserPromptMessage(
  162. content='Hello World!'
  163. )
  164. ],
  165. tools=[]
  166. )
  167. assert isinstance(response, int)
  168. assert response == 9