cot_chat_agent_runner.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. import json
  2. from core.agent.cot_agent_runner import CotAgentRunner
  3. from core.model_runtime.entities.message_entities import (
  4. AssistantPromptMessage,
  5. PromptMessage,
  6. SystemPromptMessage,
  7. TextPromptMessageContent,
  8. UserPromptMessage,
  9. )
  10. from core.model_runtime.utils.encoders import jsonable_encoder
  11. class CotChatAgentRunner(CotAgentRunner):
  12. def _organize_system_prompt(self) -> SystemPromptMessage:
  13. """
  14. Organize system prompt
  15. """
  16. prompt_entity = self.app_config.agent.prompt
  17. first_prompt = prompt_entity.first_prompt
  18. system_prompt = first_prompt \
  19. .replace("{{instruction}}", self._instruction) \
  20. .replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) \
  21. .replace("{{tool_names}}", ', '.join([tool.name for tool in self._prompt_messages_tools]))
  22. return SystemPromptMessage(content=system_prompt)
  23. def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
  24. """
  25. Organize user query
  26. """
  27. if self.files:
  28. prompt_message_contents = [TextPromptMessageContent(data=query)]
  29. for file_obj in self.files:
  30. prompt_message_contents.append(file_obj.prompt_message_content)
  31. prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
  32. else:
  33. prompt_messages.append(UserPromptMessage(content=query))
  34. return prompt_messages
  35. def _organize_prompt_messages(self) -> list[PromptMessage]:
  36. """
  37. Organize
  38. """
  39. # organize system prompt
  40. system_message = self._organize_system_prompt()
  41. # organize current assistant messages
  42. agent_scratchpad = self._agent_scratchpad
  43. if not agent_scratchpad:
  44. assistant_messages = []
  45. else:
  46. assistant_message = AssistantPromptMessage(content='')
  47. for unit in agent_scratchpad:
  48. if unit.is_final():
  49. assistant_message.content += f"Final Answer: {unit.agent_response}"
  50. else:
  51. assistant_message.content += f"Thought: {unit.thought}\n\n"
  52. if unit.action_str:
  53. assistant_message.content += f"Action: {unit.action_str}\n\n"
  54. if unit.observation:
  55. assistant_message.content += f"Observation: {unit.observation}\n\n"
  56. assistant_messages = [assistant_message]
  57. # query messages
  58. query_messages = self._organize_user_query(self._query, [])
  59. if assistant_messages:
  60. # organize historic prompt messages
  61. historic_messages = self._organize_historic_prompt_messages([
  62. system_message,
  63. *query_messages,
  64. *assistant_messages,
  65. UserPromptMessage(content='continue')
  66. ])
  67. messages = [
  68. system_message,
  69. *historic_messages,
  70. *query_messages,
  71. *assistant_messages,
  72. UserPromptMessage(content='continue')
  73. ]
  74. else:
  75. # organize historic prompt messages
  76. historic_messages = self._organize_historic_prompt_messages([system_message, *query_messages])
  77. messages = [system_message, *historic_messages, *query_messages]
  78. # join all messages
  79. return messages