assistant_cot_runner.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597
  1. import json
  2. import logging
  3. import re
  4. from typing import Dict, Generator, List, Literal, Union
  5. from core.application_queue_manager import PublishFrom
  6. from core.entities.application_entities import AgentPromptEntity, AgentScratchpadUnit
  7. from core.features.assistant_base_runner import BaseAssistantApplicationRunner
  8. from core.model_manager import ModelInstance
  9. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
  10. from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool,
  11. SystemPromptMessage, UserPromptMessage)
  12. from core.model_runtime.utils.encoders import jsonable_encoder
  13. from core.tools.errors import (ToolInvokeError, ToolNotFoundError, ToolNotSupportedError, ToolParameterValidationError,
  14. ToolProviderCredentialValidationError, ToolProviderNotFoundError)
  15. from models.model import Conversation, Message
  16. class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
  17. def run(self, conversation: Conversation,
  18. message: Message,
  19. query: str,
  20. inputs: Dict[str, str],
  21. ) -> Union[Generator, LLMResult]:
  22. """
  23. Run Cot agent application
  24. """
  25. app_orchestration_config = self.app_orchestration_config
  26. self._repack_app_orchestration_config(app_orchestration_config)
  27. agent_scratchpad: List[AgentScratchpadUnit] = []
  28. # check model mode
  29. if self.app_orchestration_config.model_config.mode == "completion":
  30. # TODO: stop words
  31. if 'Observation' not in app_orchestration_config.model_config.stop:
  32. app_orchestration_config.model_config.stop.append('Observation')
  33. # override inputs
  34. inputs = inputs or {}
  35. instruction = self.app_orchestration_config.prompt_template.simple_prompt_template
  36. instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
  37. iteration_step = 1
  38. max_iteration_steps = min(self.app_orchestration_config.agent.max_iteration, 5) + 1
  39. prompt_messages = self.history_prompt_messages
  40. # convert tools into ModelRuntime Tool format
  41. prompt_messages_tools: List[PromptMessageTool] = []
  42. tool_instances = {}
  43. for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []:
  44. try:
  45. prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)
  46. except Exception:
  47. # api tool may be deleted
  48. continue
  49. # save tool entity
  50. tool_instances[tool.tool_name] = tool_entity
  51. # save prompt tool
  52. prompt_messages_tools.append(prompt_tool)
  53. # convert dataset tools into ModelRuntime Tool format
  54. for dataset_tool in self.dataset_tools:
  55. prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool)
  56. # save prompt tool
  57. prompt_messages_tools.append(prompt_tool)
  58. # save tool entity
  59. tool_instances[dataset_tool.identity.name] = dataset_tool
  60. function_call_state = True
  61. llm_usage = {
  62. 'usage': None
  63. }
  64. final_answer = ''
  65. def increase_usage(final_llm_usage_dict: Dict[str, LLMUsage], usage: LLMUsage):
  66. if not final_llm_usage_dict['usage']:
  67. final_llm_usage_dict['usage'] = usage
  68. else:
  69. llm_usage = final_llm_usage_dict['usage']
  70. llm_usage.prompt_tokens += usage.prompt_tokens
  71. llm_usage.completion_tokens += usage.completion_tokens
  72. llm_usage.prompt_price += usage.prompt_price
  73. llm_usage.completion_price += usage.completion_price
  74. model_instance = self.model_instance
  75. while function_call_state and iteration_step <= max_iteration_steps:
  76. # continue to run until there is not any tool call
  77. function_call_state = False
  78. if iteration_step == max_iteration_steps:
  79. # the last iteration, remove all tools
  80. prompt_messages_tools = []
  81. message_file_ids = []
  82. agent_thought = self.create_agent_thought(
  83. message_id=message.id,
  84. message='',
  85. tool_name='',
  86. tool_input='',
  87. messages_ids=message_file_ids
  88. )
  89. if iteration_step > 1:
  90. self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
  91. # update prompt messages
  92. prompt_messages = self._organize_cot_prompt_messages(
  93. mode=app_orchestration_config.model_config.mode,
  94. prompt_messages=prompt_messages,
  95. tools=prompt_messages_tools,
  96. agent_scratchpad=agent_scratchpad,
  97. agent_prompt_message=app_orchestration_config.agent.prompt,
  98. instruction=instruction,
  99. input=query
  100. )
  101. # recale llm max tokens
  102. self.recale_llm_max_tokens(self.model_config, prompt_messages)
  103. # invoke model
  104. llm_result: LLMResult = model_instance.invoke_llm(
  105. prompt_messages=prompt_messages,
  106. model_parameters=app_orchestration_config.model_config.parameters,
  107. tools=[],
  108. stop=app_orchestration_config.model_config.stop,
  109. stream=False,
  110. user=self.user_id,
  111. callbacks=[],
  112. )
  113. # check llm result
  114. if not llm_result:
  115. raise ValueError("failed to invoke llm")
  116. # get scratchpad
  117. scratchpad = self._extract_response_scratchpad(llm_result.message.content)
  118. agent_scratchpad.append(scratchpad)
  119. # get llm usage
  120. if llm_result.usage:
  121. increase_usage(llm_usage, llm_result.usage)
  122. # publish agent thought if it's first iteration
  123. if iteration_step == 1:
  124. self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
  125. self.save_agent_thought(agent_thought=agent_thought,
  126. tool_name=scratchpad.action.action_name if scratchpad.action else '',
  127. tool_input=scratchpad.action.action_input if scratchpad.action else '',
  128. thought=scratchpad.thought,
  129. observation='',
  130. answer=llm_result.message.content,
  131. messages_ids=[],
  132. llm_usage=llm_result.usage)
  133. if scratchpad.action and scratchpad.action.action_name.lower() != "final answer":
  134. self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
  135. # publish agent thought if it's not empty and there is a action
  136. if scratchpad.thought and scratchpad.action:
  137. # check if final answer
  138. if not scratchpad.action.action_name.lower() == "final answer":
  139. yield LLMResultChunk(
  140. model=model_instance.model,
  141. prompt_messages=prompt_messages,
  142. delta=LLMResultChunkDelta(
  143. index=0,
  144. message=AssistantPromptMessage(
  145. content=scratchpad.thought
  146. ),
  147. usage=llm_result.usage,
  148. ),
  149. system_fingerprint=''
  150. )
  151. if not scratchpad.action:
  152. # failed to extract action, return final answer directly
  153. final_answer = scratchpad.agent_response or ''
  154. else:
  155. if scratchpad.action.action_name.lower() == "final answer":
  156. # action is final answer, return final answer directly
  157. try:
  158. final_answer = scratchpad.action.action_input if \
  159. isinstance(scratchpad.action.action_input, str) else \
  160. json.dumps(scratchpad.action.action_input)
  161. except json.JSONDecodeError:
  162. final_answer = f'{scratchpad.action.action_input}'
  163. else:
  164. function_call_state = True
  165. # action is tool call, invoke tool
  166. tool_call_name = scratchpad.action.action_name
  167. tool_call_args = scratchpad.action.action_input
  168. tool_instance = tool_instances.get(tool_call_name)
  169. if not tool_instance:
  170. answer = f"there is not a tool named {tool_call_name}"
  171. self.save_agent_thought(agent_thought=agent_thought,
  172. tool_name='',
  173. tool_input='',
  174. thought=None,
  175. observation=answer,
  176. answer=answer,
  177. messages_ids=[])
  178. self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
  179. else:
  180. # invoke tool
  181. error_response = None
  182. try:
  183. tool_response = tool_instance.invoke(
  184. user_id=self.user_id,
  185. tool_parameters=tool_call_args if isinstance(tool_call_args, dict) else json.loads(tool_call_args)
  186. )
  187. # transform tool response to llm friendly response
  188. tool_response = self.transform_tool_invoke_messages(tool_response)
  189. # extract binary data from tool invoke message
  190. binary_files = self.extract_tool_response_binary(tool_response)
  191. # create message file
  192. message_files = self.create_message_files(binary_files)
  193. # publish files
  194. for message_file, save_as in message_files:
  195. if save_as:
  196. self.variables_pool.set_file(tool_name=tool_call_name,
  197. value=message_file.id,
  198. name=save_as)
  199. self.queue_manager.publish_message_file(message_file, PublishFrom.APPLICATION_MANAGER)
  200. message_file_ids = [message_file.id for message_file, _ in message_files]
  201. except ToolProviderCredentialValidationError as e:
  202. error_response = f"Please check your tool provider credentials"
  203. except (
  204. ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError
  205. ) as e:
  206. error_response = f"there is not a tool named {tool_call_name}"
  207. except (
  208. ToolParameterValidationError
  209. ) as e:
  210. error_response = f"tool parameters validation error: {e}, please check your tool parameters"
  211. except ToolInvokeError as e:
  212. error_response = f"tool invoke error: {e}"
  213. except Exception as e:
  214. error_response = f"unknown error: {e}"
  215. if error_response:
  216. observation = error_response
  217. else:
  218. observation = self._convert_tool_response_to_str(tool_response)
  219. # save scratchpad
  220. scratchpad.observation = observation
  221. scratchpad.agent_response = llm_result.message.content
  222. # save agent thought
  223. self.save_agent_thought(
  224. agent_thought=agent_thought,
  225. tool_name=tool_call_name,
  226. tool_input=tool_call_args,
  227. thought=None,
  228. observation=observation,
  229. answer=llm_result.message.content,
  230. messages_ids=message_file_ids,
  231. )
  232. self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER)
  233. # update prompt tool message
  234. for prompt_tool in prompt_messages_tools:
  235. self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
  236. iteration_step += 1
  237. yield LLMResultChunk(
  238. model=model_instance.model,
  239. prompt_messages=prompt_messages,
  240. delta=LLMResultChunkDelta(
  241. index=0,
  242. message=AssistantPromptMessage(
  243. content=final_answer
  244. ),
  245. usage=llm_usage['usage']
  246. ),
  247. system_fingerprint=''
  248. )
  249. # save agent thought
  250. self.save_agent_thought(
  251. agent_thought=agent_thought,
  252. tool_name='',
  253. tool_input='',
  254. thought=final_answer,
  255. observation='',
  256. answer=final_answer,
  257. messages_ids=[]
  258. )
  259. self.update_db_variables(self.variables_pool, self.db_variables_pool)
  260. # publish end event
  261. self.queue_manager.publish_message_end(LLMResult(
  262. model=model_instance.model,
  263. prompt_messages=prompt_messages,
  264. message=AssistantPromptMessage(
  265. content=final_answer
  266. ),
  267. usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(),
  268. system_fingerprint=''
  269. ), PublishFrom.APPLICATION_MANAGER)
  270. def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str:
  271. """
  272. fill in inputs from external data tools
  273. """
  274. for key, value in inputs.items():
  275. try:
  276. instruction = instruction.replace(f'{{{{{key}}}}}', str(value))
  277. except Exception as e:
  278. continue
  279. return instruction
  280. def _extract_response_scratchpad(self, content: str) -> AgentScratchpadUnit:
  281. """
  282. extract response from llm response
  283. """
  284. def extra_quotes() -> AgentScratchpadUnit:
  285. agent_response = content
  286. # try to extract all quotes
  287. pattern = re.compile(r'```(.*?)```', re.DOTALL)
  288. quotes = pattern.findall(content)
  289. # try to extract action from end to start
  290. for i in range(len(quotes) - 1, 0, -1):
  291. """
  292. 1. use json load to parse action
  293. 2. use plain text `Action: xxx` to parse action
  294. """
  295. try:
  296. action = json.loads(quotes[i].replace('```', ''))
  297. action_name = action.get("action")
  298. action_input = action.get("action_input")
  299. agent_thought = agent_response.replace(quotes[i], '')
  300. if action_name and action_input:
  301. return AgentScratchpadUnit(
  302. agent_response=content,
  303. thought=agent_thought,
  304. action_str=quotes[i],
  305. action=AgentScratchpadUnit.Action(
  306. action_name=action_name,
  307. action_input=action_input,
  308. )
  309. )
  310. except:
  311. # try to parse action from plain text
  312. action_name = re.findall(r'action: (.*)', quotes[i], re.IGNORECASE)
  313. action_input = re.findall(r'action input: (.*)', quotes[i], re.IGNORECASE)
  314. # delete action from agent response
  315. agent_thought = agent_response.replace(quotes[i], '')
  316. # remove extra quotes
  317. agent_thought = re.sub(r'```(json)*\n*```', '', agent_thought, flags=re.DOTALL)
  318. # remove Action: xxx from agent thought
  319. agent_thought = re.sub(r'Action:.*', '', agent_thought, flags=re.IGNORECASE)
  320. if action_name and action_input:
  321. return AgentScratchpadUnit(
  322. agent_response=content,
  323. thought=agent_thought,
  324. action_str=quotes[i],
  325. action=AgentScratchpadUnit.Action(
  326. action_name=action_name[0],
  327. action_input=action_input[0],
  328. )
  329. )
  330. def extra_json():
  331. agent_response = content
  332. # try to extract all json
  333. structures, pair_match_stack = [], []
  334. started_at, end_at = 0, 0
  335. for i in range(len(content)):
  336. if content[i] == '{':
  337. pair_match_stack.append(i)
  338. if len(pair_match_stack) == 1:
  339. started_at = i
  340. elif content[i] == '}':
  341. begin = pair_match_stack.pop()
  342. if not pair_match_stack:
  343. end_at = i + 1
  344. structures.append((content[begin:i+1], (started_at, end_at)))
  345. # handle the last character
  346. if pair_match_stack:
  347. end_at = len(content)
  348. structures.append((content[pair_match_stack[0]:], (started_at, end_at)))
  349. for i in range(len(structures), 0, -1):
  350. try:
  351. json_content, (started_at, end_at) = structures[i - 1]
  352. action = json.loads(json_content)
  353. action_name = action.get("action")
  354. action_input = action.get("action_input")
  355. # delete json content from agent response
  356. agent_thought = agent_response[:started_at] + agent_response[end_at:]
  357. # remove extra quotes like ```(json)*\n\n```
  358. agent_thought = re.sub(r'```(json)*\n*```', '', agent_thought, flags=re.DOTALL)
  359. # remove Action: xxx from agent thought
  360. agent_thought = re.sub(r'Action:.*', '', agent_thought, flags=re.IGNORECASE)
  361. if action_name and action_input is not None:
  362. return AgentScratchpadUnit(
  363. agent_response=content,
  364. thought=agent_thought,
  365. action_str=json_content,
  366. action=AgentScratchpadUnit.Action(
  367. action_name=action_name,
  368. action_input=action_input,
  369. )
  370. )
  371. except:
  372. pass
  373. agent_scratchpad = extra_quotes()
  374. if agent_scratchpad:
  375. return agent_scratchpad
  376. agent_scratchpad = extra_json()
  377. if agent_scratchpad:
  378. return agent_scratchpad
  379. return AgentScratchpadUnit(
  380. agent_response=content,
  381. thought=content,
  382. action_str='',
  383. action=None
  384. )
  385. def _check_cot_prompt_messages(self, mode: Literal["completion", "chat"],
  386. agent_prompt_message: AgentPromptEntity,
  387. ):
  388. """
  389. check chain of thought prompt messages, a standard prompt message is like:
  390. Respond to the human as helpfully and accurately as possible.
  391. {{instruction}}
  392. You have access to the following tools:
  393. {{tools}}
  394. Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
  395. Valid action values: "Final Answer" or {{tool_names}}
  396. Provide only ONE action per $JSON_BLOB, as shown:
  397. ```
  398. {
  399. "action": $TOOL_NAME,
  400. "action_input": $ACTION_INPUT
  401. }
  402. ```
  403. """
  404. # parse agent prompt message
  405. first_prompt = agent_prompt_message.first_prompt
  406. next_iteration = agent_prompt_message.next_iteration
  407. if not isinstance(first_prompt, str) or not isinstance(next_iteration, str):
  408. raise ValueError(f"first_prompt or next_iteration is required in CoT agent mode")
  409. # check instruction, tools, and tool_names slots
  410. if not first_prompt.find("{{instruction}}") >= 0:
  411. raise ValueError("{{instruction}} is required in first_prompt")
  412. if not first_prompt.find("{{tools}}") >= 0:
  413. raise ValueError("{{tools}} is required in first_prompt")
  414. if not first_prompt.find("{{tool_names}}") >= 0:
  415. raise ValueError("{{tool_names}} is required in first_prompt")
  416. if mode == "completion":
  417. if not first_prompt.find("{{query}}") >= 0:
  418. raise ValueError("{{query}} is required in first_prompt")
  419. if not first_prompt.find("{{agent_scratchpad}}") >= 0:
  420. raise ValueError("{{agent_scratchpad}} is required in first_prompt")
  421. if mode == "completion":
  422. if not next_iteration.find("{{observation}}") >= 0:
  423. raise ValueError("{{observation}} is required in next_iteration")
  424. def _convert_scratchpad_list_to_str(self, agent_scratchpad: List[AgentScratchpadUnit]) -> str:
  425. """
  426. convert agent scratchpad list to str
  427. """
  428. next_iteration = self.app_orchestration_config.agent.prompt.next_iteration
  429. result = ''
  430. for scratchpad in agent_scratchpad:
  431. result += scratchpad.thought + next_iteration.replace("{{observation}}", scratchpad.observation or '') + "\n"
  432. return result
  433. def _organize_cot_prompt_messages(self, mode: Literal["completion", "chat"],
  434. prompt_messages: List[PromptMessage],
  435. tools: List[PromptMessageTool],
  436. agent_scratchpad: List[AgentScratchpadUnit],
  437. agent_prompt_message: AgentPromptEntity,
  438. instruction: str,
  439. input: str,
  440. ) -> List[PromptMessage]:
  441. """
  442. organize chain of thought prompt messages, a standard prompt message is like:
  443. Respond to the human as helpfully and accurately as possible.
  444. {{instruction}}
  445. You have access to the following tools:
  446. {{tools}}
  447. Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
  448. Valid action values: "Final Answer" or {{tool_names}}
  449. Provide only ONE action per $JSON_BLOB, as shown:
  450. ```
  451. {{{{
  452. "action": $TOOL_NAME,
  453. "action_input": $ACTION_INPUT
  454. }}}}
  455. ```
  456. """
  457. self._check_cot_prompt_messages(mode, agent_prompt_message)
  458. # parse agent prompt message
  459. first_prompt = agent_prompt_message.first_prompt
  460. # parse tools
  461. tools_str = self._jsonify_tool_prompt_messages(tools)
  462. # parse tools name
  463. tool_names = '"' + '","'.join([tool.name for tool in tools]) + '"'
  464. # get system message
  465. system_message = first_prompt.replace("{{instruction}}", instruction) \
  466. .replace("{{tools}}", tools_str) \
  467. .replace("{{tool_names}}", tool_names)
  468. # organize prompt messages
  469. if mode == "chat":
  470. # override system message
  471. overrided = False
  472. prompt_messages = prompt_messages.copy()
  473. for prompt_message in prompt_messages:
  474. if isinstance(prompt_message, SystemPromptMessage):
  475. prompt_message.content = system_message
  476. overrided = True
  477. break
  478. if not overrided:
  479. prompt_messages.insert(0, SystemPromptMessage(
  480. content=system_message,
  481. ))
  482. # add assistant message
  483. if len(agent_scratchpad) > 0:
  484. prompt_messages.append(AssistantPromptMessage(
  485. content=(agent_scratchpad[-1].thought or '')
  486. ))
  487. # add user message
  488. if len(agent_scratchpad) > 0:
  489. prompt_messages.append(UserPromptMessage(
  490. content=(agent_scratchpad[-1].observation or ''),
  491. ))
  492. return prompt_messages
  493. elif mode == "completion":
  494. # parse agent scratchpad
  495. agent_scratchpad_str = self._convert_scratchpad_list_to_str(agent_scratchpad)
  496. # parse prompt messages
  497. return [UserPromptMessage(
  498. content=first_prompt.replace("{{instruction}}", instruction)
  499. .replace("{{tools}}", tools_str)
  500. .replace("{{tool_names}}", tool_names)
  501. .replace("{{query}}", input)
  502. .replace("{{agent_scratchpad}}", agent_scratchpad_str),
  503. )]
  504. else:
  505. raise ValueError(f"mode {mode} is not supported")
  506. def _jsonify_tool_prompt_messages(self, tools: list[PromptMessageTool]) -> str:
  507. """
  508. jsonify tool prompt messages
  509. """
  510. tools = jsonable_encoder(tools)
  511. try:
  512. return json.dumps(tools, ensure_ascii=False)
  513. except json.JSONDecodeError:
  514. return json.dumps(tools)