test_llm.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. import os
  2. from collections.abc import Generator
  3. import pytest
  4. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
  5. from core.model_runtime.entities.message_entities import (
  6. AssistantPromptMessage,
  7. PromptMessageTool,
  8. SystemPromptMessage,
  9. TextPromptMessageContent,
  10. UserPromptMessage,
  11. )
  12. from core.model_runtime.entities.model_entities import ParameterRule
  13. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  14. from core.model_runtime.model_providers.localai.llm.llm import LocalAILarguageModel
  15. def test_validate_credentials_for_chat_model():
  16. model = LocalAILarguageModel()
  17. with pytest.raises(CredentialsValidateFailedError):
  18. model.validate_credentials(
  19. model='chinese-llama-2-7b',
  20. credentials={
  21. 'server_url': 'hahahaha',
  22. 'completion_type': 'completion',
  23. }
  24. )
  25. model.validate_credentials(
  26. model='chinese-llama-2-7b',
  27. credentials={
  28. 'server_url': os.environ.get('LOCALAI_SERVER_URL'),
  29. 'completion_type': 'completion',
  30. }
  31. )
  32. def test_invoke_completion_model():
  33. model = LocalAILarguageModel()
  34. response = model.invoke(
  35. model='chinese-llama-2-7b',
  36. credentials={
  37. 'server_url': os.environ.get('LOCALAI_SERVER_URL'),
  38. 'completion_type': 'completion',
  39. },
  40. prompt_messages=[
  41. UserPromptMessage(
  42. content='ping'
  43. )
  44. ],
  45. model_parameters={
  46. 'temperature': 0.7,
  47. 'top_p': 1.0,
  48. 'max_tokens': 10
  49. },
  50. stop=[],
  51. user="abc-123",
  52. stream=False
  53. )
  54. assert isinstance(response, LLMResult)
  55. assert len(response.message.content) > 0
  56. assert response.usage.total_tokens > 0
  57. def test_invoke_chat_model():
  58. model = LocalAILarguageModel()
  59. response = model.invoke(
  60. model='chinese-llama-2-7b',
  61. credentials={
  62. 'server_url': os.environ.get('LOCALAI_SERVER_URL'),
  63. 'completion_type': 'chat_completion',
  64. },
  65. prompt_messages=[
  66. UserPromptMessage(
  67. content='ping'
  68. )
  69. ],
  70. model_parameters={
  71. 'temperature': 0.7,
  72. 'top_p': 1.0,
  73. 'max_tokens': 10
  74. },
  75. stop=[],
  76. user="abc-123",
  77. stream=False
  78. )
  79. assert isinstance(response, LLMResult)
  80. assert len(response.message.content) > 0
  81. assert response.usage.total_tokens > 0
  82. def test_invoke_stream_completion_model():
  83. model = LocalAILarguageModel()
  84. response = model.invoke(
  85. model='chinese-llama-2-7b',
  86. credentials={
  87. 'server_url': os.environ.get('LOCALAI_SERVER_URL'),
  88. 'completion_type': 'completion',
  89. },
  90. prompt_messages=[
  91. UserPromptMessage(
  92. content='Hello World!'
  93. )
  94. ],
  95. model_parameters={
  96. 'temperature': 0.7,
  97. 'top_p': 1.0,
  98. 'max_tokens': 10
  99. },
  100. stop=['you'],
  101. stream=True,
  102. user="abc-123"
  103. )
  104. assert isinstance(response, Generator)
  105. for chunk in response:
  106. assert isinstance(chunk, LLMResultChunk)
  107. assert isinstance(chunk.delta, LLMResultChunkDelta)
  108. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  109. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  110. def test_invoke_stream_chat_model():
  111. model = LocalAILarguageModel()
  112. response = model.invoke(
  113. model='chinese-llama-2-7b',
  114. credentials={
  115. 'server_url': os.environ.get('LOCALAI_SERVER_URL'),
  116. 'completion_type': 'chat_completion',
  117. },
  118. prompt_messages=[
  119. UserPromptMessage(
  120. content='Hello World!'
  121. )
  122. ],
  123. model_parameters={
  124. 'temperature': 0.7,
  125. 'top_p': 1.0,
  126. 'max_tokens': 10
  127. },
  128. stop=['you'],
  129. stream=True,
  130. user="abc-123"
  131. )
  132. assert isinstance(response, Generator)
  133. for chunk in response:
  134. assert isinstance(chunk, LLMResultChunk)
  135. assert isinstance(chunk.delta, LLMResultChunkDelta)
  136. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  137. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  138. def test_get_num_tokens():
  139. model = LocalAILarguageModel()
  140. num_tokens = model.get_num_tokens(
  141. model='????',
  142. credentials={
  143. 'server_url': os.environ.get('LOCALAI_SERVER_URL'),
  144. 'completion_type': 'chat_completion',
  145. },
  146. prompt_messages=[
  147. SystemPromptMessage(
  148. content='You are a helpful AI assistant.',
  149. ),
  150. UserPromptMessage(
  151. content='Hello World!'
  152. )
  153. ],
  154. tools=[
  155. PromptMessageTool(
  156. name='get_current_weather',
  157. description='Get the current weather in a given location',
  158. parameters={
  159. "type": "object",
  160. "properties": {
  161. "location": {
  162. "type": "string",
  163. "description": "The city and state e.g. San Francisco, CA"
  164. },
  165. "unit": {
  166. "type": "string",
  167. "enum": [
  168. "c",
  169. "f"
  170. ]
  171. }
  172. },
  173. "required": [
  174. "location"
  175. ]
  176. }
  177. )
  178. ]
  179. )
  180. assert isinstance(num_tokens, int)
  181. assert num_tokens == 77
  182. num_tokens = model.get_num_tokens(
  183. model='????',
  184. credentials={
  185. 'server_url': os.environ.get('LOCALAI_SERVER_URL'),
  186. 'completion_type': 'chat_completion',
  187. },
  188. prompt_messages=[
  189. UserPromptMessage(
  190. content='Hello World!'
  191. )
  192. ],
  193. )
  194. assert isinstance(num_tokens, int)
  195. assert num_tokens == 10