agent_app_runner.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. import json
  2. import logging
  3. from typing import cast
  4. from core.agent.agent.agent_llm_callback import AgentLLMCallback
  5. from core.app_runner.app_runner import AppRunner
  6. from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
  7. from core.entities.application_entities import ApplicationGenerateEntity, PromptTemplateEntity, ModelConfigEntity
  8. from core.application_queue_manager import ApplicationQueueManager
  9. from core.features.agent_runner import AgentRunnerFeature
  10. from core.memory.token_buffer_memory import TokenBufferMemory
  11. from core.model_manager import ModelInstance
  12. from core.model_runtime.entities.llm_entities import LLMUsage
  13. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  14. from extensions.ext_database import db
  15. from models.model import Conversation, Message, App, MessageChain, MessageAgentThought
  16. logger = logging.getLogger(__name__)
  17. class AgentApplicationRunner(AppRunner):
  18. """
  19. Agent Application Runner
  20. """
  21. def run(self, application_generate_entity: ApplicationGenerateEntity,
  22. queue_manager: ApplicationQueueManager,
  23. conversation: Conversation,
  24. message: Message) -> None:
  25. """
  26. Run agent application
  27. :param application_generate_entity: application generate entity
  28. :param queue_manager: application queue manager
  29. :param conversation: conversation
  30. :param message: message
  31. :return:
  32. """
  33. app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first()
  34. if not app_record:
  35. raise ValueError(f"App not found")
  36. app_orchestration_config = application_generate_entity.app_orchestration_config_entity
  37. inputs = application_generate_entity.inputs
  38. query = application_generate_entity.query
  39. files = application_generate_entity.files
  40. # Pre-calculate the number of tokens of the prompt messages,
  41. # and return the rest number of tokens by model context token size limit and max token size limit.
  42. # If the rest number of tokens is not enough, raise exception.
  43. # Include: prompt template, inputs, query(optional), files(optional)
  44. # Not Include: memory, external data, dataset context
  45. self.get_pre_calculate_rest_tokens(
  46. app_record=app_record,
  47. model_config=app_orchestration_config.model_config,
  48. prompt_template_entity=app_orchestration_config.prompt_template,
  49. inputs=inputs,
  50. files=files,
  51. query=query
  52. )
  53. memory = None
  54. if application_generate_entity.conversation_id:
  55. # get memory of conversation (read-only)
  56. model_instance = ModelInstance(
  57. provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
  58. model=app_orchestration_config.model_config.model
  59. )
  60. memory = TokenBufferMemory(
  61. conversation=conversation,
  62. model_instance=model_instance
  63. )
  64. # reorganize all inputs and template to prompt messages
  65. # Include: prompt template, inputs, query(optional), files(optional)
  66. # memory(optional)
  67. prompt_messages, stop = self.organize_prompt_messages(
  68. app_record=app_record,
  69. model_config=app_orchestration_config.model_config,
  70. prompt_template_entity=app_orchestration_config.prompt_template,
  71. inputs=inputs,
  72. files=files,
  73. query=query,
  74. context=None,
  75. memory=memory
  76. )
  77. # Create MessageChain
  78. message_chain = self._init_message_chain(
  79. message=message,
  80. query=query
  81. )
  82. # add agent callback to record agent thoughts
  83. agent_callback = AgentLoopGatherCallbackHandler(
  84. model_config=app_orchestration_config.model_config,
  85. message=message,
  86. queue_manager=queue_manager,
  87. message_chain=message_chain
  88. )
  89. # init LLM Callback
  90. agent_llm_callback = AgentLLMCallback(
  91. agent_callback=agent_callback
  92. )
  93. agent_runner = AgentRunnerFeature(
  94. tenant_id=application_generate_entity.tenant_id,
  95. app_orchestration_config=app_orchestration_config,
  96. model_config=app_orchestration_config.model_config,
  97. config=app_orchestration_config.agent,
  98. queue_manager=queue_manager,
  99. message=message,
  100. user_id=application_generate_entity.user_id,
  101. agent_llm_callback=agent_llm_callback,
  102. callback=agent_callback,
  103. memory=memory
  104. )
  105. # agent run
  106. result = agent_runner.run(
  107. query=query,
  108. invoke_from=application_generate_entity.invoke_from
  109. )
  110. if result:
  111. self._save_message_chain(
  112. message_chain=message_chain,
  113. output_text=result
  114. )
  115. if (result
  116. and app_orchestration_config.prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE
  117. and app_orchestration_config.prompt_template.simple_prompt_template
  118. ):
  119. # Direct output if agent result exists and has pre prompt
  120. self.direct_output(
  121. queue_manager=queue_manager,
  122. app_orchestration_config=app_orchestration_config,
  123. prompt_messages=prompt_messages,
  124. stream=application_generate_entity.stream,
  125. text=result,
  126. usage=self._get_usage_of_all_agent_thoughts(
  127. model_config=app_orchestration_config.model_config,
  128. message=message
  129. )
  130. )
  131. else:
  132. # As normal LLM run, agent result as context
  133. context = result
  134. # reorganize all inputs and template to prompt messages
  135. # Include: prompt template, inputs, query(optional), files(optional)
  136. # memory(optional), external data, dataset context(optional)
  137. prompt_messages, stop = self.organize_prompt_messages(
  138. app_record=app_record,
  139. model_config=app_orchestration_config.model_config,
  140. prompt_template_entity=app_orchestration_config.prompt_template,
  141. inputs=inputs,
  142. files=files,
  143. query=query,
  144. context=context,
  145. memory=memory
  146. )
  147. # Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
  148. self.recale_llm_max_tokens(
  149. model_config=app_orchestration_config.model_config,
  150. prompt_messages=prompt_messages
  151. )
  152. # Invoke model
  153. model_instance = ModelInstance(
  154. provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
  155. model=app_orchestration_config.model_config.model
  156. )
  157. invoke_result = model_instance.invoke_llm(
  158. prompt_messages=prompt_messages,
  159. model_parameters=app_orchestration_config.model_config.parameters,
  160. stop=stop,
  161. stream=application_generate_entity.stream,
  162. user=application_generate_entity.user_id,
  163. )
  164. # handle invoke result
  165. self._handle_invoke_result(
  166. invoke_result=invoke_result,
  167. queue_manager=queue_manager,
  168. stream=application_generate_entity.stream
  169. )
  170. def _init_message_chain(self, message: Message, query: str) -> MessageChain:
  171. """
  172. Init MessageChain
  173. :param message: message
  174. :param query: query
  175. :return:
  176. """
  177. message_chain = MessageChain(
  178. message_id=message.id,
  179. type="AgentExecutor",
  180. input=json.dumps({
  181. "input": query
  182. })
  183. )
  184. db.session.add(message_chain)
  185. db.session.commit()
  186. return message_chain
  187. def _save_message_chain(self, message_chain: MessageChain, output_text: str) -> None:
  188. """
  189. Save MessageChain
  190. :param message_chain: message chain
  191. :param output_text: output text
  192. :return:
  193. """
  194. message_chain.output = json.dumps({
  195. "output": output_text
  196. })
  197. db.session.commit()
  198. def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigEntity,
  199. message: Message) -> LLMUsage:
  200. """
  201. Get usage of all agent thoughts
  202. :param model_config: model config
  203. :param message: message
  204. :return:
  205. """
  206. agent_thoughts = (db.session.query(MessageAgentThought)
  207. .filter(MessageAgentThought.message_id == message.id).all())
  208. all_message_tokens = 0
  209. all_answer_tokens = 0
  210. for agent_thought in agent_thoughts:
  211. all_message_tokens += agent_thought.message_token
  212. all_answer_tokens += agent_thought.answer_token
  213. model_type_instance = model_config.provider_model_bundle.model_type_instance
  214. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  215. return model_type_instance._calc_response_usage(
  216. model_config.model,
  217. model_config.credentials,
  218. all_message_tokens,
  219. all_answer_tokens
  220. )