test_llm.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. import os
  2. from collections.abc import Generator
  3. import pytest
  4. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
  5. from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage
  6. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  7. from core.model_runtime.model_providers.huggingface_hub.llm.llm import HuggingfaceHubLargeLanguageModel
  8. from tests.integration_tests.model_runtime.__mock.huggingface import setup_huggingface_mock
  9. @pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
  10. def test_hosted_inference_api_validate_credentials(setup_huggingface_mock):
  11. model = HuggingfaceHubLargeLanguageModel()
  12. with pytest.raises(CredentialsValidateFailedError):
  13. model.validate_credentials(
  14. model='HuggingFaceH4/zephyr-7b-beta',
  15. credentials={
  16. 'huggingfacehub_api_type': 'hosted_inference_api',
  17. 'huggingfacehub_api_token': 'invalid_key'
  18. }
  19. )
  20. with pytest.raises(CredentialsValidateFailedError):
  21. model.validate_credentials(
  22. model='fake-model',
  23. credentials={
  24. 'huggingfacehub_api_type': 'hosted_inference_api',
  25. 'huggingfacehub_api_token': 'invalid_key'
  26. }
  27. )
  28. model.validate_credentials(
  29. model='HuggingFaceH4/zephyr-7b-beta',
  30. credentials={
  31. 'huggingfacehub_api_type': 'hosted_inference_api',
  32. 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY')
  33. }
  34. )
  35. @pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
  36. def test_hosted_inference_api_invoke_model(setup_huggingface_mock):
  37. model = HuggingfaceHubLargeLanguageModel()
  38. response = model.invoke(
  39. model='HuggingFaceH4/zephyr-7b-beta',
  40. credentials={
  41. 'huggingfacehub_api_type': 'hosted_inference_api',
  42. 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY')
  43. },
  44. prompt_messages=[
  45. UserPromptMessage(
  46. content='Who are you?'
  47. )
  48. ],
  49. model_parameters={
  50. 'temperature': 1.0,
  51. 'top_k': 2,
  52. 'top_p': 0.5,
  53. },
  54. stop=['How'],
  55. stream=False,
  56. user="abc-123"
  57. )
  58. assert isinstance(response, LLMResult)
  59. assert len(response.message.content) > 0
  60. @pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
  61. def test_hosted_inference_api_invoke_stream_model(setup_huggingface_mock):
  62. model = HuggingfaceHubLargeLanguageModel()
  63. response = model.invoke(
  64. model='HuggingFaceH4/zephyr-7b-beta',
  65. credentials={
  66. 'huggingfacehub_api_type': 'hosted_inference_api',
  67. 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY')
  68. },
  69. prompt_messages=[
  70. UserPromptMessage(
  71. content='Who are you?'
  72. )
  73. ],
  74. model_parameters={
  75. 'temperature': 1.0,
  76. 'top_k': 2,
  77. 'top_p': 0.5,
  78. },
  79. stop=['How'],
  80. stream=True,
  81. user="abc-123"
  82. )
  83. assert isinstance(response, Generator)
  84. for chunk in response:
  85. assert isinstance(chunk, LLMResultChunk)
  86. assert isinstance(chunk.delta, LLMResultChunkDelta)
  87. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  88. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  89. @pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
  90. def test_inference_endpoints_text_generation_validate_credentials(setup_huggingface_mock):
  91. model = HuggingfaceHubLargeLanguageModel()
  92. with pytest.raises(CredentialsValidateFailedError):
  93. model.validate_credentials(
  94. model='openchat/openchat_3.5',
  95. credentials={
  96. 'huggingfacehub_api_type': 'inference_endpoints',
  97. 'huggingfacehub_api_token': 'invalid_key',
  98. 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'),
  99. 'task_type': 'text-generation'
  100. }
  101. )
  102. model.validate_credentials(
  103. model='openchat/openchat_3.5',
  104. credentials={
  105. 'huggingfacehub_api_type': 'inference_endpoints',
  106. 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
  107. 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'),
  108. 'task_type': 'text-generation'
  109. }
  110. )
  111. @pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
  112. def test_inference_endpoints_text_generation_invoke_model(setup_huggingface_mock):
  113. model = HuggingfaceHubLargeLanguageModel()
  114. response = model.invoke(
  115. model='openchat/openchat_3.5',
  116. credentials={
  117. 'huggingfacehub_api_type': 'inference_endpoints',
  118. 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
  119. 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'),
  120. 'task_type': 'text-generation'
  121. },
  122. prompt_messages=[
  123. UserPromptMessage(
  124. content='Who are you?'
  125. )
  126. ],
  127. model_parameters={
  128. 'temperature': 1.0,
  129. 'top_k': 2,
  130. 'top_p': 0.5,
  131. },
  132. stop=['How'],
  133. stream=False,
  134. user="abc-123"
  135. )
  136. assert isinstance(response, LLMResult)
  137. assert len(response.message.content) > 0
  138. @pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
  139. def test_inference_endpoints_text_generation_invoke_stream_model(setup_huggingface_mock):
  140. model = HuggingfaceHubLargeLanguageModel()
  141. response = model.invoke(
  142. model='openchat/openchat_3.5',
  143. credentials={
  144. 'huggingfacehub_api_type': 'inference_endpoints',
  145. 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
  146. 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'),
  147. 'task_type': 'text-generation'
  148. },
  149. prompt_messages=[
  150. UserPromptMessage(
  151. content='Who are you?'
  152. )
  153. ],
  154. model_parameters={
  155. 'temperature': 1.0,
  156. 'top_k': 2,
  157. 'top_p': 0.5,
  158. },
  159. stop=['How'],
  160. stream=True,
  161. user="abc-123"
  162. )
  163. assert isinstance(response, Generator)
  164. for chunk in response:
  165. assert isinstance(chunk, LLMResultChunk)
  166. assert isinstance(chunk.delta, LLMResultChunkDelta)
  167. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  168. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  169. @pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
  170. def test_inference_endpoints_text2text_generation_validate_credentials(setup_huggingface_mock):
  171. model = HuggingfaceHubLargeLanguageModel()
  172. with pytest.raises(CredentialsValidateFailedError):
  173. model.validate_credentials(
  174. model='google/mt5-base',
  175. credentials={
  176. 'huggingfacehub_api_type': 'inference_endpoints',
  177. 'huggingfacehub_api_token': 'invalid_key',
  178. 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
  179. 'task_type': 'text2text-generation'
  180. }
  181. )
  182. model.validate_credentials(
  183. model='google/mt5-base',
  184. credentials={
  185. 'huggingfacehub_api_type': 'inference_endpoints',
  186. 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
  187. 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
  188. 'task_type': 'text2text-generation'
  189. }
  190. )
  191. @pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
  192. def test_inference_endpoints_text2text_generation_invoke_model(setup_huggingface_mock):
  193. model = HuggingfaceHubLargeLanguageModel()
  194. response = model.invoke(
  195. model='google/mt5-base',
  196. credentials={
  197. 'huggingfacehub_api_type': 'inference_endpoints',
  198. 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
  199. 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
  200. 'task_type': 'text2text-generation'
  201. },
  202. prompt_messages=[
  203. UserPromptMessage(
  204. content='Who are you?'
  205. )
  206. ],
  207. model_parameters={
  208. 'temperature': 1.0,
  209. 'top_k': 2,
  210. 'top_p': 0.5,
  211. },
  212. stop=['How'],
  213. stream=False,
  214. user="abc-123"
  215. )
  216. assert isinstance(response, LLMResult)
  217. assert len(response.message.content) > 0
  218. @pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
  219. def test_inference_endpoints_text2text_generation_invoke_stream_model(setup_huggingface_mock):
  220. model = HuggingfaceHubLargeLanguageModel()
  221. response = model.invoke(
  222. model='google/mt5-base',
  223. credentials={
  224. 'huggingfacehub_api_type': 'inference_endpoints',
  225. 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
  226. 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
  227. 'task_type': 'text2text-generation'
  228. },
  229. prompt_messages=[
  230. UserPromptMessage(
  231. content='Who are you?'
  232. )
  233. ],
  234. model_parameters={
  235. 'temperature': 1.0,
  236. 'top_k': 2,
  237. 'top_p': 0.5,
  238. },
  239. stop=['How'],
  240. stream=True,
  241. user="abc-123"
  242. )
  243. assert isinstance(response, Generator)
  244. for chunk in response:
  245. assert isinstance(chunk, LLMResultChunk)
  246. assert isinstance(chunk.delta, LLMResultChunkDelta)
  247. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  248. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  249. def test_get_num_tokens():
  250. model = HuggingfaceHubLargeLanguageModel()
  251. num_tokens = model.get_num_tokens(
  252. model='google/mt5-base',
  253. credentials={
  254. 'huggingfacehub_api_type': 'inference_endpoints',
  255. 'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
  256. 'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
  257. 'task_type': 'text2text-generation'
  258. },
  259. prompt_messages=[
  260. UserPromptMessage(
  261. content='Hello World!'
  262. )
  263. ]
  264. )
  265. assert num_tokens == 7