cot_chat_agent_runner.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. import json
  2. from core.agent.cot_agent_runner import CotAgentRunner
  3. from core.model_runtime.entities.message_entities import (
  4. AssistantPromptMessage,
  5. PromptMessage,
  6. SystemPromptMessage,
  7. UserPromptMessage,
  8. )
  9. from core.model_runtime.utils.encoders import jsonable_encoder
  10. class CotChatAgentRunner(CotAgentRunner):
  11. def _organize_system_prompt(self) -> SystemPromptMessage:
  12. """
  13. Organize system prompt
  14. """
  15. prompt_entity = self.app_config.agent.prompt
  16. first_prompt = prompt_entity.first_prompt
  17. system_prompt = first_prompt \
  18. .replace("{{instruction}}", self._instruction) \
  19. .replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) \
  20. .replace("{{tool_names}}", ', '.join([tool.name for tool in self._prompt_messages_tools]))
  21. return SystemPromptMessage(content=system_prompt)
  22. def _organize_prompt_messages(self) -> list[PromptMessage]:
  23. """
  24. Organize
  25. """
  26. # organize system prompt
  27. system_message = self._organize_system_prompt()
  28. # organize current assistant messages
  29. agent_scratchpad = self._agent_scratchpad
  30. if not agent_scratchpad:
  31. assistant_messages = []
  32. else:
  33. assistant_message = AssistantPromptMessage(content='')
  34. for unit in agent_scratchpad:
  35. if unit.is_final():
  36. assistant_message.content += f"Final Answer: {unit.agent_response}"
  37. else:
  38. assistant_message.content += f"Thought: {unit.thought}\n\n"
  39. if unit.action_str:
  40. assistant_message.content += f"Action: {unit.action_str}\n\n"
  41. if unit.observation:
  42. assistant_message.content += f"Observation: {unit.observation}\n\n"
  43. assistant_messages = [assistant_message]
  44. # query messages
  45. query_messages = UserPromptMessage(content=self._query)
  46. if assistant_messages:
  47. # organize historic prompt messages
  48. historic_messages = self._organize_historic_prompt_messages([
  49. system_message,
  50. query_messages,
  51. *assistant_messages,
  52. UserPromptMessage(content='continue')
  53. ])
  54. messages = [
  55. system_message,
  56. *historic_messages,
  57. query_messages,
  58. *assistant_messages,
  59. UserPromptMessage(content='continue')
  60. ]
  61. else:
  62. # organize historic prompt messages
  63. historic_messages = self._organize_historic_prompt_messages([system_message, query_messages])
  64. messages = [system_message, *historic_messages, query_messages]
  65. # join all messages
  66. return messages