Browse Source

fix: resolve issue with cot_agent_runner not analyzing user-uploaded images correctly (#5360)

Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Xiao Ley 9 months ago
parent
commit
369a395ee9
2 changed files with 23 additions and 9 deletions
  1. 0 2
      api/core/agent/cot_agent_runner.py
  2. 23 7
      api/core/agent/cot_chat_agent_runner.py

+ 0 - 2
api/core/agent/cot_agent_runner.py

@@ -61,8 +61,6 @@ class CotAgentRunner(BaseAgentRunner, ABC):
         # convert tools into ModelRuntime Tool format
         tool_instances, self._prompt_messages_tools = self._init_prompt_tools()
 
-        prompt_messages = self._organize_prompt_messages()
-
         function_call_state = True
         llm_usage = {
             'usage': None

+ 23 - 7
api/core/agent/cot_chat_agent_runner.py

@@ -5,6 +5,7 @@ from core.model_runtime.entities.message_entities import (
     AssistantPromptMessage,
     PromptMessage,
     SystemPromptMessage,
+    TextPromptMessageContent,
     UserPromptMessage,
 )
 from core.model_runtime.utils.encoders import jsonable_encoder
@@ -25,6 +26,21 @@ class CotChatAgentRunner(CotAgentRunner):
 
         return SystemPromptMessage(content=system_prompt)
 
+    def _organize_user_query(self, query,  prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
+        """
+        Organize user query
+        """
+        if self.files:
+            prompt_message_contents = [TextPromptMessageContent(data=query)]
+            for file_obj in self.files:
+                prompt_message_contents.append(file_obj.prompt_message_content)
+
+            prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
+        else:
+            prompt_messages.append(UserPromptMessage(content=query))
+
+        return prompt_messages
+
     def _organize_prompt_messages(self) -> list[PromptMessage]:
         """
         Organize 
@@ -51,27 +67,27 @@ class CotChatAgentRunner(CotAgentRunner):
             assistant_messages = [assistant_message]
 
         # query messages
-        query_messages = UserPromptMessage(content=self._query)
+        query_messages = self._organize_user_query(self._query, [])
 
         if assistant_messages:
             # organize historic prompt messages
             historic_messages = self._organize_historic_prompt_messages([
                 system_message,
-                query_messages,
+                *query_messages,
                 *assistant_messages,
                 UserPromptMessage(content='continue')
-            ])            
+            ])
             messages = [
                 system_message,
                 *historic_messages,
-                query_messages,
+                *query_messages,
                 *assistant_messages,
                 UserPromptMessage(content='continue')
             ]
         else:
             # organize historic prompt messages
-            historic_messages = self._organize_historic_prompt_messages([system_message, query_messages])
-            messages = [system_message, *historic_messages, query_messages]
+            historic_messages = self._organize_historic_prompt_messages([system_message, *query_messages])
+            messages = [system_message, *historic_messages, *query_messages]
 
         # join all messages
-        return messages
+        return messages