test_llm.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. import os
  2. from typing import Generator
  3. import pytest
  4. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
  5. from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessageTool,
  6. SystemPromptMessage, TextPromptMessageContent,
  7. UserPromptMessage)
  8. from core.model_runtime.entities.model_entities import ParameterRule
  9. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  10. from core.model_runtime.model_providers.localai.llm.llm import LocalAILarguageModel
  11. def test_validate_credentials_for_chat_model():
  12. model = LocalAILarguageModel()
  13. with pytest.raises(CredentialsValidateFailedError):
  14. model.validate_credentials(
  15. model='chinese-llama-2-7b',
  16. credentials={
  17. 'server_url': 'hahahaha',
  18. 'completion_type': 'completion',
  19. }
  20. )
  21. model.validate_credentials(
  22. model='chinese-llama-2-7b',
  23. credentials={
  24. 'server_url': os.environ.get('LOCALAI_SERVER_URL'),
  25. 'completion_type': 'completion',
  26. }
  27. )
  28. def test_invoke_completion_model():
  29. model = LocalAILarguageModel()
  30. response = model.invoke(
  31. model='chinese-llama-2-7b',
  32. credentials={
  33. 'server_url': os.environ.get('LOCALAI_SERVER_URL'),
  34. 'completion_type': 'completion',
  35. },
  36. prompt_messages=[
  37. UserPromptMessage(
  38. content='ping'
  39. )
  40. ],
  41. model_parameters={
  42. 'temperature': 0.7,
  43. 'top_p': 1.0,
  44. 'max_tokens': 10
  45. },
  46. stop=[],
  47. user="abc-123",
  48. stream=False
  49. )
  50. assert isinstance(response, LLMResult)
  51. assert len(response.message.content) > 0
  52. assert response.usage.total_tokens > 0
  53. def test_invoke_chat_model():
  54. model = LocalAILarguageModel()
  55. response = model.invoke(
  56. model='chinese-llama-2-7b',
  57. credentials={
  58. 'server_url': os.environ.get('LOCALAI_SERVER_URL'),
  59. 'completion_type': 'chat_completion',
  60. },
  61. prompt_messages=[
  62. UserPromptMessage(
  63. content='ping'
  64. )
  65. ],
  66. model_parameters={
  67. 'temperature': 0.7,
  68. 'top_p': 1.0,
  69. 'max_tokens': 10
  70. },
  71. stop=[],
  72. user="abc-123",
  73. stream=False
  74. )
  75. assert isinstance(response, LLMResult)
  76. assert len(response.message.content) > 0
  77. assert response.usage.total_tokens > 0
  78. def test_invoke_stream_completion_model():
  79. model = LocalAILarguageModel()
  80. response = model.invoke(
  81. model='chinese-llama-2-7b',
  82. credentials={
  83. 'server_url': os.environ.get('LOCALAI_SERVER_URL'),
  84. 'completion_type': 'completion',
  85. },
  86. prompt_messages=[
  87. UserPromptMessage(
  88. content='Hello World!'
  89. )
  90. ],
  91. model_parameters={
  92. 'temperature': 0.7,
  93. 'top_p': 1.0,
  94. 'max_tokens': 10
  95. },
  96. stop=['you'],
  97. stream=True,
  98. user="abc-123"
  99. )
  100. assert isinstance(response, Generator)
  101. for chunk in response:
  102. assert isinstance(chunk, LLMResultChunk)
  103. assert isinstance(chunk.delta, LLMResultChunkDelta)
  104. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  105. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  106. def test_invoke_stream_chat_model():
  107. model = LocalAILarguageModel()
  108. response = model.invoke(
  109. model='chinese-llama-2-7b',
  110. credentials={
  111. 'server_url': os.environ.get('LOCALAI_SERVER_URL'),
  112. 'completion_type': 'chat_completion',
  113. },
  114. prompt_messages=[
  115. UserPromptMessage(
  116. content='Hello World!'
  117. )
  118. ],
  119. model_parameters={
  120. 'temperature': 0.7,
  121. 'top_p': 1.0,
  122. 'max_tokens': 10
  123. },
  124. stop=['you'],
  125. stream=True,
  126. user="abc-123"
  127. )
  128. assert isinstance(response, Generator)
  129. for chunk in response:
  130. assert isinstance(chunk, LLMResultChunk)
  131. assert isinstance(chunk.delta, LLMResultChunkDelta)
  132. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  133. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  134. def test_get_num_tokens():
  135. model = LocalAILarguageModel()
  136. num_tokens = model.get_num_tokens(
  137. model='????',
  138. credentials={
  139. 'server_url': os.environ.get('LOCALAI_SERVER_URL'),
  140. 'completion_type': 'chat_completion',
  141. },
  142. prompt_messages=[
  143. SystemPromptMessage(
  144. content='You are a helpful AI assistant.',
  145. ),
  146. UserPromptMessage(
  147. content='Hello World!'
  148. )
  149. ],
  150. tools=[
  151. PromptMessageTool(
  152. name='get_current_weather',
  153. description='Get the current weather in a given location',
  154. parameters={
  155. "type": "object",
  156. "properties": {
  157. "location": {
  158. "type": "string",
  159. "description": "The city and state e.g. San Francisco, CA"
  160. },
  161. "unit": {
  162. "type": "string",
  163. "enum": [
  164. "c",
  165. "f"
  166. ]
  167. }
  168. },
  169. "required": [
  170. "location"
  171. ]
  172. }
  173. )
  174. ]
  175. )
  176. assert isinstance(num_tokens, int)
  177. assert num_tokens == 77
  178. num_tokens = model.get_num_tokens(
  179. model='????',
  180. credentials={
  181. 'server_url': os.environ.get('LOCALAI_SERVER_URL'),
  182. 'completion_type': 'chat_completion',
  183. },
  184. prompt_messages=[
  185. UserPromptMessage(
  186. content='Hello World!'
  187. )
  188. ],
  189. )
  190. assert isinstance(num_tokens, int)
  191. assert num_tokens == 10