test_llm.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. import os
  2. from collections.abc import Generator
  3. import pytest
  4. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
  5. from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage
  6. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  7. from core.model_runtime.model_providers.cohere.llm.llm import CohereLargeLanguageModel
  8. def test_validate_credentials_for_chat_model():
  9. model = CohereLargeLanguageModel()
  10. with pytest.raises(CredentialsValidateFailedError):
  11. model.validate_credentials(
  12. model='command-light-chat',
  13. credentials={
  14. 'api_key': 'invalid_key'
  15. }
  16. )
  17. model.validate_credentials(
  18. model='command-light-chat',
  19. credentials={
  20. 'api_key': os.environ.get('COHERE_API_KEY')
  21. }
  22. )
  23. def test_validate_credentials_for_completion_model():
  24. model = CohereLargeLanguageModel()
  25. with pytest.raises(CredentialsValidateFailedError):
  26. model.validate_credentials(
  27. model='command-light',
  28. credentials={
  29. 'api_key': 'invalid_key'
  30. }
  31. )
  32. model.validate_credentials(
  33. model='command-light',
  34. credentials={
  35. 'api_key': os.environ.get('COHERE_API_KEY')
  36. }
  37. )
  38. def test_invoke_completion_model():
  39. model = CohereLargeLanguageModel()
  40. credentials = {
  41. 'api_key': os.environ.get('COHERE_API_KEY')
  42. }
  43. result = model.invoke(
  44. model='command-light',
  45. credentials=credentials,
  46. prompt_messages=[
  47. UserPromptMessage(
  48. content='Hello World!'
  49. )
  50. ],
  51. model_parameters={
  52. 'temperature': 0.0,
  53. 'max_tokens': 1
  54. },
  55. stream=False,
  56. user="abc-123"
  57. )
  58. assert isinstance(result, LLMResult)
  59. assert len(result.message.content) > 0
  60. assert model._num_tokens_from_string('command-light', credentials, result.message.content) == 1
  61. def test_invoke_stream_completion_model():
  62. model = CohereLargeLanguageModel()
  63. result = model.invoke(
  64. model='command-light',
  65. credentials={
  66. 'api_key': os.environ.get('COHERE_API_KEY')
  67. },
  68. prompt_messages=[
  69. UserPromptMessage(
  70. content='Hello World!'
  71. )
  72. ],
  73. model_parameters={
  74. 'temperature': 0.0,
  75. 'max_tokens': 100
  76. },
  77. stream=True,
  78. user="abc-123"
  79. )
  80. assert isinstance(result, Generator)
  81. for chunk in result:
  82. assert isinstance(chunk, LLMResultChunk)
  83. assert isinstance(chunk.delta, LLMResultChunkDelta)
  84. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  85. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  86. def test_invoke_chat_model():
  87. model = CohereLargeLanguageModel()
  88. result = model.invoke(
  89. model='command-light-chat',
  90. credentials={
  91. 'api_key': os.environ.get('COHERE_API_KEY')
  92. },
  93. prompt_messages=[
  94. SystemPromptMessage(
  95. content='You are a helpful AI assistant.',
  96. ),
  97. UserPromptMessage(
  98. content='Hello World!'
  99. )
  100. ],
  101. model_parameters={
  102. 'temperature': 0.0,
  103. 'p': 0.99,
  104. 'presence_penalty': 0.0,
  105. 'frequency_penalty': 0.0,
  106. 'max_tokens': 10
  107. },
  108. stop=['How'],
  109. stream=False,
  110. user="abc-123"
  111. )
  112. assert isinstance(result, LLMResult)
  113. assert len(result.message.content) > 0
  114. for chunk in model._llm_result_to_stream(result):
  115. assert isinstance(chunk, LLMResultChunk)
  116. assert isinstance(chunk.delta, LLMResultChunkDelta)
  117. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  118. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  119. def test_invoke_stream_chat_model():
  120. model = CohereLargeLanguageModel()
  121. result = model.invoke(
  122. model='command-light-chat',
  123. credentials={
  124. 'api_key': os.environ.get('COHERE_API_KEY')
  125. },
  126. prompt_messages=[
  127. SystemPromptMessage(
  128. content='You are a helpful AI assistant.',
  129. ),
  130. UserPromptMessage(
  131. content='Hello World!'
  132. )
  133. ],
  134. model_parameters={
  135. 'temperature': 0.0,
  136. 'max_tokens': 100
  137. },
  138. stream=True,
  139. user="abc-123"
  140. )
  141. assert isinstance(result, Generator)
  142. for chunk in result:
  143. assert isinstance(chunk, LLMResultChunk)
  144. assert isinstance(chunk.delta, LLMResultChunkDelta)
  145. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  146. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  147. if chunk.delta.finish_reason is not None:
  148. assert chunk.delta.usage is not None
  149. assert chunk.delta.usage.completion_tokens > 0
  150. def test_get_num_tokens():
  151. model = CohereLargeLanguageModel()
  152. num_tokens = model.get_num_tokens(
  153. model='command-light',
  154. credentials={
  155. 'api_key': os.environ.get('COHERE_API_KEY')
  156. },
  157. prompt_messages=[
  158. UserPromptMessage(
  159. content='Hello World!'
  160. )
  161. ]
  162. )
  163. assert num_tokens == 3
  164. num_tokens = model.get_num_tokens(
  165. model='command-light-chat',
  166. credentials={
  167. 'api_key': os.environ.get('COHERE_API_KEY')
  168. },
  169. prompt_messages=[
  170. SystemPromptMessage(
  171. content='You are a helpful AI assistant.',
  172. ),
  173. UserPromptMessage(
  174. content='Hello World!'
  175. )
  176. ]
  177. )
  178. assert num_tokens == 15
  179. def test_fine_tuned_model():
  180. model = CohereLargeLanguageModel()
  181. # test invoke
  182. result = model.invoke(
  183. model='85ec47be-6139-4f75-a4be-0f0ec1ef115c-ft',
  184. credentials={
  185. 'api_key': os.environ.get('COHERE_API_KEY'),
  186. 'mode': 'completion'
  187. },
  188. prompt_messages=[
  189. SystemPromptMessage(
  190. content='You are a helpful AI assistant.',
  191. ),
  192. UserPromptMessage(
  193. content='Hello World!'
  194. )
  195. ],
  196. model_parameters={
  197. 'temperature': 0.0,
  198. 'max_tokens': 100
  199. },
  200. stream=False,
  201. user="abc-123"
  202. )
  203. assert isinstance(result, LLMResult)
  204. def test_fine_tuned_chat_model():
  205. model = CohereLargeLanguageModel()
  206. # test invoke
  207. result = model.invoke(
  208. model='94f2d55a-4c79-4c00-bde4-23962e74b170-ft',
  209. credentials={
  210. 'api_key': os.environ.get('COHERE_API_KEY'),
  211. 'mode': 'chat'
  212. },
  213. prompt_messages=[
  214. SystemPromptMessage(
  215. content='You are a helpful AI assistant.',
  216. ),
  217. UserPromptMessage(
  218. content='Hello World!'
  219. )
  220. ],
  221. model_parameters={
  222. 'temperature': 0.0,
  223. 'max_tokens': 100
  224. },
  225. stream=False,
  226. user="abc-123"
  227. )
  228. assert isinstance(result, LLMResult)