fc_agent_runner.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448
  1. import json
  2. import logging
  3. from collections.abc import Generator
  4. from copy import deepcopy
  5. from typing import Any, Optional, Union
  6. from core.agent.base_agent_runner import BaseAgentRunner
  7. from core.app.apps.base_app_queue_manager import PublishFrom
  8. from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
  9. from core.file import file_manager
  10. from core.model_runtime.entities import (
  11. AssistantPromptMessage,
  12. LLMResult,
  13. LLMResultChunk,
  14. LLMResultChunkDelta,
  15. LLMUsage,
  16. PromptMessage,
  17. PromptMessageContent,
  18. PromptMessageContentType,
  19. SystemPromptMessage,
  20. TextPromptMessageContent,
  21. ToolPromptMessage,
  22. UserPromptMessage,
  23. )
  24. from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
  25. from core.tools.entities.tool_entities import ToolInvokeMeta
  26. from core.tools.tool_engine import ToolEngine
  27. from models.model import Message
  28. logger = logging.getLogger(__name__)
  29. class FunctionCallAgentRunner(BaseAgentRunner):
  30. def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]:
  31. """
  32. Run FunctionCall agent application
  33. """
  34. self.query = query
  35. app_generate_entity = self.application_generate_entity
  36. app_config = self.app_config
  37. # convert tools into ModelRuntime Tool format
  38. tool_instances, prompt_messages_tools = self._init_prompt_tools()
  39. iteration_step = 1
  40. max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
  41. # continue to run until there is not any tool call
  42. function_call_state = True
  43. llm_usage = {"usage": None}
  44. final_answer = ""
  45. # get tracing instance
  46. trace_manager = app_generate_entity.trace_manager
  47. def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
  48. if not final_llm_usage_dict["usage"]:
  49. final_llm_usage_dict["usage"] = usage
  50. else:
  51. llm_usage = final_llm_usage_dict["usage"]
  52. llm_usage.prompt_tokens += usage.prompt_tokens
  53. llm_usage.completion_tokens += usage.completion_tokens
  54. llm_usage.prompt_price += usage.prompt_price
  55. llm_usage.completion_price += usage.completion_price
  56. llm_usage.total_price += usage.total_price
  57. model_instance = self.model_instance
  58. while function_call_state and iteration_step <= max_iteration_steps:
  59. function_call_state = False
  60. if iteration_step == max_iteration_steps:
  61. # the last iteration, remove all tools
  62. prompt_messages_tools = []
  63. message_file_ids = []
  64. agent_thought = self.create_agent_thought(
  65. message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
  66. )
  67. # recalc llm max tokens
  68. prompt_messages = self._organize_prompt_messages()
  69. self.recalc_llm_max_tokens(self.model_config, prompt_messages)
  70. # invoke model
  71. chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
  72. prompt_messages=prompt_messages,
  73. model_parameters=app_generate_entity.model_conf.parameters,
  74. tools=prompt_messages_tools,
  75. stop=app_generate_entity.model_conf.stop,
  76. stream=self.stream_tool_call,
  77. user=self.user_id,
  78. callbacks=[],
  79. )
  80. tool_calls: list[tuple[str, str, dict[str, Any]]] = []
  81. # save full response
  82. response = ""
  83. # save tool call names and inputs
  84. tool_call_names = ""
  85. tool_call_inputs = ""
  86. current_llm_usage = None
  87. if self.stream_tool_call:
  88. is_first_chunk = True
  89. for chunk in chunks:
  90. if is_first_chunk:
  91. self.queue_manager.publish(
  92. QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
  93. )
  94. is_first_chunk = False
  95. # check if there is any tool call
  96. if self.check_tool_calls(chunk):
  97. function_call_state = True
  98. tool_calls.extend(self.extract_tool_calls(chunk))
  99. tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
  100. try:
  101. tool_call_inputs = json.dumps(
  102. {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
  103. )
  104. except json.JSONDecodeError as e:
  105. # ensure ascii to avoid encoding error
  106. tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
  107. if chunk.delta.message and chunk.delta.message.content:
  108. if isinstance(chunk.delta.message.content, list):
  109. for content in chunk.delta.message.content:
  110. response += content.data
  111. else:
  112. response += chunk.delta.message.content
  113. if chunk.delta.usage:
  114. increase_usage(llm_usage, chunk.delta.usage)
  115. current_llm_usage = chunk.delta.usage
  116. yield chunk
  117. else:
  118. result: LLMResult = chunks
  119. # check if there is any tool call
  120. if self.check_blocking_tool_calls(result):
  121. function_call_state = True
  122. tool_calls.extend(self.extract_blocking_tool_calls(result))
  123. tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
  124. try:
  125. tool_call_inputs = json.dumps(
  126. {tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
  127. )
  128. except json.JSONDecodeError as e:
  129. # ensure ascii to avoid encoding error
  130. tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
  131. if result.usage:
  132. increase_usage(llm_usage, result.usage)
  133. current_llm_usage = result.usage
  134. if result.message and result.message.content:
  135. if isinstance(result.message.content, list):
  136. for content in result.message.content:
  137. response += content.data
  138. else:
  139. response += result.message.content
  140. if not result.message.content:
  141. result.message.content = ""
  142. self.queue_manager.publish(
  143. QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
  144. )
  145. yield LLMResultChunk(
  146. model=model_instance.model,
  147. prompt_messages=result.prompt_messages,
  148. system_fingerprint=result.system_fingerprint,
  149. delta=LLMResultChunkDelta(
  150. index=0,
  151. message=result.message,
  152. usage=result.usage,
  153. ),
  154. )
  155. assistant_message = AssistantPromptMessage(content="", tool_calls=[])
  156. if tool_calls:
  157. assistant_message.tool_calls = [
  158. AssistantPromptMessage.ToolCall(
  159. id=tool_call[0],
  160. type="function",
  161. function=AssistantPromptMessage.ToolCall.ToolCallFunction(
  162. name=tool_call[1], arguments=json.dumps(tool_call[2], ensure_ascii=False)
  163. ),
  164. )
  165. for tool_call in tool_calls
  166. ]
  167. else:
  168. assistant_message.content = response
  169. self._current_thoughts.append(assistant_message)
  170. # save thought
  171. self.save_agent_thought(
  172. agent_thought=agent_thought,
  173. tool_name=tool_call_names,
  174. tool_input=tool_call_inputs,
  175. thought=response,
  176. tool_invoke_meta=None,
  177. observation=None,
  178. answer=response,
  179. messages_ids=[],
  180. llm_usage=current_llm_usage,
  181. )
  182. self.queue_manager.publish(
  183. QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
  184. )
  185. final_answer += response + "\n"
  186. # call tools
  187. tool_responses = []
  188. for tool_call_id, tool_call_name, tool_call_args in tool_calls:
  189. tool_instance = tool_instances.get(tool_call_name)
  190. if not tool_instance:
  191. tool_response = {
  192. "tool_call_id": tool_call_id,
  193. "tool_call_name": tool_call_name,
  194. "tool_response": f"there is not a tool named {tool_call_name}",
  195. "meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict(),
  196. }
  197. else:
  198. # invoke tool
  199. tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
  200. tool=tool_instance,
  201. tool_parameters=tool_call_args,
  202. user_id=self.user_id,
  203. tenant_id=self.tenant_id,
  204. message=self.message,
  205. invoke_from=self.application_generate_entity.invoke_from,
  206. agent_tool_callback=self.agent_callback,
  207. trace_manager=trace_manager,
  208. )
  209. # publish files
  210. for message_file_id, save_as in message_files:
  211. if save_as:
  212. self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as)
  213. # publish message file
  214. self.queue_manager.publish(
  215. QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
  216. )
  217. # add message file ids
  218. message_file_ids.append(message_file_id)
  219. tool_response = {
  220. "tool_call_id": tool_call_id,
  221. "tool_call_name": tool_call_name,
  222. "tool_response": tool_invoke_response,
  223. "meta": tool_invoke_meta.to_dict(),
  224. }
  225. tool_responses.append(tool_response)
  226. if tool_response["tool_response"] is not None:
  227. self._current_thoughts.append(
  228. ToolPromptMessage(
  229. content=tool_response["tool_response"],
  230. tool_call_id=tool_call_id,
  231. name=tool_call_name,
  232. )
  233. )
  234. if len(tool_responses) > 0:
  235. # save agent thought
  236. self.save_agent_thought(
  237. agent_thought=agent_thought,
  238. tool_name=None,
  239. tool_input=None,
  240. thought=None,
  241. tool_invoke_meta={
  242. tool_response["tool_call_name"]: tool_response["meta"] for tool_response in tool_responses
  243. },
  244. observation={
  245. tool_response["tool_call_name"]: tool_response["tool_response"]
  246. for tool_response in tool_responses
  247. },
  248. answer=None,
  249. messages_ids=message_file_ids,
  250. )
  251. self.queue_manager.publish(
  252. QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
  253. )
  254. # update prompt tool
  255. for prompt_tool in prompt_messages_tools:
  256. self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
  257. iteration_step += 1
  258. self.update_db_variables(self.variables_pool, self.db_variables_pool)
  259. # publish end event
  260. self.queue_manager.publish(
  261. QueueMessageEndEvent(
  262. llm_result=LLMResult(
  263. model=model_instance.model,
  264. prompt_messages=prompt_messages,
  265. message=AssistantPromptMessage(content=final_answer),
  266. usage=llm_usage["usage"] or LLMUsage.empty_usage(),
  267. system_fingerprint="",
  268. )
  269. ),
  270. PublishFrom.APPLICATION_MANAGER,
  271. )
  272. def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool:
  273. """
  274. Check if there is any tool call in llm result chunk
  275. """
  276. if llm_result_chunk.delta.message.tool_calls:
  277. return True
  278. return False
  279. def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool:
  280. """
  281. Check if there is any blocking tool call in llm result
  282. """
  283. if llm_result.message.tool_calls:
  284. return True
  285. return False
  286. def extract_tool_calls(
  287. self, llm_result_chunk: LLMResultChunk
  288. ) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
  289. """
  290. Extract tool calls from llm result chunk
  291. Returns:
  292. List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
  293. """
  294. tool_calls = []
  295. for prompt_message in llm_result_chunk.delta.message.tool_calls:
  296. args = {}
  297. if prompt_message.function.arguments != "":
  298. args = json.loads(prompt_message.function.arguments)
  299. tool_calls.append(
  300. (
  301. prompt_message.id,
  302. prompt_message.function.name,
  303. args,
  304. )
  305. )
  306. return tool_calls
  307. def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
  308. """
  309. Extract blocking tool calls from llm result
  310. Returns:
  311. List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
  312. """
  313. tool_calls = []
  314. for prompt_message in llm_result.message.tool_calls:
  315. args = {}
  316. if prompt_message.function.arguments != "":
  317. args = json.loads(prompt_message.function.arguments)
  318. tool_calls.append(
  319. (
  320. prompt_message.id,
  321. prompt_message.function.name,
  322. args,
  323. )
  324. )
  325. return tool_calls
  326. def _init_system_message(
  327. self, prompt_template: str, prompt_messages: Optional[list[PromptMessage]] = None
  328. ) -> list[PromptMessage]:
  329. """
  330. Initialize system message
  331. """
  332. if not prompt_messages and prompt_template:
  333. return [
  334. SystemPromptMessage(content=prompt_template),
  335. ]
  336. if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template:
  337. prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
  338. return prompt_messages
  339. def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
  340. """
  341. Organize user query
  342. """
  343. if self.files:
  344. prompt_message_contents: list[PromptMessageContent] = []
  345. prompt_message_contents.append(TextPromptMessageContent(data=query))
  346. for file_obj in self.files:
  347. prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj))
  348. prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
  349. else:
  350. prompt_messages.append(UserPromptMessage(content=query))
  351. return prompt_messages
  352. def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
  353. """
  354. As for now, gpt supports both fc and vision at the first iteration.
  355. We need to remove the image messages from the prompt messages at the first iteration.
  356. """
  357. prompt_messages = deepcopy(prompt_messages)
  358. for prompt_message in prompt_messages:
  359. if isinstance(prompt_message, UserPromptMessage):
  360. if isinstance(prompt_message.content, list):
  361. prompt_message.content = "\n".join(
  362. [
  363. content.data
  364. if content.type == PromptMessageContentType.TEXT
  365. else "[image]"
  366. if content.type == PromptMessageContentType.IMAGE
  367. else "[file]"
  368. for content in prompt_message.content
  369. ]
  370. )
  371. return prompt_messages
  372. def _organize_prompt_messages(self):
  373. prompt_template = self.app_config.prompt_template.simple_prompt_template or ""
  374. self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
  375. query_prompt_messages = self._organize_user_query(self.query, [])
  376. self.history_prompt_messages = AgentHistoryPromptTransform(
  377. model_config=self.model_config,
  378. prompt_messages=[*query_prompt_messages, *self._current_thoughts],
  379. history_messages=self.history_prompt_messages,
  380. memory=self.memory,
  381. ).get_prompt()
  382. prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts]
  383. if len(self._current_thoughts) != 0:
  384. # clear messages after the first iteration
  385. prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
  386. return prompt_messages