assistant_cot_runner.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642
  1. import json
  2. import re
  3. from collections.abc import Generator
  4. from typing import Literal, Union
  5. from core.application_queue_manager import PublishFrom
  6. from core.entities.application_entities import AgentPromptEntity, AgentScratchpadUnit
  7. from core.features.assistant_base_runner import BaseAssistantApplicationRunner
  8. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
  9. from core.model_runtime.entities.message_entities import (
  10. AssistantPromptMessage,
  11. PromptMessage,
  12. PromptMessageTool,
  13. SystemPromptMessage,
  14. ToolPromptMessage,
  15. UserPromptMessage,
  16. )
  17. from core.model_runtime.utils.encoders import jsonable_encoder
  18. from core.tools.errors import (
  19. ToolInvokeError,
  20. ToolNotFoundError,
  21. ToolNotSupportedError,
  22. ToolParameterValidationError,
  23. ToolProviderCredentialValidationError,
  24. ToolProviderNotFoundError,
  25. )
  26. from models.model import Conversation, Message
  27. class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
  28. def run(self, conversation: Conversation,
  29. message: Message,
  30. query: str,
  31. inputs: dict[str, str],
  32. ) -> Union[Generator, LLMResult]:
  33. """
  34. Run Cot agent application
  35. """
  36. app_orchestration_config = self.app_orchestration_config
  37. self._repack_app_orchestration_config(app_orchestration_config)
  38. agent_scratchpad: list[AgentScratchpadUnit] = []
  39. self._init_agent_scratchpad(agent_scratchpad, self.history_prompt_messages)
  40. # check model mode
  41. if self.app_orchestration_config.model_config.mode == "completion":
  42. # TODO: stop words
  43. if 'Observation' not in app_orchestration_config.model_config.stop:
  44. app_orchestration_config.model_config.stop.append('Observation')
  45. # override inputs
  46. inputs = inputs or {}
  47. instruction = self.app_orchestration_config.prompt_template.simple_prompt_template
  48. instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
  49. iteration_step = 1
  50. max_iteration_steps = min(self.app_orchestration_config.agent.max_iteration, 5) + 1
  51. prompt_messages = self.history_prompt_messages
  52. # convert tools into ModelRuntime Tool format
  53. prompt_messages_tools: list[PromptMessageTool] = []
  54. tool_instances = {}
  55. for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []:
  56. try:
  57. prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
  58. except Exception:
  59. # api tool may be deleted
  60. continue
  61. # save tool entity
  62. tool_instances[tool.tool_name] = tool_entity
  63. # save prompt tool
  64. prompt_messages_tools.append(prompt_tool)
  65. # convert dataset tools into ModelRuntime Tool format
  66. for dataset_tool in self.dataset_tools:
  67. prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool)
  68. # save prompt tool
  69. prompt_messages_tools.append(prompt_tool)
  70. # save tool entity
  71. tool_instances[dataset_tool.identity.name] = dataset_tool
  72. function_call_state = True
  73. llm_usage = {
  74. 'usage': None
  75. }
  76. final_answer = ''
  77. def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
  78. if not final_llm_usage_dict['usage']:
  79. final_llm_usage_dict['usage'] = usage
  80. else:
  81. llm_usage = final_llm_usage_dict['usage']
  82. llm_usage.prompt_tokens += usage.prompt_tokens
  83. llm_usage.completion_tokens += usage.completion_tokens
  84. llm_usage.prompt_price += usage.prompt_price
  85. llm_usage.completion_price += usage.completion_price
  86. model_instance = self.model_instance
  87. while function_call_state and iteration_step <= max_iteration_steps:
  88. # continue to run until there is not any tool call
  89. function_call_state = False
  90. if iteration_step == max_iteration_steps:
  91. # the last iteration, remove all tools
  92. prompt_messages_tools = []
  93. message_file_ids = []
  94. agent_thought = self.create_agent_thought(
  95. message_id=message.id,
  96. message='',
  97. tool_name='',
  98. tool_input='',
  99. messages_ids=message_file_ids
  100. )
  101. if iteration_step > 1:
  102. self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
  103. # update prompt messages
  104. prompt_messages = self._organize_cot_prompt_messages(
  105. mode=app_orchestration_config.model_config.mode,
  106. prompt_messages=prompt_messages,
  107. tools=prompt_messages_tools,
  108. agent_scratchpad=agent_scratchpad,
  109. agent_prompt_message=app_orchestration_config.agent.prompt,
  110. instruction=instruction,
  111. input=query
  112. )
  113. # recale llm max tokens
  114. self.recale_llm_max_tokens(self.model_config, prompt_messages)
  115. # invoke model
  116. llm_result: LLMResult = model_instance.invoke_llm(
  117. prompt_messages=prompt_messages,
  118. model_parameters=app_orchestration_config.model_config.parameters,
  119. tools=[],
  120. stop=app_orchestration_config.model_config.stop,
  121. stream=False,
  122. user=self.user_id,
  123. callbacks=[],
  124. )
  125. # check llm result
  126. if not llm_result:
  127. raise ValueError("failed to invoke llm")
  128. # get scratchpad
  129. scratchpad = self._extract_response_scratchpad(llm_result.message.content)
  130. agent_scratchpad.append(scratchpad)
  131. # get llm usage
  132. if llm_result.usage:
  133. increase_usage(llm_usage, llm_result.usage)
  134. # publish agent thought if it's first iteration
  135. if iteration_step == 1:
  136. self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
  137. self.save_agent_thought(agent_thought=agent_thought,
  138. tool_name=scratchpad.action.action_name if scratchpad.action else '',
  139. tool_input=scratchpad.action.action_input if scratchpad.action else '',
  140. thought=scratchpad.thought,
  141. observation='',
  142. answer=llm_result.message.content,
  143. messages_ids=[],
  144. llm_usage=llm_result.usage)
  145. if scratchpad.action and scratchpad.action.action_name.lower() != "final answer":
  146. self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
  147. # publish agent thought if it's not empty and there is a action
  148. if scratchpad.thought and scratchpad.action:
  149. # check if final answer
  150. if not scratchpad.action.action_name.lower() == "final answer":
  151. yield LLMResultChunk(
  152. model=model_instance.model,
  153. prompt_messages=prompt_messages,
  154. delta=LLMResultChunkDelta(
  155. index=0,
  156. message=AssistantPromptMessage(
  157. content=scratchpad.thought
  158. ),
  159. usage=llm_result.usage,
  160. ),
  161. system_fingerprint=''
  162. )
  163. if not scratchpad.action:
  164. # failed to extract action, return final answer directly
  165. final_answer = scratchpad.agent_response or ''
  166. else:
  167. if scratchpad.action.action_name.lower() == "final answer":
  168. # action is final answer, return final answer directly
  169. try:
  170. final_answer = scratchpad.action.action_input if \
  171. isinstance(scratchpad.action.action_input, str) else \
  172. json.dumps(scratchpad.action.action_input)
  173. except json.JSONDecodeError:
  174. final_answer = f'{scratchpad.action.action_input}'
  175. else:
  176. function_call_state = True
  177. # action is tool call, invoke tool
  178. tool_call_name = scratchpad.action.action_name
  179. tool_call_args = scratchpad.action.action_input
  180. tool_instance = tool_instances.get(tool_call_name)
  181. if not tool_instance:
  182. answer = f"there is not a tool named {tool_call_name}"
  183. self.save_agent_thought(agent_thought=agent_thought,
  184. tool_name='',
  185. tool_input='',
  186. thought=None,
  187. observation=answer,
  188. answer=answer,
  189. messages_ids=[])
  190. self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
  191. else:
  192. # invoke tool
  193. error_response = None
  194. try:
  195. tool_response = tool_instance.invoke(
  196. user_id=self.user_id,
  197. tool_parameters=tool_call_args if isinstance(tool_call_args, dict) else json.loads(tool_call_args)
  198. )
  199. # transform tool response to llm friendly response
  200. tool_response = self.transform_tool_invoke_messages(tool_response)
  201. # extract binary data from tool invoke message
  202. binary_files = self.extract_tool_response_binary(tool_response)
  203. # create message file
  204. message_files = self.create_message_files(binary_files)
  205. # publish files
  206. for message_file, save_as in message_files:
  207. if save_as:
  208. self.variables_pool.set_file(tool_name=tool_call_name,
  209. value=message_file.id,
  210. name=save_as)
  211. self.queue_manager.publish_message_file(message_file, PublishFrom.APPLICATION_MANAGER)
  212. message_file_ids = [message_file.id for message_file, _ in message_files]
  213. except ToolProviderCredentialValidationError as e:
  214. error_response = "Please check your tool provider credentials"
  215. except (
  216. ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError
  217. ) as e:
  218. error_response = f"there is not a tool named {tool_call_name}"
  219. except (
  220. ToolParameterValidationError
  221. ) as e:
  222. error_response = f"tool parameters validation error: {e}, please check your tool parameters"
  223. except ToolInvokeError as e:
  224. error_response = f"tool invoke error: {e}"
  225. except Exception as e:
  226. error_response = f"unknown error: {e}"
  227. if error_response:
  228. observation = error_response
  229. else:
  230. observation = self._convert_tool_response_to_str(tool_response)
  231. # save scratchpad
  232. scratchpad.observation = observation
  233. scratchpad.agent_response = llm_result.message.content
  234. # save agent thought
  235. self.save_agent_thought(
  236. agent_thought=agent_thought,
  237. tool_name=tool_call_name,
  238. tool_input=tool_call_args,
  239. thought=None,
  240. observation=observation,
  241. answer=llm_result.message.content,
  242. messages_ids=message_file_ids,
  243. )
  244. self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
  245. # update prompt tool message
  246. for prompt_tool in prompt_messages_tools:
  247. self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
  248. iteration_step += 1
  249. yield LLMResultChunk(
  250. model=model_instance.model,
  251. prompt_messages=prompt_messages,
  252. delta=LLMResultChunkDelta(
  253. index=0,
  254. message=AssistantPromptMessage(
  255. content=final_answer
  256. ),
  257. usage=llm_usage['usage']
  258. ),
  259. system_fingerprint=''
  260. )
  261. # save agent thought
  262. self.save_agent_thought(
  263. agent_thought=agent_thought,
  264. tool_name='',
  265. tool_input='',
  266. thought=final_answer,
  267. observation='',
  268. answer=final_answer,
  269. messages_ids=[]
  270. )
  271. self.update_db_variables(self.variables_pool, self.db_variables_pool)
  272. # publish end event
  273. self.queue_manager.publish_message_end(LLMResult(
  274. model=model_instance.model,
  275. prompt_messages=prompt_messages,
  276. message=AssistantPromptMessage(
  277. content=final_answer
  278. ),
  279. usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(),
  280. system_fingerprint=''
  281. ), PublishFrom.APPLICATION_MANAGER)
  282. def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str:
  283. """
  284. fill in inputs from external data tools
  285. """
  286. for key, value in inputs.items():
  287. try:
  288. instruction = instruction.replace(f'{{{{{key}}}}}', str(value))
  289. except Exception as e:
  290. continue
  291. return instruction
  292. def _init_agent_scratchpad(self,
  293. agent_scratchpad: list[AgentScratchpadUnit],
  294. messages: list[PromptMessage]
  295. ) -> list[AgentScratchpadUnit]:
  296. """
  297. init agent scratchpad
  298. """
  299. current_scratchpad: AgentScratchpadUnit = None
  300. for message in messages:
  301. if isinstance(message, AssistantPromptMessage):
  302. current_scratchpad = AgentScratchpadUnit(
  303. agent_response=message.content,
  304. thought=message.content,
  305. action_str='',
  306. action=None,
  307. observation=None
  308. )
  309. if message.tool_calls:
  310. try:
  311. current_scratchpad.action = AgentScratchpadUnit.Action(
  312. action_name=message.tool_calls[0].function.name,
  313. action_input=json.loads(message.tool_calls[0].function.arguments)
  314. )
  315. except:
  316. pass
  317. agent_scratchpad.append(current_scratchpad)
  318. elif isinstance(message, ToolPromptMessage):
  319. if current_scratchpad:
  320. current_scratchpad.observation = message.content
  321. return agent_scratchpad
  322. def _extract_response_scratchpad(self, content: str) -> AgentScratchpadUnit:
  323. """
  324. extract response from llm response
  325. """
  326. def extra_quotes() -> AgentScratchpadUnit:
  327. agent_response = content
  328. # try to extract all quotes
  329. pattern = re.compile(r'```(.*?)```', re.DOTALL)
  330. quotes = pattern.findall(content)
  331. # try to extract action from end to start
  332. for i in range(len(quotes) - 1, 0, -1):
  333. """
  334. 1. use json load to parse action
  335. 2. use plain text `Action: xxx` to parse action
  336. """
  337. try:
  338. action = json.loads(quotes[i].replace('```', ''))
  339. action_name = action.get("action")
  340. action_input = action.get("action_input")
  341. agent_thought = agent_response.replace(quotes[i], '')
  342. if action_name and action_input:
  343. return AgentScratchpadUnit(
  344. agent_response=content,
  345. thought=agent_thought,
  346. action_str=quotes[i],
  347. action=AgentScratchpadUnit.Action(
  348. action_name=action_name,
  349. action_input=action_input,
  350. )
  351. )
  352. except:
  353. # try to parse action from plain text
  354. action_name = re.findall(r'action: (.*)', quotes[i], re.IGNORECASE)
  355. action_input = re.findall(r'action input: (.*)', quotes[i], re.IGNORECASE)
  356. # delete action from agent response
  357. agent_thought = agent_response.replace(quotes[i], '')
  358. # remove extra quotes
  359. agent_thought = re.sub(r'```(json)*\n*```', '', agent_thought, flags=re.DOTALL)
  360. # remove Action: xxx from agent thought
  361. agent_thought = re.sub(r'Action:.*', '', agent_thought, flags=re.IGNORECASE)
  362. if action_name and action_input:
  363. return AgentScratchpadUnit(
  364. agent_response=content,
  365. thought=agent_thought,
  366. action_str=quotes[i],
  367. action=AgentScratchpadUnit.Action(
  368. action_name=action_name[0],
  369. action_input=action_input[0],
  370. )
  371. )
  372. def extra_json():
  373. agent_response = content
  374. # try to extract all json
  375. structures, pair_match_stack = [], []
  376. started_at, end_at = 0, 0
  377. for i in range(len(content)):
  378. if content[i] == '{':
  379. pair_match_stack.append(i)
  380. if len(pair_match_stack) == 1:
  381. started_at = i
  382. elif content[i] == '}':
  383. begin = pair_match_stack.pop()
  384. if not pair_match_stack:
  385. end_at = i + 1
  386. structures.append((content[begin:i+1], (started_at, end_at)))
  387. # handle the last character
  388. if pair_match_stack:
  389. end_at = len(content)
  390. structures.append((content[pair_match_stack[0]:], (started_at, end_at)))
  391. for i in range(len(structures), 0, -1):
  392. try:
  393. json_content, (started_at, end_at) = structures[i - 1]
  394. action = json.loads(json_content)
  395. action_name = action.get("action")
  396. action_input = action.get("action_input")
  397. # delete json content from agent response
  398. agent_thought = agent_response[:started_at] + agent_response[end_at:]
  399. # remove extra quotes like ```(json)*\n\n```
  400. agent_thought = re.sub(r'```(json)*\n*```', '', agent_thought, flags=re.DOTALL)
  401. # remove Action: xxx from agent thought
  402. agent_thought = re.sub(r'Action:.*', '', agent_thought, flags=re.IGNORECASE)
  403. if action_name and action_input is not None:
  404. return AgentScratchpadUnit(
  405. agent_response=content,
  406. thought=agent_thought,
  407. action_str=json_content,
  408. action=AgentScratchpadUnit.Action(
  409. action_name=action_name,
  410. action_input=action_input,
  411. )
  412. )
  413. except:
  414. pass
  415. agent_scratchpad = extra_quotes()
  416. if agent_scratchpad:
  417. return agent_scratchpad
  418. agent_scratchpad = extra_json()
  419. if agent_scratchpad:
  420. return agent_scratchpad
  421. return AgentScratchpadUnit(
  422. agent_response=content,
  423. thought=content,
  424. action_str='',
  425. action=None
  426. )
  427. def _check_cot_prompt_messages(self, mode: Literal["completion", "chat"],
  428. agent_prompt_message: AgentPromptEntity,
  429. ):
  430. """
  431. check chain of thought prompt messages, a standard prompt message is like:
  432. Respond to the human as helpfully and accurately as possible.
  433. {{instruction}}
  434. You have access to the following tools:
  435. {{tools}}
  436. Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
  437. Valid action values: "Final Answer" or {{tool_names}}
  438. Provide only ONE action per $JSON_BLOB, as shown:
  439. ```
  440. {
  441. "action": $TOOL_NAME,
  442. "action_input": $ACTION_INPUT
  443. }
  444. ```
  445. """
  446. # parse agent prompt message
  447. first_prompt = agent_prompt_message.first_prompt
  448. next_iteration = agent_prompt_message.next_iteration
  449. if not isinstance(first_prompt, str) or not isinstance(next_iteration, str):
  450. raise ValueError("first_prompt or next_iteration is required in CoT agent mode")
  451. # check instruction, tools, and tool_names slots
  452. if not first_prompt.find("{{instruction}}") >= 0:
  453. raise ValueError("{{instruction}} is required in first_prompt")
  454. if not first_prompt.find("{{tools}}") >= 0:
  455. raise ValueError("{{tools}} is required in first_prompt")
  456. if not first_prompt.find("{{tool_names}}") >= 0:
  457. raise ValueError("{{tool_names}} is required in first_prompt")
  458. if mode == "completion":
  459. if not first_prompt.find("{{query}}") >= 0:
  460. raise ValueError("{{query}} is required in first_prompt")
  461. if not first_prompt.find("{{agent_scratchpad}}") >= 0:
  462. raise ValueError("{{agent_scratchpad}} is required in first_prompt")
  463. if mode == "completion":
  464. if not next_iteration.find("{{observation}}") >= 0:
  465. raise ValueError("{{observation}} is required in next_iteration")
  466. def _convert_scratchpad_list_to_str(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str:
  467. """
  468. convert agent scratchpad list to str
  469. """
  470. next_iteration = self.app_orchestration_config.agent.prompt.next_iteration
  471. result = ''
  472. for scratchpad in agent_scratchpad:
  473. result += scratchpad.thought + next_iteration.replace("{{observation}}", scratchpad.observation or '') + "\n"
  474. return result
  475. def _organize_cot_prompt_messages(self, mode: Literal["completion", "chat"],
  476. prompt_messages: list[PromptMessage],
  477. tools: list[PromptMessageTool],
  478. agent_scratchpad: list[AgentScratchpadUnit],
  479. agent_prompt_message: AgentPromptEntity,
  480. instruction: str,
  481. input: str,
  482. ) -> list[PromptMessage]:
  483. """
  484. organize chain of thought prompt messages, a standard prompt message is like:
  485. Respond to the human as helpfully and accurately as possible.
  486. {{instruction}}
  487. You have access to the following tools:
  488. {{tools}}
  489. Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
  490. Valid action values: "Final Answer" or {{tool_names}}
  491. Provide only ONE action per $JSON_BLOB, as shown:
  492. ```
  493. {{{{
  494. "action": $TOOL_NAME,
  495. "action_input": $ACTION_INPUT
  496. }}}}
  497. ```
  498. """
  499. self._check_cot_prompt_messages(mode, agent_prompt_message)
  500. # parse agent prompt message
  501. first_prompt = agent_prompt_message.first_prompt
  502. # parse tools
  503. tools_str = self._jsonify_tool_prompt_messages(tools)
  504. # parse tools name
  505. tool_names = '"' + '","'.join([tool.name for tool in tools]) + '"'
  506. # get system message
  507. system_message = first_prompt.replace("{{instruction}}", instruction) \
  508. .replace("{{tools}}", tools_str) \
  509. .replace("{{tool_names}}", tool_names)
  510. # organize prompt messages
  511. if mode == "chat":
  512. # override system message
  513. overrided = False
  514. prompt_messages = prompt_messages.copy()
  515. for prompt_message in prompt_messages:
  516. if isinstance(prompt_message, SystemPromptMessage):
  517. prompt_message.content = system_message
  518. overrided = True
  519. break
  520. if not overrided:
  521. prompt_messages.insert(0, SystemPromptMessage(
  522. content=system_message,
  523. ))
  524. # add assistant message
  525. if len(agent_scratchpad) > 0:
  526. prompt_messages.append(AssistantPromptMessage(
  527. content=(agent_scratchpad[-1].thought or '')
  528. ))
  529. # add user message
  530. if len(agent_scratchpad) > 0:
  531. prompt_messages.append(UserPromptMessage(
  532. content=(agent_scratchpad[-1].observation or ''),
  533. ))
  534. return prompt_messages
  535. elif mode == "completion":
  536. # parse agent scratchpad
  537. agent_scratchpad_str = self._convert_scratchpad_list_to_str(agent_scratchpad)
  538. # parse prompt messages
  539. return [UserPromptMessage(
  540. content=first_prompt.replace("{{instruction}}", instruction)
  541. .replace("{{tools}}", tools_str)
  542. .replace("{{tool_names}}", tool_names)
  543. .replace("{{query}}", input)
  544. .replace("{{agent_scratchpad}}", agent_scratchpad_str),
  545. )]
  546. else:
  547. raise ValueError(f"mode {mode} is not supported")
  548. def _jsonify_tool_prompt_messages(self, tools: list[PromptMessageTool]) -> str:
  549. """
  550. jsonify tool prompt messages
  551. """
  552. tools = jsonable_encoder(tools)
  553. try:
  554. return json.dumps(tools, ensure_ascii=False)
  555. except json.JSONDecodeError:
  556. return json.dumps(tools)