fc_agent_runner.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431
  1. import json
  2. import logging
  3. from collections.abc import Generator
  4. from copy import deepcopy
  5. from typing import Any, Union
  6. from core.agent.base_agent_runner import BaseAgentRunner
  7. from core.app.apps.base_app_queue_manager import PublishFrom
  8. from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
  9. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
  10. from core.model_runtime.entities.message_entities import (
  11. AssistantPromptMessage,
  12. PromptMessage,
  13. PromptMessageContentType,
  14. SystemPromptMessage,
  15. TextPromptMessageContent,
  16. ToolPromptMessage,
  17. UserPromptMessage,
  18. )
  19. from core.tools.entities.tool_entities import ToolInvokeMeta
  20. from core.tools.tool_engine import ToolEngine
  21. from models.model import Message
  22. logger = logging.getLogger(__name__)
  23. class FunctionCallAgentRunner(BaseAgentRunner):
  24. def run(self,
  25. message: Message, query: str, **kwargs: Any
  26. ) -> Generator[LLMResultChunk, None, None]:
  27. """
  28. Run FunctionCall agent application
  29. """
  30. app_generate_entity = self.application_generate_entity
  31. app_config = self.app_config
  32. prompt_template = app_config.prompt_template.simple_prompt_template or ''
  33. prompt_messages = self.history_prompt_messages
  34. prompt_messages = self._init_system_message(prompt_template, prompt_messages)
  35. prompt_messages = self._organize_user_query(query, prompt_messages)
  36. # convert tools into ModelRuntime Tool format
  37. tool_instances, prompt_messages_tools = self._init_prompt_tools()
  38. iteration_step = 1
  39. max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
  40. # continue to run until there is not any tool call
  41. function_call_state = True
  42. llm_usage = {
  43. 'usage': None
  44. }
  45. final_answer = ''
  46. def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
  47. if not final_llm_usage_dict['usage']:
  48. final_llm_usage_dict['usage'] = usage
  49. else:
  50. llm_usage = final_llm_usage_dict['usage']
  51. llm_usage.prompt_tokens += usage.prompt_tokens
  52. llm_usage.completion_tokens += usage.completion_tokens
  53. llm_usage.prompt_price += usage.prompt_price
  54. llm_usage.completion_price += usage.completion_price
  55. model_instance = self.model_instance
  56. while function_call_state and iteration_step <= max_iteration_steps:
  57. function_call_state = False
  58. if iteration_step == max_iteration_steps:
  59. # the last iteration, remove all tools
  60. prompt_messages_tools = []
  61. message_file_ids = []
  62. agent_thought = self.create_agent_thought(
  63. message_id=message.id,
  64. message='',
  65. tool_name='',
  66. tool_input='',
  67. messages_ids=message_file_ids
  68. )
  69. # recalc llm max tokens
  70. self.recalc_llm_max_tokens(self.model_config, prompt_messages)
  71. # invoke model
  72. chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
  73. prompt_messages=prompt_messages,
  74. model_parameters=app_generate_entity.model_config.parameters,
  75. tools=prompt_messages_tools,
  76. stop=app_generate_entity.model_config.stop,
  77. stream=self.stream_tool_call,
  78. user=self.user_id,
  79. callbacks=[],
  80. )
  81. tool_calls: list[tuple[str, str, dict[str, Any]]] = []
  82. # save full response
  83. response = ''
  84. # save tool call names and inputs
  85. tool_call_names = ''
  86. tool_call_inputs = ''
  87. current_llm_usage = None
  88. if self.stream_tool_call:
  89. is_first_chunk = True
  90. for chunk in chunks:
  91. if is_first_chunk:
  92. self.queue_manager.publish(QueueAgentThoughtEvent(
  93. agent_thought_id=agent_thought.id
  94. ), PublishFrom.APPLICATION_MANAGER)
  95. is_first_chunk = False
  96. # check if there is any tool call
  97. if self.check_tool_calls(chunk):
  98. function_call_state = True
  99. tool_calls.extend(self.extract_tool_calls(chunk))
  100. tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
  101. try:
  102. tool_call_inputs = json.dumps({
  103. tool_call[1]: tool_call[2] for tool_call in tool_calls
  104. }, ensure_ascii=False)
  105. except json.JSONDecodeError as e:
  106. # ensure ascii to avoid encoding error
  107. tool_call_inputs = json.dumps({
  108. tool_call[1]: tool_call[2] for tool_call in tool_calls
  109. })
  110. if chunk.delta.message and chunk.delta.message.content:
  111. if isinstance(chunk.delta.message.content, list):
  112. for content in chunk.delta.message.content:
  113. response += content.data
  114. else:
  115. response += chunk.delta.message.content
  116. if chunk.delta.usage:
  117. increase_usage(llm_usage, chunk.delta.usage)
  118. current_llm_usage = chunk.delta.usage
  119. yield chunk
  120. else:
  121. result: LLMResult = chunks
  122. # check if there is any tool call
  123. if self.check_blocking_tool_calls(result):
  124. function_call_state = True
  125. tool_calls.extend(self.extract_blocking_tool_calls(result))
  126. tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
  127. try:
  128. tool_call_inputs = json.dumps({
  129. tool_call[1]: tool_call[2] for tool_call in tool_calls
  130. }, ensure_ascii=False)
  131. except json.JSONDecodeError as e:
  132. # ensure ascii to avoid encoding error
  133. tool_call_inputs = json.dumps({
  134. tool_call[1]: tool_call[2] for tool_call in tool_calls
  135. })
  136. if result.usage:
  137. increase_usage(llm_usage, result.usage)
  138. current_llm_usage = result.usage
  139. if result.message and result.message.content:
  140. if isinstance(result.message.content, list):
  141. for content in result.message.content:
  142. response += content.data
  143. else:
  144. response += result.message.content
  145. if not result.message.content:
  146. result.message.content = ''
  147. self.queue_manager.publish(QueueAgentThoughtEvent(
  148. agent_thought_id=agent_thought.id
  149. ), PublishFrom.APPLICATION_MANAGER)
  150. yield LLMResultChunk(
  151. model=model_instance.model,
  152. prompt_messages=result.prompt_messages,
  153. system_fingerprint=result.system_fingerprint,
  154. delta=LLMResultChunkDelta(
  155. index=0,
  156. message=result.message,
  157. usage=result.usage,
  158. )
  159. )
  160. assistant_message = AssistantPromptMessage(
  161. content='',
  162. tool_calls=[]
  163. )
  164. if tool_calls:
  165. assistant_message.tool_calls=[
  166. AssistantPromptMessage.ToolCall(
  167. id=tool_call[0],
  168. type='function',
  169. function=AssistantPromptMessage.ToolCall.ToolCallFunction(
  170. name=tool_call[1],
  171. arguments=json.dumps(tool_call[2], ensure_ascii=False)
  172. )
  173. ) for tool_call in tool_calls
  174. ]
  175. else:
  176. assistant_message.content = response
  177. prompt_messages.append(assistant_message)
  178. # save thought
  179. self.save_agent_thought(
  180. agent_thought=agent_thought,
  181. tool_name=tool_call_names,
  182. tool_input=tool_call_inputs,
  183. thought=response,
  184. tool_invoke_meta=None,
  185. observation=None,
  186. answer=response,
  187. messages_ids=[],
  188. llm_usage=current_llm_usage
  189. )
  190. self.queue_manager.publish(QueueAgentThoughtEvent(
  191. agent_thought_id=agent_thought.id
  192. ), PublishFrom.APPLICATION_MANAGER)
  193. final_answer += response + '\n'
  194. # call tools
  195. tool_responses = []
  196. for tool_call_id, tool_call_name, tool_call_args in tool_calls:
  197. tool_instance = tool_instances.get(tool_call_name)
  198. if not tool_instance:
  199. tool_response = {
  200. "tool_call_id": tool_call_id,
  201. "tool_call_name": tool_call_name,
  202. "tool_response": f"there is not a tool named {tool_call_name}",
  203. "meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict()
  204. }
  205. else:
  206. # invoke tool
  207. tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
  208. tool=tool_instance,
  209. tool_parameters=tool_call_args,
  210. user_id=self.user_id,
  211. tenant_id=self.tenant_id,
  212. message=self.message,
  213. invoke_from=self.application_generate_entity.invoke_from,
  214. agent_tool_callback=self.agent_callback,
  215. )
  216. # publish files
  217. for message_file, save_as in message_files:
  218. if save_as:
  219. self.variables_pool.set_file(tool_name=tool_call_name, value=message_file.id, name=save_as)
  220. # publish message file
  221. self.queue_manager.publish(QueueMessageFileEvent(
  222. message_file_id=message_file.id
  223. ), PublishFrom.APPLICATION_MANAGER)
  224. # add message file ids
  225. message_file_ids.append(message_file.id)
  226. tool_response = {
  227. "tool_call_id": tool_call_id,
  228. "tool_call_name": tool_call_name,
  229. "tool_response": tool_invoke_response,
  230. "meta": tool_invoke_meta.to_dict()
  231. }
  232. tool_responses.append(tool_response)
  233. prompt_messages = self._organize_assistant_message(
  234. tool_call_id=tool_call_id,
  235. tool_call_name=tool_call_name,
  236. tool_response=tool_response['tool_response'],
  237. prompt_messages=prompt_messages,
  238. )
  239. if len(tool_responses) > 0:
  240. # save agent thought
  241. self.save_agent_thought(
  242. agent_thought=agent_thought,
  243. tool_name=None,
  244. tool_input=None,
  245. thought=None,
  246. tool_invoke_meta={
  247. tool_response['tool_call_name']: tool_response['meta']
  248. for tool_response in tool_responses
  249. },
  250. observation={
  251. tool_response['tool_call_name']: tool_response['tool_response']
  252. for tool_response in tool_responses
  253. },
  254. answer=None,
  255. messages_ids=message_file_ids
  256. )
  257. self.queue_manager.publish(QueueAgentThoughtEvent(
  258. agent_thought_id=agent_thought.id
  259. ), PublishFrom.APPLICATION_MANAGER)
  260. # update prompt tool
  261. for prompt_tool in prompt_messages_tools:
  262. self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
  263. iteration_step += 1
  264. prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
  265. self.update_db_variables(self.variables_pool, self.db_variables_pool)
  266. # publish end event
  267. self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
  268. model=model_instance.model,
  269. prompt_messages=prompt_messages,
  270. message=AssistantPromptMessage(
  271. content=final_answer
  272. ),
  273. usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(),
  274. system_fingerprint=''
  275. )), PublishFrom.APPLICATION_MANAGER)
  276. def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool:
  277. """
  278. Check if there is any tool call in llm result chunk
  279. """
  280. if llm_result_chunk.delta.message.tool_calls:
  281. return True
  282. return False
  283. def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool:
  284. """
  285. Check if there is any blocking tool call in llm result
  286. """
  287. if llm_result.message.tool_calls:
  288. return True
  289. return False
  290. def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
  291. """
  292. Extract tool calls from llm result chunk
  293. Returns:
  294. List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
  295. """
  296. tool_calls = []
  297. for prompt_message in llm_result_chunk.delta.message.tool_calls:
  298. tool_calls.append((
  299. prompt_message.id,
  300. prompt_message.function.name,
  301. json.loads(prompt_message.function.arguments),
  302. ))
  303. return tool_calls
  304. def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
  305. """
  306. Extract blocking tool calls from llm result
  307. Returns:
  308. List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
  309. """
  310. tool_calls = []
  311. for prompt_message in llm_result.message.tool_calls:
  312. tool_calls.append((
  313. prompt_message.id,
  314. prompt_message.function.name,
  315. json.loads(prompt_message.function.arguments),
  316. ))
  317. return tool_calls
  318. def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
  319. """
  320. Initialize system message
  321. """
  322. if not prompt_messages and prompt_template:
  323. return [
  324. SystemPromptMessage(content=prompt_template),
  325. ]
  326. if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template:
  327. prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
  328. return prompt_messages
  329. def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
  330. """
  331. Organize user query
  332. """
  333. if self.files:
  334. prompt_message_contents = [TextPromptMessageContent(data=query)]
  335. for file_obj in self.files:
  336. prompt_message_contents.append(file_obj.prompt_message_content)
  337. prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
  338. else:
  339. prompt_messages.append(UserPromptMessage(content=query))
  340. return prompt_messages
  341. def _organize_assistant_message(self, tool_call_id: str = None, tool_call_name: str = None, tool_response: str = None,
  342. prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
  343. """
  344. Organize assistant message
  345. """
  346. prompt_messages = deepcopy(prompt_messages)
  347. if tool_response is not None:
  348. prompt_messages.append(
  349. ToolPromptMessage(
  350. content=tool_response,
  351. tool_call_id=tool_call_id,
  352. name=tool_call_name,
  353. )
  354. )
  355. return prompt_messages
  356. def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
  357. """
  358. As for now, gpt supports both fc and vision at the first iteration.
  359. We need to remove the image messages from the prompt messages at the first iteration.
  360. """
  361. prompt_messages = deepcopy(prompt_messages)
  362. for prompt_message in prompt_messages:
  363. if isinstance(prompt_message, UserPromptMessage):
  364. if isinstance(prompt_message.content, list):
  365. prompt_message.content = '\n'.join([
  366. content.data if content.type == PromptMessageContentType.TEXT else
  367. '[image]' if content.type == PromptMessageContentType.IMAGE else
  368. '[file]'
  369. for content in prompt_message.content
  370. ])
  371. return prompt_messages