test_llm.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. import json
  2. import os
  3. from unittest.mock import MagicMock
  4. import pytest
  5. from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
  6. from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
  7. from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration
  8. from core.model_manager import ModelInstance
  9. from core.model_runtime.entities.model_entities import ModelType
  10. from core.model_runtime.model_providers import ModelProviderFactory
  11. from core.workflow.entities.node_entities import SystemVariable
  12. from core.workflow.entities.variable_pool import VariablePool
  13. from core.workflow.nodes.base_node import UserFrom
  14. from core.workflow.nodes.llm.llm_node import LLMNode
  15. from extensions.ext_database import db
  16. from models.provider import ProviderType
  17. from models.workflow import WorkflowNodeExecutionStatus
  18. """FOR MOCK FIXTURES, DO NOT REMOVE"""
  19. from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
  20. from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
  21. @pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
  22. def test_execute_llm(setup_openai_mock):
  23. node = LLMNode(
  24. tenant_id='1',
  25. app_id='1',
  26. workflow_id='1',
  27. user_id='1',
  28. invoke_from=InvokeFrom.WEB_APP,
  29. user_from=UserFrom.ACCOUNT,
  30. config={
  31. 'id': 'llm',
  32. 'data': {
  33. 'title': '123',
  34. 'type': 'llm',
  35. 'model': {
  36. 'provider': 'openai',
  37. 'name': 'gpt-3.5-turbo',
  38. 'mode': 'chat',
  39. 'completion_params': {}
  40. },
  41. 'prompt_template': [
  42. {
  43. 'role': 'system',
  44. 'text': 'you are a helpful assistant.\ntoday\'s weather is {{#abc.output#}}.'
  45. },
  46. {
  47. 'role': 'user',
  48. 'text': '{{#sys.query#}}'
  49. }
  50. ],
  51. 'memory': None,
  52. 'context': {
  53. 'enabled': False
  54. },
  55. 'vision': {
  56. 'enabled': False
  57. }
  58. }
  59. }
  60. )
  61. # construct variable pool
  62. pool = VariablePool(system_variables={
  63. SystemVariable.QUERY: 'what\'s the weather today?',
  64. SystemVariable.FILES: [],
  65. SystemVariable.CONVERSATION_ID: 'abababa',
  66. SystemVariable.USER_ID: 'aaa'
  67. }, user_inputs={})
  68. pool.append_variable(node_id='abc', variable_key_list=['output'], value='sunny')
  69. credentials = {
  70. 'openai_api_key': os.environ.get('OPENAI_API_KEY')
  71. }
  72. provider_instance = ModelProviderFactory().get_provider_instance('openai')
  73. model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
  74. provider_model_bundle = ProviderModelBundle(
  75. configuration=ProviderConfiguration(
  76. tenant_id='1',
  77. provider=provider_instance.get_provider_schema(),
  78. preferred_provider_type=ProviderType.CUSTOM,
  79. using_provider_type=ProviderType.CUSTOM,
  80. system_configuration=SystemConfiguration(
  81. enabled=False
  82. ),
  83. custom_configuration=CustomConfiguration(
  84. provider=CustomProviderConfiguration(
  85. credentials=credentials
  86. )
  87. )
  88. ),
  89. provider_instance=provider_instance,
  90. model_type_instance=model_type_instance
  91. )
  92. model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model='gpt-3.5-turbo')
  93. model_config = ModelConfigWithCredentialsEntity(
  94. model='gpt-3.5-turbo',
  95. provider='openai',
  96. mode='chat',
  97. credentials=credentials,
  98. parameters={},
  99. model_schema=model_type_instance.get_model_schema('gpt-3.5-turbo'),
  100. provider_model_bundle=provider_model_bundle
  101. )
  102. # Mock db.session.close()
  103. db.session.close = MagicMock()
  104. node._fetch_model_config = MagicMock(return_value=tuple([model_instance, model_config]))
  105. # execute node
  106. result = node.run(pool)
  107. assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
  108. assert result.outputs['text'] is not None
  109. assert result.outputs['usage']['total_tokens'] > 0
  110. @pytest.mark.parametrize('setup_code_executor_mock', [['none']], indirect=True)
  111. @pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
  112. def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_openai_mock):
  113. """
  114. Test execute LLM node with jinja2
  115. """
  116. node = LLMNode(
  117. tenant_id='1',
  118. app_id='1',
  119. workflow_id='1',
  120. user_id='1',
  121. invoke_from=InvokeFrom.WEB_APP,
  122. user_from=UserFrom.ACCOUNT,
  123. config={
  124. 'id': 'llm',
  125. 'data': {
  126. 'title': '123',
  127. 'type': 'llm',
  128. 'model': {
  129. 'provider': 'openai',
  130. 'name': 'gpt-3.5-turbo',
  131. 'mode': 'chat',
  132. 'completion_params': {}
  133. },
  134. 'prompt_config': {
  135. 'jinja2_variables': [{
  136. 'variable': 'sys_query',
  137. 'value_selector': ['sys', 'query']
  138. }, {
  139. 'variable': 'output',
  140. 'value_selector': ['abc', 'output']
  141. }]
  142. },
  143. 'prompt_template': [
  144. {
  145. 'role': 'system',
  146. 'text': 'you are a helpful assistant.\ntoday\'s weather is {{#abc.output#}}',
  147. 'jinja2_text': 'you are a helpful assistant.\ntoday\'s weather is {{output}}.',
  148. 'edition_type': 'jinja2'
  149. },
  150. {
  151. 'role': 'user',
  152. 'text': '{{#sys.query#}}',
  153. 'jinja2_text': '{{sys_query}}',
  154. 'edition_type': 'basic'
  155. }
  156. ],
  157. 'memory': None,
  158. 'context': {
  159. 'enabled': False
  160. },
  161. 'vision': {
  162. 'enabled': False
  163. }
  164. }
  165. }
  166. )
  167. # construct variable pool
  168. pool = VariablePool(system_variables={
  169. SystemVariable.QUERY: 'what\'s the weather today?',
  170. SystemVariable.FILES: [],
  171. SystemVariable.CONVERSATION_ID: 'abababa',
  172. SystemVariable.USER_ID: 'aaa'
  173. }, user_inputs={})
  174. pool.append_variable(node_id='abc', variable_key_list=['output'], value='sunny')
  175. credentials = {
  176. 'openai_api_key': os.environ.get('OPENAI_API_KEY')
  177. }
  178. provider_instance = ModelProviderFactory().get_provider_instance('openai')
  179. model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
  180. provider_model_bundle = ProviderModelBundle(
  181. configuration=ProviderConfiguration(
  182. tenant_id='1',
  183. provider=provider_instance.get_provider_schema(),
  184. preferred_provider_type=ProviderType.CUSTOM,
  185. using_provider_type=ProviderType.CUSTOM,
  186. system_configuration=SystemConfiguration(
  187. enabled=False
  188. ),
  189. custom_configuration=CustomConfiguration(
  190. provider=CustomProviderConfiguration(
  191. credentials=credentials
  192. )
  193. )
  194. ),
  195. provider_instance=provider_instance,
  196. model_type_instance=model_type_instance
  197. )
  198. model_instance = ModelInstance(provider_model_bundle=provider_model_bundle, model='gpt-3.5-turbo')
  199. model_config = ModelConfigWithCredentialsEntity(
  200. model='gpt-3.5-turbo',
  201. provider='openai',
  202. mode='chat',
  203. credentials=credentials,
  204. parameters={},
  205. model_schema=model_type_instance.get_model_schema('gpt-3.5-turbo'),
  206. provider_model_bundle=provider_model_bundle
  207. )
  208. # Mock db.session.close()
  209. db.session.close = MagicMock()
  210. node._fetch_model_config = MagicMock(return_value=tuple([model_instance, model_config]))
  211. # execute node
  212. result = node.run(pool)
  213. assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
  214. assert 'sunny' in json.dumps(result.process_data)
  215. assert 'what\'s the weather today?' in json.dumps(result.process_data)