|
@@ -5,6 +5,7 @@ from core.model_runtime.entities.message_entities import (
|
|
|
AssistantPromptMessage,
|
|
|
PromptMessage,
|
|
|
SystemPromptMessage,
|
|
|
+ TextPromptMessageContent,
|
|
|
UserPromptMessage,
|
|
|
)
|
|
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
|
@@ -25,6 +26,21 @@ class CotChatAgentRunner(CotAgentRunner):
|
|
|
|
|
|
return SystemPromptMessage(content=system_prompt)
|
|
|
|
|
|
+ def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
|
|
|
+ """
|
|
|
+ Organize user query
|
|
|
+ """
|
|
|
+ if self.files:
|
|
|
+ prompt_message_contents = [TextPromptMessageContent(data=query)]
|
|
|
+ for file_obj in self.files:
|
|
|
+ prompt_message_contents.append(file_obj.prompt_message_content)
|
|
|
+
|
|
|
+ prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
|
|
+ else:
|
|
|
+ prompt_messages.append(UserPromptMessage(content=query))
|
|
|
+
|
|
|
+ return prompt_messages
|
|
|
+
|
|
|
def _organize_prompt_messages(self) -> list[PromptMessage]:
|
|
|
"""
|
|
|
Organize
|
|
@@ -51,27 +67,27 @@ class CotChatAgentRunner(CotAgentRunner):
|
|
|
assistant_messages = [assistant_message]
|
|
|
|
|
|
# query messages
|
|
|
- query_messages = UserPromptMessage(content=self._query)
|
|
|
+ query_messages = self._organize_user_query(self._query, [])
|
|
|
|
|
|
if assistant_messages:
|
|
|
# organize historic prompt messages
|
|
|
historic_messages = self._organize_historic_prompt_messages([
|
|
|
system_message,
|
|
|
- query_messages,
|
|
|
+ *query_messages,
|
|
|
*assistant_messages,
|
|
|
UserPromptMessage(content='continue')
|
|
|
- ])
|
|
|
+ ])
|
|
|
messages = [
|
|
|
system_message,
|
|
|
*historic_messages,
|
|
|
- query_messages,
|
|
|
+ *query_messages,
|
|
|
*assistant_messages,
|
|
|
UserPromptMessage(content='continue')
|
|
|
]
|
|
|
else:
|
|
|
# organize historic prompt messages
|
|
|
- historic_messages = self._organize_historic_prompt_messages([system_message, query_messages])
|
|
|
- messages = [system_message, *historic_messages, query_messages]
|
|
|
+ historic_messages = self._organize_historic_prompt_messages([system_message, *query_messages])
|
|
|
+ messages = [system_message, *historic_messages, *query_messages]
|
|
|
|
|
|
# join all messages
|
|
|
- return messages
|
|
|
+ return messages
|