test_llm.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413
  1. import os
  2. from collections.abc import Generator
  3. import pytest
  4. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
  5. from core.model_runtime.entities.message_entities import (
  6. AssistantPromptMessage,
  7. ImagePromptMessageContent,
  8. PromptMessageTool,
  9. SystemPromptMessage,
  10. TextPromptMessageContent,
  11. UserPromptMessage,
  12. )
  13. from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
  14. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  15. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  16. from core.model_runtime.model_providers.openai.llm.llm import OpenAILargeLanguageModel
  17. """FOR MOCK FIXTURES, DO NOT REMOVE"""
  18. from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
  19. def test_predefined_models():
  20. model = OpenAILargeLanguageModel()
  21. model_schemas = model.predefined_models()
  22. assert len(model_schemas) >= 1
  23. assert isinstance(model_schemas[0], AIModelEntity)
  24. @pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
  25. def test_validate_credentials_for_chat_model(setup_openai_mock):
  26. model = OpenAILargeLanguageModel()
  27. with pytest.raises(CredentialsValidateFailedError):
  28. model.validate_credentials(
  29. model='gpt-3.5-turbo',
  30. credentials={
  31. 'openai_api_key': 'invalid_key'
  32. }
  33. )
  34. model.validate_credentials(
  35. model='gpt-3.5-turbo',
  36. credentials={
  37. 'openai_api_key': os.environ.get('OPENAI_API_KEY')
  38. }
  39. )
  40. @pytest.mark.parametrize('setup_openai_mock', [['completion']], indirect=True)
  41. def test_validate_credentials_for_completion_model(setup_openai_mock):
  42. model = OpenAILargeLanguageModel()
  43. with pytest.raises(CredentialsValidateFailedError):
  44. model.validate_credentials(
  45. model='text-davinci-003',
  46. credentials={
  47. 'openai_api_key': 'invalid_key'
  48. }
  49. )
  50. model.validate_credentials(
  51. model='text-davinci-003',
  52. credentials={
  53. 'openai_api_key': os.environ.get('OPENAI_API_KEY')
  54. }
  55. )
  56. @pytest.mark.parametrize('setup_openai_mock', [['completion']], indirect=True)
  57. def test_invoke_completion_model(setup_openai_mock):
  58. model = OpenAILargeLanguageModel()
  59. result = model.invoke(
  60. model='gpt-3.5-turbo-instruct',
  61. credentials={
  62. 'openai_api_key': os.environ.get('OPENAI_API_KEY'),
  63. 'openai_api_base': 'https://api.openai.com'
  64. },
  65. prompt_messages=[
  66. UserPromptMessage(
  67. content='Hello World!'
  68. )
  69. ],
  70. model_parameters={
  71. 'temperature': 0.0,
  72. 'max_tokens': 1
  73. },
  74. stream=False,
  75. user="abc-123"
  76. )
  77. assert isinstance(result, LLMResult)
  78. assert len(result.message.content) > 0
  79. assert model._num_tokens_from_string('gpt-3.5-turbo-instruct', result.message.content) == 1
  80. @pytest.mark.parametrize('setup_openai_mock', [['completion']], indirect=True)
  81. def test_invoke_stream_completion_model(setup_openai_mock):
  82. model = OpenAILargeLanguageModel()
  83. result = model.invoke(
  84. model='gpt-3.5-turbo-instruct',
  85. credentials={
  86. 'openai_api_key': os.environ.get('OPENAI_API_KEY'),
  87. 'openai_organization': os.environ.get('OPENAI_ORGANIZATION'),
  88. },
  89. prompt_messages=[
  90. UserPromptMessage(
  91. content='Hello World!'
  92. )
  93. ],
  94. model_parameters={
  95. 'temperature': 0.0,
  96. 'max_tokens': 100
  97. },
  98. stream=True,
  99. user="abc-123"
  100. )
  101. assert isinstance(result, Generator)
  102. for chunk in result:
  103. assert isinstance(chunk, LLMResultChunk)
  104. assert isinstance(chunk.delta, LLMResultChunkDelta)
  105. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  106. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  107. @pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
  108. def test_invoke_chat_model(setup_openai_mock):
  109. model = OpenAILargeLanguageModel()
  110. result = model.invoke(
  111. model='gpt-3.5-turbo',
  112. credentials={
  113. 'openai_api_key': os.environ.get('OPENAI_API_KEY')
  114. },
  115. prompt_messages=[
  116. SystemPromptMessage(
  117. content='You are a helpful AI assistant.',
  118. ),
  119. UserPromptMessage(
  120. content='Hello World!'
  121. )
  122. ],
  123. model_parameters={
  124. 'temperature': 0.0,
  125. 'top_p': 1.0,
  126. 'presence_penalty': 0.0,
  127. 'frequency_penalty': 0.0,
  128. 'max_tokens': 10
  129. },
  130. stop=['How'],
  131. stream=False,
  132. user="abc-123"
  133. )
  134. assert isinstance(result, LLMResult)
  135. assert len(result.message.content) > 0
  136. for chunk in model._llm_result_to_stream(result):
  137. assert isinstance(chunk, LLMResultChunk)
  138. assert isinstance(chunk.delta, LLMResultChunkDelta)
  139. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  140. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  141. @pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
  142. def test_invoke_chat_model_with_vision(setup_openai_mock):
  143. model = OpenAILargeLanguageModel()
  144. result = model.invoke(
  145. model='gpt-4-vision-preview',
  146. credentials={
  147. 'openai_api_key': os.environ.get('OPENAI_API_KEY')
  148. },
  149. prompt_messages=[
  150. SystemPromptMessage(
  151. content='You are a helpful AI assistant.',
  152. ),
  153. UserPromptMessage(
  154. content=[
  155. TextPromptMessageContent(
  156. data='Hello World!',
  157. ),
  158. ImagePromptMessageContent(
  159. data=''
  160. )
  161. ]
  162. )
  163. ],
  164. model_parameters={
  165. 'temperature': 0.0,
  166. 'max_tokens': 100
  167. },
  168. stream=False,
  169. user="abc-123"
  170. )
  171. assert isinstance(result, LLMResult)
  172. assert len(result.message.content) > 0
  173. @pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
  174. def test_invoke_chat_model_with_tools(setup_openai_mock):
  175. model = OpenAILargeLanguageModel()
  176. result = model.invoke(
  177. model='gpt-3.5-turbo',
  178. credentials={
  179. 'openai_api_key': os.environ.get('OPENAI_API_KEY')
  180. },
  181. prompt_messages=[
  182. SystemPromptMessage(
  183. content='You are a helpful AI assistant.',
  184. ),
  185. UserPromptMessage(
  186. content="what's the weather today in London?",
  187. )
  188. ],
  189. model_parameters={
  190. 'temperature': 0.0,
  191. 'max_tokens': 100
  192. },
  193. tools=[
  194. PromptMessageTool(
  195. name='get_weather',
  196. description='Determine weather in my location',
  197. parameters={
  198. "type": "object",
  199. "properties": {
  200. "location": {
  201. "type": "string",
  202. "description": "The city and state e.g. San Francisco, CA"
  203. },
  204. "unit": {
  205. "type": "string",
  206. "enum": [
  207. "c",
  208. "f"
  209. ]
  210. }
  211. },
  212. "required": [
  213. "location"
  214. ]
  215. }
  216. ),
  217. PromptMessageTool(
  218. name='get_stock_price',
  219. description='Get the current stock price',
  220. parameters={
  221. "type": "object",
  222. "properties": {
  223. "symbol": {
  224. "type": "string",
  225. "description": "The stock symbol"
  226. }
  227. },
  228. "required": [
  229. "symbol"
  230. ]
  231. }
  232. )
  233. ],
  234. stream=False,
  235. user="abc-123"
  236. )
  237. assert isinstance(result, LLMResult)
  238. assert isinstance(result.message, AssistantPromptMessage)
  239. assert len(result.message.tool_calls) > 0
  240. @pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
  241. def test_invoke_stream_chat_model(setup_openai_mock):
  242. model = OpenAILargeLanguageModel()
  243. result = model.invoke(
  244. model='gpt-3.5-turbo',
  245. credentials={
  246. 'openai_api_key': os.environ.get('OPENAI_API_KEY')
  247. },
  248. prompt_messages=[
  249. SystemPromptMessage(
  250. content='You are a helpful AI assistant.',
  251. ),
  252. UserPromptMessage(
  253. content='Hello World!'
  254. )
  255. ],
  256. model_parameters={
  257. 'temperature': 0.0,
  258. 'max_tokens': 100
  259. },
  260. stream=True,
  261. user="abc-123"
  262. )
  263. assert isinstance(result, Generator)
  264. for chunk in result:
  265. assert isinstance(chunk, LLMResultChunk)
  266. assert isinstance(chunk.delta, LLMResultChunkDelta)
  267. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  268. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  269. if chunk.delta.finish_reason is not None:
  270. assert chunk.delta.usage is not None
  271. assert chunk.delta.usage.completion_tokens > 0
  272. def test_get_num_tokens():
  273. model = OpenAILargeLanguageModel()
  274. num_tokens = model.get_num_tokens(
  275. model='gpt-3.5-turbo-instruct',
  276. credentials={
  277. 'openai_api_key': os.environ.get('OPENAI_API_KEY')
  278. },
  279. prompt_messages=[
  280. UserPromptMessage(
  281. content='Hello World!'
  282. )
  283. ]
  284. )
  285. assert num_tokens == 3
  286. num_tokens = model.get_num_tokens(
  287. model='gpt-3.5-turbo',
  288. credentials={
  289. 'openai_api_key': os.environ.get('OPENAI_API_KEY')
  290. },
  291. prompt_messages=[
  292. SystemPromptMessage(
  293. content='You are a helpful AI assistant.',
  294. ),
  295. UserPromptMessage(
  296. content='Hello World!'
  297. )
  298. ],
  299. tools=[
  300. PromptMessageTool(
  301. name='get_weather',
  302. description='Determine weather in my location',
  303. parameters={
  304. "type": "object",
  305. "properties": {
  306. "location": {
  307. "type": "string",
  308. "description": "The city and state e.g. San Francisco, CA"
  309. },
  310. "unit": {
  311. "type": "string",
  312. "enum": [
  313. "c",
  314. "f"
  315. ]
  316. }
  317. },
  318. "required": [
  319. "location"
  320. ]
  321. }
  322. ),
  323. ]
  324. )
  325. assert num_tokens == 72
  326. @pytest.mark.parametrize('setup_openai_mock', [['chat', 'remote']], indirect=True)
  327. def test_fine_tuned_models(setup_openai_mock):
  328. model = OpenAILargeLanguageModel()
  329. remote_models = model.remote_models(credentials={
  330. 'openai_api_key': os.environ.get('OPENAI_API_KEY')
  331. })
  332. if not remote_models:
  333. assert isinstance(remote_models, list)
  334. else:
  335. assert isinstance(remote_models[0], AIModelEntity)
  336. for llm_model in remote_models:
  337. if llm_model.model_type == ModelType.LLM:
  338. break
  339. assert isinstance(llm_model, AIModelEntity)
  340. # test invoke
  341. result = model.invoke(
  342. model=llm_model.model,
  343. credentials={
  344. 'openai_api_key': os.environ.get('OPENAI_API_KEY')
  345. },
  346. prompt_messages=[
  347. SystemPromptMessage(
  348. content='You are a helpful AI assistant.',
  349. ),
  350. UserPromptMessage(
  351. content='Hello World!'
  352. )
  353. ],
  354. model_parameters={
  355. 'temperature': 0.0,
  356. 'max_tokens': 100
  357. },
  358. stream=False,
  359. user="abc-123"
  360. )
  361. assert isinstance(result, LLMResult)
  362. def test__get_num_tokens_by_gpt2():
  363. model = OpenAILargeLanguageModel()
  364. num_tokens = model._get_num_tokens_by_gpt2('Hello World!')
  365. assert num_tokens == 3