fc_agent_runner.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465
  1. import json
  2. import logging
  3. from collections.abc import Generator
  4. from copy import deepcopy
  5. from typing import Any, Optional, Union
  6. from core.agent.base_agent_runner import BaseAgentRunner
  7. from core.app.apps.base_app_queue_manager import PublishFrom
  8. from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
  9. from core.file import file_manager
  10. from core.model_runtime.entities import (
  11. AssistantPromptMessage,
  12. LLMResult,
  13. LLMResultChunk,
  14. LLMResultChunkDelta,
  15. LLMUsage,
  16. PromptMessage,
  17. PromptMessageContent,
  18. PromptMessageContentType,
  19. SystemPromptMessage,
  20. TextPromptMessageContent,
  21. ToolPromptMessage,
  22. UserPromptMessage,
  23. )
  24. from core.model_runtime.entities.message_entities import ImagePromptMessageContent
  25. from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
  26. from core.tools.entities.tool_entities import ToolInvokeMeta
  27. from core.tools.tool_engine import ToolEngine
  28. from models.model import Message
  29. logger = logging.getLogger(__name__)
  30. class FunctionCallAgentRunner(BaseAgentRunner):
  31. def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]:
  32. """
  33. Run FunctionCall agent application
  34. """
  35. self.query = query
  36. app_generate_entity = self.application_generate_entity
  37. app_config = self.app_config
  38. # convert tools into ModelRuntime Tool format
  39. tool_instances, prompt_messages_tools = self._init_prompt_tools()
  40. iteration_step = 1
  41. max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
  42. # continue to run until there is not any tool call
  43. function_call_state = True
  44. llm_usage = {"usage": None}
  45. final_answer = ""
  46. # get tracing instance
  47. trace_manager = app_generate_entity.trace_manager
  48. def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
  49. if not final_llm_usage_dict["usage"]:
  50. final_llm_usage_dict["usage"] = usage
  51. else:
  52. llm_usage = final_llm_usage_dict["usage"]
  53. llm_usage.prompt_tokens += usage.prompt_tokens
  54. llm_usage.completion_tokens += usage.completion_tokens
  55. llm_usage.prompt_price += usage.prompt_price
  56. llm_usage.completion_price += usage.completion_price
  57. llm_usage.total_price += usage.total_price
  58. model_instance = self.model_instance
  59. while function_call_state and iteration_step <= max_iteration_steps:
  60. function_call_state = False
  61. if iteration_step == max_iteration_steps:
  62. # the last iteration, remove all tools
  63. prompt_messages_tools = []
  64. message_file_ids = []
  65. agent_thought = self.create_agent_thought(
  66. message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
  67. )
  68. # recalc llm max tokens
  69. prompt_messages = self._organize_prompt_messages()
  70. self.recalc_llm_max_tokens(self.model_config, prompt_messages)
  71. # invoke model
  72. chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
  73. prompt_messages=prompt_messages,
  74. model_parameters=app_generate_entity.model_conf.parameters,
  75. tools=prompt_messages_tools,
  76. stop=app_generate_entity.model_conf.stop,
  77. stream=self.stream_tool_call,
  78. user=self.user_id,
  79. callbacks=[],
  80. )
  81. tool_calls: list[tuple[str, str, dict[str, Any]]] = []
  82. # save full response
  83. response = ""
  84. # save tool call names and inputs
  85. tool_call_names = ""
  86. tool_call_inputs = ""
  87. current_llm_usage = None
  88. if self.stream_tool_call:
  89. is_first_chunk = True
  90. for chunk in chunks:
  91. if is_first_chunk:
  92. self.queue_manager.publish(
  93. QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
  94. )
  95. is_first_chunk = False
  96. # check if there is any tool call
  97. if self.check_tool_calls(chunk):
  98. function_call_state = True
  99. tool_calls.extend(self.extract_tool_calls(chunk))
  100. tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
  101. try:
  102. tool_call_inputs = json.dumps(
  103. {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
  104. )
  105. except json.JSONDecodeError as e:
  106. # ensure ascii to avoid encoding error
  107. tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
  108. if chunk.delta.message and chunk.delta.message.content:
  109. if isinstance(chunk.delta.message.content, list):
  110. for content in chunk.delta.message.content:
  111. response += content.data
  112. else:
  113. response += chunk.delta.message.content
  114. if chunk.delta.usage:
  115. increase_usage(llm_usage, chunk.delta.usage)
  116. current_llm_usage = chunk.delta.usage
  117. yield chunk
  118. else:
  119. result: LLMResult = chunks
  120. # check if there is any tool call
  121. if self.check_blocking_tool_calls(result):
  122. function_call_state = True
  123. tool_calls.extend(self.extract_blocking_tool_calls(result))
  124. tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
  125. try:
  126. tool_call_inputs = json.dumps(
  127. {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
  128. )
  129. except json.JSONDecodeError as e:
  130. # ensure ascii to avoid encoding error
  131. tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
  132. if result.usage:
  133. increase_usage(llm_usage, result.usage)
  134. current_llm_usage = result.usage
  135. if result.message and result.message.content:
  136. if isinstance(result.message.content, list):
  137. for content in result.message.content:
  138. response += content.data
  139. else:
  140. response += result.message.content
  141. if not result.message.content:
  142. result.message.content = ""
  143. self.queue_manager.publish(
  144. QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
  145. )
  146. yield LLMResultChunk(
  147. model=model_instance.model,
  148. prompt_messages=result.prompt_messages,
  149. system_fingerprint=result.system_fingerprint,
  150. delta=LLMResultChunkDelta(
  151. index=0,
  152. message=result.message,
  153. usage=result.usage,
  154. ),
  155. )
  156. assistant_message = AssistantPromptMessage(content="", tool_calls=[])
  157. if tool_calls:
  158. assistant_message.tool_calls = [
  159. AssistantPromptMessage.ToolCall(
  160. id=tool_call[0],
  161. type="function",
  162. function=AssistantPromptMessage.ToolCall.ToolCallFunction(
  163. name=tool_call[1], arguments=json.dumps(tool_call[2], ensure_ascii=False)
  164. ),
  165. )
  166. for tool_call in tool_calls
  167. ]
  168. else:
  169. assistant_message.content = response
  170. self._current_thoughts.append(assistant_message)
  171. # save thought
  172. self.save_agent_thought(
  173. agent_thought=agent_thought,
  174. tool_name=tool_call_names,
  175. tool_input=tool_call_inputs,
  176. thought=response,
  177. tool_invoke_meta=None,
  178. observation=None,
  179. answer=response,
  180. messages_ids=[],
  181. llm_usage=current_llm_usage,
  182. )
  183. self.queue_manager.publish(
  184. QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
  185. )
  186. final_answer += response + "\n"
  187. # call tools
  188. tool_responses = []
  189. for tool_call_id, tool_call_name, tool_call_args in tool_calls:
  190. tool_instance = tool_instances.get(tool_call_name)
  191. if not tool_instance:
  192. tool_response = {
  193. "tool_call_id": tool_call_id,
  194. "tool_call_name": tool_call_name,
  195. "tool_response": f"there is not a tool named {tool_call_name}",
  196. "meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict(),
  197. }
  198. else:
  199. # invoke tool
  200. tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
  201. tool=tool_instance,
  202. tool_parameters=tool_call_args,
  203. user_id=self.user_id,
  204. tenant_id=self.tenant_id,
  205. message=self.message,
  206. invoke_from=self.application_generate_entity.invoke_from,
  207. agent_tool_callback=self.agent_callback,
  208. trace_manager=trace_manager,
  209. )
  210. # publish files
  211. for message_file_id, save_as in message_files:
  212. if save_as:
  213. self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as)
  214. # publish message file
  215. self.queue_manager.publish(
  216. QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
  217. )
  218. # add message file ids
  219. message_file_ids.append(message_file_id)
  220. tool_response = {
  221. "tool_call_id": tool_call_id,
  222. "tool_call_name": tool_call_name,
  223. "tool_response": tool_invoke_response,
  224. "meta": tool_invoke_meta.to_dict(),
  225. }
  226. tool_responses.append(tool_response)
  227. if tool_response["tool_response"] is not None:
  228. self._current_thoughts.append(
  229. ToolPromptMessage(
  230. content=tool_response["tool_response"],
  231. tool_call_id=tool_call_id,
  232. name=tool_call_name,
  233. )
  234. )
  235. if len(tool_responses) > 0:
  236. # save agent thought
  237. self.save_agent_thought(
  238. agent_thought=agent_thought,
  239. tool_name=None,
  240. tool_input=None,
  241. thought=None,
  242. tool_invoke_meta={
  243. tool_response["tool_call_name"]: tool_response["meta"] for tool_response in tool_responses
  244. },
  245. observation={
  246. tool_response["tool_call_name"]: tool_response["tool_response"]
  247. for tool_response in tool_responses
  248. },
  249. answer=None,
  250. messages_ids=message_file_ids,
  251. )
  252. self.queue_manager.publish(
  253. QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
  254. )
  255. # update prompt tool
  256. for prompt_tool in prompt_messages_tools:
  257. self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
  258. iteration_step += 1
  259. self.update_db_variables(self.variables_pool, self.db_variables_pool)
  260. # publish end event
  261. self.queue_manager.publish(
  262. QueueMessageEndEvent(
  263. llm_result=LLMResult(
  264. model=model_instance.model,
  265. prompt_messages=prompt_messages,
  266. message=AssistantPromptMessage(content=final_answer),
  267. usage=llm_usage["usage"] or LLMUsage.empty_usage(),
  268. system_fingerprint="",
  269. )
  270. ),
  271. PublishFrom.APPLICATION_MANAGER,
  272. )
  273. def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool:
  274. """
  275. Check if there is any tool call in llm result chunk
  276. """
  277. if llm_result_chunk.delta.message.tool_calls:
  278. return True
  279. return False
  280. def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool:
  281. """
  282. Check if there is any blocking tool call in llm result
  283. """
  284. if llm_result.message.tool_calls:
  285. return True
  286. return False
  287. def extract_tool_calls(
  288. self, llm_result_chunk: LLMResultChunk
  289. ) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
  290. """
  291. Extract tool calls from llm result chunk
  292. Returns:
  293. List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
  294. """
  295. tool_calls = []
  296. for prompt_message in llm_result_chunk.delta.message.tool_calls:
  297. args = {}
  298. if prompt_message.function.arguments != "":
  299. args = json.loads(prompt_message.function.arguments)
  300. tool_calls.append(
  301. (
  302. prompt_message.id,
  303. prompt_message.function.name,
  304. args,
  305. )
  306. )
  307. return tool_calls
  308. def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
  309. """
  310. Extract blocking tool calls from llm result
  311. Returns:
  312. List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
  313. """
  314. tool_calls = []
  315. for prompt_message in llm_result.message.tool_calls:
  316. args = {}
  317. if prompt_message.function.arguments != "":
  318. args = json.loads(prompt_message.function.arguments)
  319. tool_calls.append(
  320. (
  321. prompt_message.id,
  322. prompt_message.function.name,
  323. args,
  324. )
  325. )
  326. return tool_calls
  327. def _init_system_message(
  328. self, prompt_template: str, prompt_messages: Optional[list[PromptMessage]] = None
  329. ) -> list[PromptMessage]:
  330. """
  331. Initialize system message
  332. """
  333. if not prompt_messages and prompt_template:
  334. return [
  335. SystemPromptMessage(content=prompt_template),
  336. ]
  337. if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template:
  338. prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
  339. return prompt_messages
  340. def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
  341. """
  342. Organize user query
  343. """
  344. if self.files:
  345. prompt_message_contents: list[PromptMessageContent] = []
  346. prompt_message_contents.append(TextPromptMessageContent(data=query))
  347. # get image detail config
  348. image_detail_config = (
  349. self.application_generate_entity.file_upload_config.image_config.detail
  350. if (
  351. self.application_generate_entity.file_upload_config
  352. and self.application_generate_entity.file_upload_config.image_config
  353. )
  354. else None
  355. )
  356. image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
  357. for file in self.files:
  358. prompt_message_contents.append(
  359. file_manager.to_prompt_message_content(
  360. file,
  361. image_detail_config=image_detail_config,
  362. )
  363. )
  364. prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
  365. else:
  366. prompt_messages.append(UserPromptMessage(content=query))
  367. return prompt_messages
  368. def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
  369. """
  370. As for now, gpt supports both fc and vision at the first iteration.
  371. We need to remove the image messages from the prompt messages at the first iteration.
  372. """
  373. prompt_messages = deepcopy(prompt_messages)
  374. for prompt_message in prompt_messages:
  375. if isinstance(prompt_message, UserPromptMessage):
  376. if isinstance(prompt_message.content, list):
  377. prompt_message.content = "\n".join(
  378. [
  379. content.data
  380. if content.type == PromptMessageContentType.TEXT
  381. else "[image]"
  382. if content.type == PromptMessageContentType.IMAGE
  383. else "[file]"
  384. for content in prompt_message.content
  385. ]
  386. )
  387. return prompt_messages
  388. def _organize_prompt_messages(self):
  389. prompt_template = self.app_config.prompt_template.simple_prompt_template or ""
  390. self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
  391. query_prompt_messages = self._organize_user_query(self.query, [])
  392. self.history_prompt_messages = AgentHistoryPromptTransform(
  393. model_config=self.model_config,
  394. prompt_messages=[*query_prompt_messages, *self._current_thoughts],
  395. history_messages=self.history_prompt_messages,
  396. memory=self.memory,
  397. ).get_prompt()
  398. prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts]
  399. if len(self._current_thoughts) != 0:
  400. # clear messages after the first iteration
  401. prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
  402. return prompt_messages