fc_agent_runner.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442
  1. import json
  2. import logging
  3. from collections.abc import Generator
  4. from copy import deepcopy
  5. from typing import Any, Union
  6. from core.agent.base_agent_runner import BaseAgentRunner
  7. from core.app.apps.base_app_queue_manager import PublishFrom
  8. from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
  9. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
  10. from core.model_runtime.entities.message_entities import (
  11. AssistantPromptMessage,
  12. PromptMessage,
  13. PromptMessageContentType,
  14. SystemPromptMessage,
  15. TextPromptMessageContent,
  16. ToolPromptMessage,
  17. UserPromptMessage,
  18. )
  19. from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
  20. from core.tools.entities.tool_entities import ToolInvokeMeta
  21. from core.tools.tool_engine import ToolEngine
  22. from models.model import Message
  23. logger = logging.getLogger(__name__)
  24. class FunctionCallAgentRunner(BaseAgentRunner):
  25. def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]:
  26. """
  27. Run FunctionCall agent application
  28. """
  29. self.query = query
  30. app_generate_entity = self.application_generate_entity
  31. app_config = self.app_config
  32. # convert tools into ModelRuntime Tool format
  33. tool_instances, prompt_messages_tools = self._init_prompt_tools()
  34. iteration_step = 1
  35. max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
  36. # continue to run until there is not any tool call
  37. function_call_state = True
  38. llm_usage = {"usage": None}
  39. final_answer = ""
  40. # get tracing instance
  41. trace_manager = app_generate_entity.trace_manager
  42. def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
  43. if not final_llm_usage_dict["usage"]:
  44. final_llm_usage_dict["usage"] = usage
  45. else:
  46. llm_usage = final_llm_usage_dict["usage"]
  47. llm_usage.prompt_tokens += usage.prompt_tokens
  48. llm_usage.completion_tokens += usage.completion_tokens
  49. llm_usage.prompt_price += usage.prompt_price
  50. llm_usage.completion_price += usage.completion_price
  51. llm_usage.total_price += usage.total_price
  52. model_instance = self.model_instance
  53. while function_call_state and iteration_step <= max_iteration_steps:
  54. function_call_state = False
  55. if iteration_step == max_iteration_steps:
  56. # the last iteration, remove all tools
  57. prompt_messages_tools = []
  58. message_file_ids = []
  59. agent_thought = self.create_agent_thought(
  60. message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
  61. )
  62. # recalc llm max tokens
  63. prompt_messages = self._organize_prompt_messages()
  64. self.recalc_llm_max_tokens(self.model_config, prompt_messages)
  65. # invoke model
  66. chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
  67. prompt_messages=prompt_messages,
  68. model_parameters=app_generate_entity.model_conf.parameters,
  69. tools=prompt_messages_tools,
  70. stop=app_generate_entity.model_conf.stop,
  71. stream=self.stream_tool_call,
  72. user=self.user_id,
  73. callbacks=[],
  74. )
  75. tool_calls: list[tuple[str, str, dict[str, Any]]] = []
  76. # save full response
  77. response = ""
  78. # save tool call names and inputs
  79. tool_call_names = ""
  80. tool_call_inputs = ""
  81. current_llm_usage = None
  82. if self.stream_tool_call:
  83. is_first_chunk = True
  84. for chunk in chunks:
  85. if is_first_chunk:
  86. self.queue_manager.publish(
  87. QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
  88. )
  89. is_first_chunk = False
  90. # check if there is any tool call
  91. if self.check_tool_calls(chunk):
  92. function_call_state = True
  93. tool_calls.extend(self.extract_tool_calls(chunk))
  94. tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
  95. try:
  96. tool_call_inputs = json.dumps(
  97. {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
  98. )
  99. except json.JSONDecodeError as e:
  100. # ensure ascii to avoid encoding error
  101. tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
  102. if chunk.delta.message and chunk.delta.message.content:
  103. if isinstance(chunk.delta.message.content, list):
  104. for content in chunk.delta.message.content:
  105. response += content.data
  106. else:
  107. response += chunk.delta.message.content
  108. if chunk.delta.usage:
  109. increase_usage(llm_usage, chunk.delta.usage)
  110. current_llm_usage = chunk.delta.usage
  111. yield chunk
  112. else:
  113. result: LLMResult = chunks
  114. # check if there is any tool call
  115. if self.check_blocking_tool_calls(result):
  116. function_call_state = True
  117. tool_calls.extend(self.extract_blocking_tool_calls(result))
  118. tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
  119. try:
  120. tool_call_inputs = json.dumps(
  121. {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
  122. )
  123. except json.JSONDecodeError as e:
  124. # ensure ascii to avoid encoding error
  125. tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
  126. if result.usage:
  127. increase_usage(llm_usage, result.usage)
  128. current_llm_usage = result.usage
  129. if result.message and result.message.content:
  130. if isinstance(result.message.content, list):
  131. for content in result.message.content:
  132. response += content.data
  133. else:
  134. response += result.message.content
  135. if not result.message.content:
  136. result.message.content = ""
  137. self.queue_manager.publish(
  138. QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
  139. )
  140. yield LLMResultChunk(
  141. model=model_instance.model,
  142. prompt_messages=result.prompt_messages,
  143. system_fingerprint=result.system_fingerprint,
  144. delta=LLMResultChunkDelta(
  145. index=0,
  146. message=result.message,
  147. usage=result.usage,
  148. ),
  149. )
  150. assistant_message = AssistantPromptMessage(content="", tool_calls=[])
  151. if tool_calls:
  152. assistant_message.tool_calls = [
  153. AssistantPromptMessage.ToolCall(
  154. id=tool_call[0],
  155. type="function",
  156. function=AssistantPromptMessage.ToolCall.ToolCallFunction(
  157. name=tool_call[1], arguments=json.dumps(tool_call[2], ensure_ascii=False)
  158. ),
  159. )
  160. for tool_call in tool_calls
  161. ]
  162. else:
  163. assistant_message.content = response
  164. self._current_thoughts.append(assistant_message)
  165. # save thought
  166. self.save_agent_thought(
  167. agent_thought=agent_thought,
  168. tool_name=tool_call_names,
  169. tool_input=tool_call_inputs,
  170. thought=response,
  171. tool_invoke_meta=None,
  172. observation=None,
  173. answer=response,
  174. messages_ids=[],
  175. llm_usage=current_llm_usage,
  176. )
  177. self.queue_manager.publish(
  178. QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
  179. )
  180. final_answer += response + "\n"
  181. # call tools
  182. tool_responses = []
  183. for tool_call_id, tool_call_name, tool_call_args in tool_calls:
  184. tool_instance = tool_instances.get(tool_call_name)
  185. if not tool_instance:
  186. tool_response = {
  187. "tool_call_id": tool_call_id,
  188. "tool_call_name": tool_call_name,
  189. "tool_response": f"there is not a tool named {tool_call_name}",
  190. "meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict(),
  191. }
  192. else:
  193. # invoke tool
  194. tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
  195. tool=tool_instance,
  196. tool_parameters=tool_call_args,
  197. user_id=self.user_id,
  198. tenant_id=self.tenant_id,
  199. message=self.message,
  200. invoke_from=self.application_generate_entity.invoke_from,
  201. agent_tool_callback=self.agent_callback,
  202. trace_manager=trace_manager,
  203. )
  204. # publish files
  205. for message_file_id, save_as in message_files:
  206. if save_as:
  207. self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as)
  208. # publish message file
  209. self.queue_manager.publish(
  210. QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
  211. )
  212. # add message file ids
  213. message_file_ids.append(message_file_id)
  214. tool_response = {
  215. "tool_call_id": tool_call_id,
  216. "tool_call_name": tool_call_name,
  217. "tool_response": tool_invoke_response,
  218. "meta": tool_invoke_meta.to_dict(),
  219. }
  220. tool_responses.append(tool_response)
  221. if tool_response["tool_response"] is not None:
  222. self._current_thoughts.append(
  223. ToolPromptMessage(
  224. content=tool_response["tool_response"],
  225. tool_call_id=tool_call_id,
  226. name=tool_call_name,
  227. )
  228. )
  229. if len(tool_responses) > 0:
  230. # save agent thought
  231. self.save_agent_thought(
  232. agent_thought=agent_thought,
  233. tool_name=None,
  234. tool_input=None,
  235. thought=None,
  236. tool_invoke_meta={
  237. tool_response["tool_call_name"]: tool_response["meta"] for tool_response in tool_responses
  238. },
  239. observation={
  240. tool_response["tool_call_name"]: tool_response["tool_response"]
  241. for tool_response in tool_responses
  242. },
  243. answer=None,
  244. messages_ids=message_file_ids,
  245. )
  246. self.queue_manager.publish(
  247. QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
  248. )
  249. # update prompt tool
  250. for prompt_tool in prompt_messages_tools:
  251. self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
  252. iteration_step += 1
  253. self.update_db_variables(self.variables_pool, self.db_variables_pool)
  254. # publish end event
  255. self.queue_manager.publish(
  256. QueueMessageEndEvent(
  257. llm_result=LLMResult(
  258. model=model_instance.model,
  259. prompt_messages=prompt_messages,
  260. message=AssistantPromptMessage(content=final_answer),
  261. usage=llm_usage["usage"] or LLMUsage.empty_usage(),
  262. system_fingerprint="",
  263. )
  264. ),
  265. PublishFrom.APPLICATION_MANAGER,
  266. )
  267. def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool:
  268. """
  269. Check if there is any tool call in llm result chunk
  270. """
  271. if llm_result_chunk.delta.message.tool_calls:
  272. return True
  273. return False
  274. def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool:
  275. """
  276. Check if there is any blocking tool call in llm result
  277. """
  278. if llm_result.message.tool_calls:
  279. return True
  280. return False
  281. def extract_tool_calls(
  282. self, llm_result_chunk: LLMResultChunk
  283. ) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
  284. """
  285. Extract tool calls from llm result chunk
  286. Returns:
  287. List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
  288. """
  289. tool_calls = []
  290. for prompt_message in llm_result_chunk.delta.message.tool_calls:
  291. args = {}
  292. if prompt_message.function.arguments != "":
  293. args = json.loads(prompt_message.function.arguments)
  294. tool_calls.append(
  295. (
  296. prompt_message.id,
  297. prompt_message.function.name,
  298. args,
  299. )
  300. )
  301. return tool_calls
  302. def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
  303. """
  304. Extract blocking tool calls from llm result
  305. Returns:
  306. List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
  307. """
  308. tool_calls = []
  309. for prompt_message in llm_result.message.tool_calls:
  310. args = {}
  311. if prompt_message.function.arguments != "":
  312. args = json.loads(prompt_message.function.arguments)
  313. tool_calls.append(
  314. (
  315. prompt_message.id,
  316. prompt_message.function.name,
  317. args,
  318. )
  319. )
  320. return tool_calls
  321. def _init_system_message(
  322. self, prompt_template: str, prompt_messages: list[PromptMessage] = None
  323. ) -> list[PromptMessage]:
  324. """
  325. Initialize system message
  326. """
  327. if not prompt_messages and prompt_template:
  328. return [
  329. SystemPromptMessage(content=prompt_template),
  330. ]
  331. if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template:
  332. prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
  333. return prompt_messages
  334. def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
  335. """
  336. Organize user query
  337. """
  338. if self.files:
  339. prompt_message_contents = [TextPromptMessageContent(data=query)]
  340. for file_obj in self.files:
  341. prompt_message_contents.append(file_obj.prompt_message_content)
  342. prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
  343. else:
  344. prompt_messages.append(UserPromptMessage(content=query))
  345. return prompt_messages
  346. def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
  347. """
  348. As for now, gpt supports both fc and vision at the first iteration.
  349. We need to remove the image messages from the prompt messages at the first iteration.
  350. """
  351. prompt_messages = deepcopy(prompt_messages)
  352. for prompt_message in prompt_messages:
  353. if isinstance(prompt_message, UserPromptMessage):
  354. if isinstance(prompt_message.content, list):
  355. prompt_message.content = "\n".join(
  356. [
  357. content.data
  358. if content.type == PromptMessageContentType.TEXT
  359. else "[image]"
  360. if content.type == PromptMessageContentType.IMAGE
  361. else "[file]"
  362. for content in prompt_message.content
  363. ]
  364. )
  365. return prompt_messages
  366. def _organize_prompt_messages(self):
  367. prompt_template = self.app_config.prompt_template.simple_prompt_template or ""
  368. self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
  369. query_prompt_messages = self._organize_user_query(self.query, [])
  370. self.history_prompt_messages = AgentHistoryPromptTransform(
  371. model_config=self.model_config,
  372. prompt_messages=[*query_prompt_messages, *self._current_thoughts],
  373. history_messages=self.history_prompt_messages,
  374. memory=self.memory,
  375. ).get_prompt()
  376. prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts]
  377. if len(self._current_thoughts) != 0:
  378. # clear messages after the first iteration
  379. prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
  380. return prompt_messages