assistant_app_runner.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342
  1. import json
  2. import logging
  3. from typing import cast
  4. from core.app_runner.app_runner import AppRunner
  5. from core.features.assistant_cot_runner import AssistantCotApplicationRunner
  6. from core.features.assistant_fc_runner import AssistantFunctionCallApplicationRunner
  7. from core.entities.application_entities import ApplicationGenerateEntity, ModelConfigEntity, \
  8. AgentEntity
  9. from core.application_queue_manager import ApplicationQueueManager, PublishFrom
  10. from core.memory.token_buffer_memory import TokenBufferMemory
  11. from core.model_manager import ModelInstance
  12. from core.model_runtime.entities.llm_entities import LLMUsage
  13. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  14. from core.moderation.base import ModerationException
  15. from core.tools.entities.tool_entities import ToolRuntimeVariablePool
  16. from extensions.ext_database import db
  17. from models.model import Conversation, Message, App, MessageChain, MessageAgentThought
  18. from models.tools import ToolConversationVariables
  19. logger = logging.getLogger(__name__)
  20. class AssistantApplicationRunner(AppRunner):
  21. """
  22. Assistant Application Runner
  23. """
  24. def run(self, application_generate_entity: ApplicationGenerateEntity,
  25. queue_manager: ApplicationQueueManager,
  26. conversation: Conversation,
  27. message: Message) -> None:
  28. """
  29. Run assistant application
  30. :param application_generate_entity: application generate entity
  31. :param queue_manager: application queue manager
  32. :param conversation: conversation
  33. :param message: message
  34. :return:
  35. """
  36. app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first()
  37. if not app_record:
  38. raise ValueError(f"App not found")
  39. app_orchestration_config = application_generate_entity.app_orchestration_config_entity
  40. inputs = application_generate_entity.inputs
  41. query = application_generate_entity.query
  42. files = application_generate_entity.files
  43. # Pre-calculate the number of tokens of the prompt messages,
  44. # and return the rest number of tokens by model context token size limit and max token size limit.
  45. # If the rest number of tokens is not enough, raise exception.
  46. # Include: prompt template, inputs, query(optional), files(optional)
  47. # Not Include: memory, external data, dataset context
  48. self.get_pre_calculate_rest_tokens(
  49. app_record=app_record,
  50. model_config=app_orchestration_config.model_config,
  51. prompt_template_entity=app_orchestration_config.prompt_template,
  52. inputs=inputs,
  53. files=files,
  54. query=query
  55. )
  56. memory = None
  57. if application_generate_entity.conversation_id:
  58. # get memory of conversation (read-only)
  59. model_instance = ModelInstance(
  60. provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
  61. model=app_orchestration_config.model_config.model
  62. )
  63. memory = TokenBufferMemory(
  64. conversation=conversation,
  65. model_instance=model_instance
  66. )
  67. # organize all inputs and template to prompt messages
  68. # Include: prompt template, inputs, query(optional), files(optional)
  69. # memory(optional)
  70. prompt_messages, _ = self.organize_prompt_messages(
  71. app_record=app_record,
  72. model_config=app_orchestration_config.model_config,
  73. prompt_template_entity=app_orchestration_config.prompt_template,
  74. inputs=inputs,
  75. files=files,
  76. query=query,
  77. memory=memory
  78. )
  79. # moderation
  80. try:
  81. # process sensitive_word_avoidance
  82. _, inputs, query = self.moderation_for_inputs(
  83. app_id=app_record.id,
  84. tenant_id=application_generate_entity.tenant_id,
  85. app_orchestration_config_entity=app_orchestration_config,
  86. inputs=inputs,
  87. query=query,
  88. )
  89. except ModerationException as e:
  90. self.direct_output(
  91. queue_manager=queue_manager,
  92. app_orchestration_config=app_orchestration_config,
  93. prompt_messages=prompt_messages,
  94. text=str(e),
  95. stream=application_generate_entity.stream
  96. )
  97. return
  98. if query:
  99. # annotation reply
  100. annotation_reply = self.query_app_annotations_to_reply(
  101. app_record=app_record,
  102. message=message,
  103. query=query,
  104. user_id=application_generate_entity.user_id,
  105. invoke_from=application_generate_entity.invoke_from
  106. )
  107. if annotation_reply:
  108. queue_manager.publish_annotation_reply(
  109. message_annotation_id=annotation_reply.id,
  110. pub_from=PublishFrom.APPLICATION_MANAGER
  111. )
  112. self.direct_output(
  113. queue_manager=queue_manager,
  114. app_orchestration_config=app_orchestration_config,
  115. prompt_messages=prompt_messages,
  116. text=annotation_reply.content,
  117. stream=application_generate_entity.stream
  118. )
  119. return
  120. # fill in variable inputs from external data tools if exists
  121. external_data_tools = app_orchestration_config.external_data_variables
  122. if external_data_tools:
  123. inputs = self.fill_in_inputs_from_external_data_tools(
  124. tenant_id=app_record.tenant_id,
  125. app_id=app_record.id,
  126. external_data_tools=external_data_tools,
  127. inputs=inputs,
  128. query=query
  129. )
  130. # reorganize all inputs and template to prompt messages
  131. # Include: prompt template, inputs, query(optional), files(optional)
  132. # memory(optional), external data, dataset context(optional)
  133. prompt_messages, _ = self.organize_prompt_messages(
  134. app_record=app_record,
  135. model_config=app_orchestration_config.model_config,
  136. prompt_template_entity=app_orchestration_config.prompt_template,
  137. inputs=inputs,
  138. files=files,
  139. query=query,
  140. memory=memory
  141. )
  142. # check hosting moderation
  143. hosting_moderation_result = self.check_hosting_moderation(
  144. application_generate_entity=application_generate_entity,
  145. queue_manager=queue_manager,
  146. prompt_messages=prompt_messages
  147. )
  148. if hosting_moderation_result:
  149. return
  150. agent_entity = app_orchestration_config.agent
  151. # load tool variables
  152. tool_conversation_variables = self._load_tool_variables(conversation_id=conversation.id,
  153. user_id=application_generate_entity.user_id,
  154. tanent_id=application_generate_entity.tenant_id)
  155. # convert db variables to tool variables
  156. tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables)
  157. message_chain = self._init_message_chain(
  158. message=message,
  159. query=query
  160. )
  161. # init model instance
  162. model_instance = ModelInstance(
  163. provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
  164. model=app_orchestration_config.model_config.model
  165. )
  166. prompt_message, _ = self.organize_prompt_messages(
  167. app_record=app_record,
  168. model_config=app_orchestration_config.model_config,
  169. prompt_template_entity=app_orchestration_config.prompt_template,
  170. inputs=inputs,
  171. files=files,
  172. query=query,
  173. memory=memory,
  174. )
  175. # start agent runner
  176. if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
  177. assistant_cot_runner = AssistantCotApplicationRunner(
  178. tenant_id=application_generate_entity.tenant_id,
  179. application_generate_entity=application_generate_entity,
  180. app_orchestration_config=app_orchestration_config,
  181. model_config=app_orchestration_config.model_config,
  182. config=agent_entity,
  183. queue_manager=queue_manager,
  184. message=message,
  185. user_id=application_generate_entity.user_id,
  186. memory=memory,
  187. prompt_messages=prompt_message,
  188. variables_pool=tool_variables,
  189. db_variables=tool_conversation_variables,
  190. )
  191. invoke_result = assistant_cot_runner.run(
  192. model_instance=model_instance,
  193. conversation=conversation,
  194. message=message,
  195. query=query,
  196. )
  197. elif agent_entity.strategy == AgentEntity.Strategy.FUNCTION_CALLING:
  198. assistant_fc_runner = AssistantFunctionCallApplicationRunner(
  199. tenant_id=application_generate_entity.tenant_id,
  200. application_generate_entity=application_generate_entity,
  201. app_orchestration_config=app_orchestration_config,
  202. model_config=app_orchestration_config.model_config,
  203. config=agent_entity,
  204. queue_manager=queue_manager,
  205. message=message,
  206. user_id=application_generate_entity.user_id,
  207. memory=memory,
  208. prompt_messages=prompt_message,
  209. variables_pool=tool_variables,
  210. db_variables=tool_conversation_variables
  211. )
  212. invoke_result = assistant_fc_runner.run(
  213. model_instance=model_instance,
  214. conversation=conversation,
  215. message=message,
  216. query=query,
  217. )
  218. # handle invoke result
  219. self._handle_invoke_result(
  220. invoke_result=invoke_result,
  221. queue_manager=queue_manager,
  222. stream=application_generate_entity.stream,
  223. agent=True
  224. )
  225. def _load_tool_variables(self, conversation_id: str, user_id: str, tanent_id: str) -> ToolConversationVariables:
  226. """
  227. load tool variables from database
  228. """
  229. tool_variables: ToolConversationVariables = db.session.query(ToolConversationVariables).filter(
  230. ToolConversationVariables.conversation_id == conversation_id,
  231. ToolConversationVariables.tenant_id == tanent_id
  232. ).first()
  233. if tool_variables:
  234. # save tool variables to session, so that we can update it later
  235. db.session.add(tool_variables)
  236. else:
  237. # create new tool variables
  238. tool_variables = ToolConversationVariables(
  239. conversation_id=conversation_id,
  240. user_id=user_id,
  241. tenant_id=tanent_id,
  242. variables_str='[]',
  243. )
  244. db.session.add(tool_variables)
  245. db.session.commit()
  246. return tool_variables
  247. def _convert_db_variables_to_tool_variables(self, db_variables: ToolConversationVariables) -> ToolRuntimeVariablePool:
  248. """
  249. convert db variables to tool variables
  250. """
  251. return ToolRuntimeVariablePool(**{
  252. 'conversation_id': db_variables.conversation_id,
  253. 'user_id': db_variables.user_id,
  254. 'tenant_id': db_variables.tenant_id,
  255. 'pool': db_variables.variables
  256. })
  257. def _init_message_chain(self, message: Message, query: str) -> MessageChain:
  258. """
  259. Init MessageChain
  260. :param message: message
  261. :param query: query
  262. :return:
  263. """
  264. message_chain = MessageChain(
  265. message_id=message.id,
  266. type="AgentExecutor",
  267. input=json.dumps({
  268. "input": query
  269. })
  270. )
  271. db.session.add(message_chain)
  272. db.session.commit()
  273. return message_chain
  274. def _save_message_chain(self, message_chain: MessageChain, output_text: str) -> None:
  275. """
  276. Save MessageChain
  277. :param message_chain: message chain
  278. :param output_text: output text
  279. :return:
  280. """
  281. message_chain.output = json.dumps({
  282. "output": output_text
  283. })
  284. db.session.commit()
  285. def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigEntity,
  286. message: Message) -> LLMUsage:
  287. """
  288. Get usage of all agent thoughts
  289. :param model_config: model config
  290. :param message: message
  291. :return:
  292. """
  293. agent_thoughts = (db.session.query(MessageAgentThought)
  294. .filter(MessageAgentThought.message_id == message.id).all())
  295. all_message_tokens = 0
  296. all_answer_tokens = 0
  297. for agent_thought in agent_thoughts:
  298. all_message_tokens += agent_thought.message_tokens
  299. all_answer_tokens += agent_thought.answer_tokens
  300. model_type_instance = model_config.provider_model_bundle.model_type_instance
  301. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  302. return model_type_instance._calc_response_usage(
  303. model_config.model,
  304. model_config.credentials,
  305. all_message_tokens,
  306. all_answer_tokens
  307. )