test_llm.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. import os
  2. from collections.abc import Generator
  3. from time import sleep
  4. import pytest
  5. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
  6. from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage
  7. from core.model_runtime.entities.model_entities import AIModelEntity
  8. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  9. from core.model_runtime.model_providers.baichuan.llm.llm import BaichuanLarguageModel
  10. def test_predefined_models():
  11. model = BaichuanLarguageModel()
  12. model_schemas = model.predefined_models()
  13. assert len(model_schemas) >= 1
  14. assert isinstance(model_schemas[0], AIModelEntity)
  15. def test_validate_credentials_for_chat_model():
  16. sleep(3)
  17. model = BaichuanLarguageModel()
  18. with pytest.raises(CredentialsValidateFailedError):
  19. model.validate_credentials(
  20. model='baichuan2-turbo',
  21. credentials={
  22. 'api_key': 'invalid_key',
  23. 'secret_key': 'invalid_key'
  24. }
  25. )
  26. model.validate_credentials(
  27. model='baichuan2-turbo',
  28. credentials={
  29. 'api_key': os.environ.get('BAICHUAN_API_KEY'),
  30. 'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
  31. }
  32. )
  33. def test_invoke_model():
  34. sleep(3)
  35. model = BaichuanLarguageModel()
  36. response = model.invoke(
  37. model='baichuan2-turbo',
  38. credentials={
  39. 'api_key': os.environ.get('BAICHUAN_API_KEY'),
  40. 'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
  41. },
  42. prompt_messages=[
  43. UserPromptMessage(
  44. content='Hello World!'
  45. )
  46. ],
  47. model_parameters={
  48. 'temperature': 0.7,
  49. 'top_p': 1.0,
  50. 'top_k': 1,
  51. },
  52. stop=['you'],
  53. user="abc-123",
  54. stream=False
  55. )
  56. assert isinstance(response, LLMResult)
  57. assert len(response.message.content) > 0
  58. assert response.usage.total_tokens > 0
  59. def test_invoke_model_with_system_message():
  60. sleep(3)
  61. model = BaichuanLarguageModel()
  62. response = model.invoke(
  63. model='baichuan2-turbo',
  64. credentials={
  65. 'api_key': os.environ.get('BAICHUAN_API_KEY'),
  66. 'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
  67. },
  68. prompt_messages=[
  69. SystemPromptMessage(
  70. content='请记住你是Kasumi。'
  71. ),
  72. UserPromptMessage(
  73. content='现在告诉我你是谁?'
  74. )
  75. ],
  76. model_parameters={
  77. 'temperature': 0.7,
  78. 'top_p': 1.0,
  79. 'top_k': 1,
  80. },
  81. stop=['you'],
  82. user="abc-123",
  83. stream=False
  84. )
  85. assert isinstance(response, LLMResult)
  86. assert len(response.message.content) > 0
  87. assert response.usage.total_tokens > 0
  88. def test_invoke_stream_model():
  89. sleep(3)
  90. model = BaichuanLarguageModel()
  91. response = model.invoke(
  92. model='baichuan2-turbo',
  93. credentials={
  94. 'api_key': os.environ.get('BAICHUAN_API_KEY'),
  95. 'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
  96. },
  97. prompt_messages=[
  98. UserPromptMessage(
  99. content='Hello World!'
  100. )
  101. ],
  102. model_parameters={
  103. 'temperature': 0.7,
  104. 'top_p': 1.0,
  105. 'top_k': 1,
  106. },
  107. stop=['you'],
  108. stream=True,
  109. user="abc-123"
  110. )
  111. assert isinstance(response, Generator)
  112. for chunk in response:
  113. assert isinstance(chunk, LLMResultChunk)
  114. assert isinstance(chunk.delta, LLMResultChunkDelta)
  115. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  116. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  117. def test_invoke_with_search():
  118. sleep(3)
  119. model = BaichuanLarguageModel()
  120. response = model.invoke(
  121. model='baichuan2-turbo',
  122. credentials={
  123. 'api_key': os.environ.get('BAICHUAN_API_KEY'),
  124. 'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
  125. },
  126. prompt_messages=[
  127. UserPromptMessage(
  128. content='北京今天的天气怎么样'
  129. )
  130. ],
  131. model_parameters={
  132. 'temperature': 0.7,
  133. 'top_p': 1.0,
  134. 'top_k': 1,
  135. 'with_search_enhance': True,
  136. },
  137. stop=['you'],
  138. stream=True,
  139. user="abc-123"
  140. )
  141. assert isinstance(response, Generator)
  142. total_message = ''
  143. for chunk in response:
  144. assert isinstance(chunk, LLMResultChunk)
  145. assert isinstance(chunk.delta, LLMResultChunkDelta)
  146. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  147. assert len(chunk.delta.message.content) > 0 if not chunk.delta.finish_reason else True
  148. total_message += chunk.delta.message.content
  149. assert '不' not in total_message
  150. def test_get_num_tokens():
  151. sleep(3)
  152. model = BaichuanLarguageModel()
  153. response = model.get_num_tokens(
  154. model='baichuan2-turbo',
  155. credentials={
  156. 'api_key': os.environ.get('BAICHUAN_API_KEY'),
  157. 'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
  158. },
  159. prompt_messages=[
  160. UserPromptMessage(
  161. content='Hello World!'
  162. )
  163. ],
  164. tools=[]
  165. )
  166. assert isinstance(response, int)
  167. assert response == 9