fc_agent_runner.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447
  1. import json
  2. import logging
  3. from collections.abc import Generator
  4. from copy import deepcopy
  5. from typing import Any, Union
  6. from core.agent.base_agent_runner import BaseAgentRunner
  7. from core.app.apps.base_app_queue_manager import PublishFrom
  8. from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
  9. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
  10. from core.model_runtime.entities.message_entities import (
  11. AssistantPromptMessage,
  12. PromptMessage,
  13. PromptMessageContentType,
  14. SystemPromptMessage,
  15. TextPromptMessageContent,
  16. ToolPromptMessage,
  17. UserPromptMessage,
  18. )
  19. from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
  20. from core.tools.entities.tool_entities import ToolInvokeMeta
  21. from core.tools.tool_engine import ToolEngine
  22. from models.model import Message
  23. logger = logging.getLogger(__name__)
  24. class FunctionCallAgentRunner(BaseAgentRunner):
  25. def run(self,
  26. message: Message, query: str, **kwargs: Any
  27. ) -> Generator[LLMResultChunk, None, None]:
  28. """
  29. Run FunctionCall agent application
  30. """
  31. self.query = query
  32. app_generate_entity = self.application_generate_entity
  33. app_config = self.app_config
  34. # convert tools into ModelRuntime Tool format
  35. tool_instances, prompt_messages_tools = self._init_prompt_tools()
  36. iteration_step = 1
  37. max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
  38. # continue to run until there is not any tool call
  39. function_call_state = True
  40. llm_usage = {
  41. 'usage': None
  42. }
  43. final_answer = ''
  44. # get tracing instance
  45. trace_manager = app_generate_entity.trace_manager
  46. def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
  47. if not final_llm_usage_dict['usage']:
  48. final_llm_usage_dict['usage'] = usage
  49. else:
  50. llm_usage = final_llm_usage_dict['usage']
  51. llm_usage.prompt_tokens += usage.prompt_tokens
  52. llm_usage.completion_tokens += usage.completion_tokens
  53. llm_usage.prompt_price += usage.prompt_price
  54. llm_usage.completion_price += usage.completion_price
  55. llm_usage.total_price += usage.total_price
  56. model_instance = self.model_instance
  57. while function_call_state and iteration_step <= max_iteration_steps:
  58. function_call_state = False
  59. if iteration_step == max_iteration_steps:
  60. # the last iteration, remove all tools
  61. prompt_messages_tools = []
  62. message_file_ids = []
  63. agent_thought = self.create_agent_thought(
  64. message_id=message.id,
  65. message='',
  66. tool_name='',
  67. tool_input='',
  68. messages_ids=message_file_ids
  69. )
  70. # recalc llm max tokens
  71. prompt_messages = self._organize_prompt_messages()
  72. self.recalc_llm_max_tokens(self.model_config, prompt_messages)
  73. # invoke model
  74. chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
  75. prompt_messages=prompt_messages,
  76. model_parameters=app_generate_entity.model_conf.parameters,
  77. tools=prompt_messages_tools,
  78. stop=app_generate_entity.model_conf.stop,
  79. stream=self.stream_tool_call,
  80. user=self.user_id,
  81. callbacks=[],
  82. )
  83. tool_calls: list[tuple[str, str, dict[str, Any]]] = []
  84. # save full response
  85. response = ''
  86. # save tool call names and inputs
  87. tool_call_names = ''
  88. tool_call_inputs = ''
  89. current_llm_usage = None
  90. if self.stream_tool_call:
  91. is_first_chunk = True
  92. for chunk in chunks:
  93. if is_first_chunk:
  94. self.queue_manager.publish(QueueAgentThoughtEvent(
  95. agent_thought_id=agent_thought.id
  96. ), PublishFrom.APPLICATION_MANAGER)
  97. is_first_chunk = False
  98. # check if there is any tool call
  99. if self.check_tool_calls(chunk):
  100. function_call_state = True
  101. tool_calls.extend(self.extract_tool_calls(chunk))
  102. tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
  103. try:
  104. tool_call_inputs = json.dumps({
  105. tool_call[1]: tool_call[2] for tool_call in tool_calls
  106. }, ensure_ascii=False)
  107. except json.JSONDecodeError as e:
  108. # ensure ascii to avoid encoding error
  109. tool_call_inputs = json.dumps({
  110. tool_call[1]: tool_call[2] for tool_call in tool_calls
  111. })
  112. if chunk.delta.message and chunk.delta.message.content:
  113. if isinstance(chunk.delta.message.content, list):
  114. for content in chunk.delta.message.content:
  115. response += content.data
  116. else:
  117. response += chunk.delta.message.content
  118. if chunk.delta.usage:
  119. increase_usage(llm_usage, chunk.delta.usage)
  120. current_llm_usage = chunk.delta.usage
  121. yield chunk
  122. else:
  123. result: LLMResult = chunks
  124. # check if there is any tool call
  125. if self.check_blocking_tool_calls(result):
  126. function_call_state = True
  127. tool_calls.extend(self.extract_blocking_tool_calls(result))
  128. tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
  129. try:
  130. tool_call_inputs = json.dumps({
  131. tool_call[1]: tool_call[2] for tool_call in tool_calls
  132. }, ensure_ascii=False)
  133. except json.JSONDecodeError as e:
  134. # ensure ascii to avoid encoding error
  135. tool_call_inputs = json.dumps({
  136. tool_call[1]: tool_call[2] for tool_call in tool_calls
  137. })
  138. if result.usage:
  139. increase_usage(llm_usage, result.usage)
  140. current_llm_usage = result.usage
  141. if result.message and result.message.content:
  142. if isinstance(result.message.content, list):
  143. for content in result.message.content:
  144. response += content.data
  145. else:
  146. response += result.message.content
  147. if not result.message.content:
  148. result.message.content = ''
  149. self.queue_manager.publish(QueueAgentThoughtEvent(
  150. agent_thought_id=agent_thought.id
  151. ), PublishFrom.APPLICATION_MANAGER)
  152. yield LLMResultChunk(
  153. model=model_instance.model,
  154. prompt_messages=result.prompt_messages,
  155. system_fingerprint=result.system_fingerprint,
  156. delta=LLMResultChunkDelta(
  157. index=0,
  158. message=result.message,
  159. usage=result.usage,
  160. )
  161. )
  162. assistant_message = AssistantPromptMessage(
  163. content='',
  164. tool_calls=[]
  165. )
  166. if tool_calls:
  167. assistant_message.tool_calls=[
  168. AssistantPromptMessage.ToolCall(
  169. id=tool_call[0],
  170. type='function',
  171. function=AssistantPromptMessage.ToolCall.ToolCallFunction(
  172. name=tool_call[1],
  173. arguments=json.dumps(tool_call[2], ensure_ascii=False)
  174. )
  175. ) for tool_call in tool_calls
  176. ]
  177. else:
  178. assistant_message.content = response
  179. self._current_thoughts.append(assistant_message)
  180. # save thought
  181. self.save_agent_thought(
  182. agent_thought=agent_thought,
  183. tool_name=tool_call_names,
  184. tool_input=tool_call_inputs,
  185. thought=response,
  186. tool_invoke_meta=None,
  187. observation=None,
  188. answer=response,
  189. messages_ids=[],
  190. llm_usage=current_llm_usage
  191. )
  192. self.queue_manager.publish(QueueAgentThoughtEvent(
  193. agent_thought_id=agent_thought.id
  194. ), PublishFrom.APPLICATION_MANAGER)
  195. final_answer += response + '\n'
  196. # call tools
  197. tool_responses = []
  198. for tool_call_id, tool_call_name, tool_call_args in tool_calls:
  199. tool_instance = tool_instances.get(tool_call_name)
  200. if not tool_instance:
  201. tool_response = {
  202. "tool_call_id": tool_call_id,
  203. "tool_call_name": tool_call_name,
  204. "tool_response": f"there is not a tool named {tool_call_name}",
  205. "meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict()
  206. }
  207. else:
  208. # invoke tool
  209. tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
  210. tool=tool_instance,
  211. tool_parameters=tool_call_args,
  212. user_id=self.user_id,
  213. tenant_id=self.tenant_id,
  214. message=self.message,
  215. invoke_from=self.application_generate_entity.invoke_from,
  216. agent_tool_callback=self.agent_callback,
  217. trace_manager=trace_manager,
  218. )
  219. # publish files
  220. for message_file_id, save_as in message_files:
  221. if save_as:
  222. self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as)
  223. # publish message file
  224. self.queue_manager.publish(QueueMessageFileEvent(
  225. message_file_id=message_file_id
  226. ), PublishFrom.APPLICATION_MANAGER)
  227. # add message file ids
  228. message_file_ids.append(message_file_id)
  229. tool_response = {
  230. "tool_call_id": tool_call_id,
  231. "tool_call_name": tool_call_name,
  232. "tool_response": tool_invoke_response,
  233. "meta": tool_invoke_meta.to_dict()
  234. }
  235. tool_responses.append(tool_response)
  236. if tool_response['tool_response'] is not None:
  237. self._current_thoughts.append(
  238. ToolPromptMessage(
  239. content=tool_response['tool_response'],
  240. tool_call_id=tool_call_id,
  241. name=tool_call_name,
  242. )
  243. )
  244. if len(tool_responses) > 0:
  245. # save agent thought
  246. self.save_agent_thought(
  247. agent_thought=agent_thought,
  248. tool_name=None,
  249. tool_input=None,
  250. thought=None,
  251. tool_invoke_meta={
  252. tool_response['tool_call_name']: tool_response['meta']
  253. for tool_response in tool_responses
  254. },
  255. observation={
  256. tool_response['tool_call_name']: tool_response['tool_response']
  257. for tool_response in tool_responses
  258. },
  259. answer=None,
  260. messages_ids=message_file_ids
  261. )
  262. self.queue_manager.publish(QueueAgentThoughtEvent(
  263. agent_thought_id=agent_thought.id
  264. ), PublishFrom.APPLICATION_MANAGER)
  265. # update prompt tool
  266. for prompt_tool in prompt_messages_tools:
  267. self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
  268. iteration_step += 1
  269. self.update_db_variables(self.variables_pool, self.db_variables_pool)
  270. # publish end event
  271. self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
  272. model=model_instance.model,
  273. prompt_messages=prompt_messages,
  274. message=AssistantPromptMessage(
  275. content=final_answer
  276. ),
  277. usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(),
  278. system_fingerprint=''
  279. )), PublishFrom.APPLICATION_MANAGER)
  280. def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool:
  281. """
  282. Check if there is any tool call in llm result chunk
  283. """
  284. if llm_result_chunk.delta.message.tool_calls:
  285. return True
  286. return False
  287. def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool:
  288. """
  289. Check if there is any blocking tool call in llm result
  290. """
  291. if llm_result.message.tool_calls:
  292. return True
  293. return False
  294. def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
  295. """
  296. Extract tool calls from llm result chunk
  297. Returns:
  298. List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
  299. """
  300. tool_calls = []
  301. for prompt_message in llm_result_chunk.delta.message.tool_calls:
  302. args = {}
  303. if prompt_message.function.arguments != '':
  304. args = json.loads(prompt_message.function.arguments)
  305. tool_calls.append((
  306. prompt_message.id,
  307. prompt_message.function.name,
  308. args,
  309. ))
  310. return tool_calls
  311. def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
  312. """
  313. Extract blocking tool calls from llm result
  314. Returns:
  315. List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
  316. """
  317. tool_calls = []
  318. for prompt_message in llm_result.message.tool_calls:
  319. args = {}
  320. if prompt_message.function.arguments != '':
  321. args = json.loads(prompt_message.function.arguments)
  322. tool_calls.append((
  323. prompt_message.id,
  324. prompt_message.function.name,
  325. args,
  326. ))
  327. return tool_calls
  328. def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
  329. """
  330. Initialize system message
  331. """
  332. if not prompt_messages and prompt_template:
  333. return [
  334. SystemPromptMessage(content=prompt_template),
  335. ]
  336. if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template:
  337. prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
  338. return prompt_messages
  339. def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
  340. """
  341. Organize user query
  342. """
  343. if self.files:
  344. prompt_message_contents = [TextPromptMessageContent(data=query)]
  345. for file_obj in self.files:
  346. prompt_message_contents.append(file_obj.prompt_message_content)
  347. prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
  348. else:
  349. prompt_messages.append(UserPromptMessage(content=query))
  350. return prompt_messages
  351. def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
  352. """
  353. As for now, gpt supports both fc and vision at the first iteration.
  354. We need to remove the image messages from the prompt messages at the first iteration.
  355. """
  356. prompt_messages = deepcopy(prompt_messages)
  357. for prompt_message in prompt_messages:
  358. if isinstance(prompt_message, UserPromptMessage):
  359. if isinstance(prompt_message.content, list):
  360. prompt_message.content = '\n'.join([
  361. content.data if content.type == PromptMessageContentType.TEXT else
  362. '[image]' if content.type == PromptMessageContentType.IMAGE else
  363. '[file]'
  364. for content in prompt_message.content
  365. ])
  366. return prompt_messages
  367. def _organize_prompt_messages(self):
  368. prompt_template = self.app_config.prompt_template.simple_prompt_template or ''
  369. self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
  370. query_prompt_messages = self._organize_user_query(self.query, [])
  371. self.history_prompt_messages = AgentHistoryPromptTransform(
  372. model_config=self.model_config,
  373. prompt_messages=[*query_prompt_messages, *self._current_thoughts],
  374. history_messages=self.history_prompt_messages,
  375. memory=self.memory
  376. ).get_prompt()
  377. prompt_messages = [
  378. *self.history_prompt_messages,
  379. *query_prompt_messages,
  380. *self._current_thoughts
  381. ]
  382. if len(self._current_thoughts) != 0:
  383. # clear messages after the first iteration
  384. prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
  385. return prompt_messages