cot_chat_agent_runner.py 3.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. import json
  2. from core.agent.cot_agent_runner import CotAgentRunner
  3. from core.model_runtime.entities.message_entities import (
  4. AssistantPromptMessage,
  5. PromptMessage,
  6. SystemPromptMessage,
  7. TextPromptMessageContent,
  8. UserPromptMessage,
  9. )
  10. from core.model_runtime.utils.encoders import jsonable_encoder
  11. class CotChatAgentRunner(CotAgentRunner):
  12. def _organize_system_prompt(self) -> SystemPromptMessage:
  13. """
  14. Organize system prompt
  15. """
  16. prompt_entity = self.app_config.agent.prompt
  17. first_prompt = prompt_entity.first_prompt
  18. system_prompt = (
  19. first_prompt.replace("{{instruction}}", self._instruction)
  20. .replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools)))
  21. .replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools]))
  22. )
  23. return SystemPromptMessage(content=system_prompt)
  24. def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
  25. """
  26. Organize user query
  27. """
  28. if self.files:
  29. prompt_message_contents = [TextPromptMessageContent(data=query)]
  30. for file_obj in self.files:
  31. prompt_message_contents.append(file_obj.prompt_message_content)
  32. prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
  33. else:
  34. prompt_messages.append(UserPromptMessage(content=query))
  35. return prompt_messages
  36. def _organize_prompt_messages(self) -> list[PromptMessage]:
  37. """
  38. Organize
  39. """
  40. # organize system prompt
  41. system_message = self._organize_system_prompt()
  42. # organize current assistant messages
  43. agent_scratchpad = self._agent_scratchpad
  44. if not agent_scratchpad:
  45. assistant_messages = []
  46. else:
  47. assistant_message = AssistantPromptMessage(content="")
  48. for unit in agent_scratchpad:
  49. if unit.is_final():
  50. assistant_message.content += f"Final Answer: {unit.agent_response}"
  51. else:
  52. assistant_message.content += f"Thought: {unit.thought}\n\n"
  53. if unit.action_str:
  54. assistant_message.content += f"Action: {unit.action_str}\n\n"
  55. if unit.observation:
  56. assistant_message.content += f"Observation: {unit.observation}\n\n"
  57. assistant_messages = [assistant_message]
  58. # query messages
  59. query_messages = self._organize_user_query(self._query, [])
  60. if assistant_messages:
  61. # organize historic prompt messages
  62. historic_messages = self._organize_historic_prompt_messages(
  63. [system_message, *query_messages, *assistant_messages, UserPromptMessage(content="continue")]
  64. )
  65. messages = [
  66. system_message,
  67. *historic_messages,
  68. *query_messages,
  69. *assistant_messages,
  70. UserPromptMessage(content="continue"),
  71. ]
  72. else:
  73. # organize historic prompt messages
  74. historic_messages = self._organize_historic_prompt_messages([system_message, *query_messages])
  75. messages = [system_message, *historic_messages, *query_messages]
  76. # join all messages
  77. return messages