浏览代码

Feat/optimize chat prompt (#158)

John Wang 2 年之前
父节点
当前提交
90150a6ca9
共有 1 个文件被更改,包括 38 次插入33 次删除
  1. 38 33
      api/core/completion.py

+ 38 - 33
api/core/completion.py

@@ -39,7 +39,8 @@ class Completion:
             memory = cls.get_memory_from_conversation(
             memory = cls.get_memory_from_conversation(
                 tenant_id=app.tenant_id,
                 tenant_id=app.tenant_id,
                 app_model_config=app_model_config,
                 app_model_config=app_model_config,
-                conversation=conversation
+                conversation=conversation,
+                return_messages=False
             )
             )
 
 
             inputs = conversation.inputs
             inputs = conversation.inputs
@@ -119,7 +120,8 @@ class Completion:
         return response
         return response
 
 
     @classmethod
     @classmethod
-    def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, pre_prompt: str, query: str, inputs: dict, chain_output: Optional[str],
+    def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, pre_prompt: str, query: str, inputs: dict,
+                            chain_output: Optional[str],
                             memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
                             memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
             Union[str | List[BaseMessage]]:
             Union[str | List[BaseMessage]]:
         pre_prompt = PromptBuilder.process_template(pre_prompt) if pre_prompt else pre_prompt
         pre_prompt = PromptBuilder.process_template(pre_prompt) if pre_prompt else pre_prompt
@@ -161,11 +163,19 @@ And answer according to the language of the user's question.
                 "query": query
                 "query": query
             }
             }
 
 
-            human_message_prompt = "{query}"
+            human_message_prompt = ""
+
+            if pre_prompt:
+                pre_prompt_inputs = {k: inputs[k] for k in
+                                     OutLinePromptTemplate.from_template(template=pre_prompt).input_variables
+                                     if k in inputs}
+
+                if pre_prompt_inputs:
+                    human_inputs.update(pre_prompt_inputs)
 
 
             if chain_output:
             if chain_output:
                 human_inputs['context'] = chain_output
                 human_inputs['context'] = chain_output
-                human_message_instruction = """Use the following CONTEXT as your learned knowledge.
+                human_message_prompt += """Use the following CONTEXT as your learned knowledge.
 [CONTEXT]
 [CONTEXT]
 {context}
 {context}
 [END CONTEXT]
 [END CONTEXT]
@@ -176,39 +186,33 @@ When answer to user:
 Avoid mentioning that you obtained the information from the context.
 Avoid mentioning that you obtained the information from the context.
 And answer according to the language of the user's question.
 And answer according to the language of the user's question.
 """
 """
-                if pre_prompt:
-                    extra_inputs = {k: inputs[k] for k in
-                                    OutLinePromptTemplate.from_template(template=pre_prompt).input_variables
-                                    if k in inputs}
-                    if extra_inputs:
-                        human_inputs.update(extra_inputs)
-                    human_message_instruction += pre_prompt + "\n"
-
-                human_message_prompt = human_message_instruction + "Q:{query}\nA:"
-            else:
-                if pre_prompt:
-                    extra_inputs = {k: inputs[k] for k in
-                                    OutLinePromptTemplate.from_template(template=pre_prompt).input_variables
-                                    if k in inputs}
-                    if extra_inputs:
-                        human_inputs.update(extra_inputs)
-                    human_message_prompt = pre_prompt + "\n" + human_message_prompt
 
 
-            # construct main prompt
-            human_message = PromptBuilder.to_human_message(
-                prompt_content=human_message_prompt,
-                inputs=human_inputs
-            )
+            if pre_prompt:
+                human_message_prompt += pre_prompt
+
+            query_prompt = "\nHuman: {query}\nAI: "
 
 
             if memory:
             if memory:
                 # append chat histories
                 # append chat histories
-                tmp_messages = messages.copy() + [human_message]
-                curr_message_tokens = memory.llm.get_messages_tokens(tmp_messages)
-                rest_tokens = llm_constant.max_context_token_length[
-                                  memory.llm.model_name] - memory.llm.max_tokens - curr_message_tokens
+                tmp_human_message = PromptBuilder.to_human_message(
+                    prompt_content=human_message_prompt + query_prompt,
+                    inputs=human_inputs
+                )
+
+                curr_message_tokens = memory.llm.get_messages_tokens([tmp_human_message])
+                rest_tokens = llm_constant.max_context_token_length[memory.llm.model_name] \
+                              - memory.llm.max_tokens - curr_message_tokens
                 rest_tokens = max(rest_tokens, 0)
                 rest_tokens = max(rest_tokens, 0)
                 history_messages = cls.get_history_messages_from_memory(memory, rest_tokens)
                 history_messages = cls.get_history_messages_from_memory(memory, rest_tokens)
-                messages += history_messages
+                human_message_prompt += "\n\n" + history_messages
+
+            human_message_prompt += query_prompt
+
+            # construct main prompt
+            human_message = PromptBuilder.to_human_message(
+                prompt_content=human_message_prompt,
+                inputs=human_inputs
+            )
 
 
             messages.append(human_message)
             messages.append(human_message)
 
 
@@ -216,7 +220,8 @@ And answer according to the language of the user's question.
 
 
     @classmethod
     @classmethod
     def get_llm_callback_manager(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
     def get_llm_callback_manager(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
-                                 streaming: bool, conversation_message_task: ConversationMessageTask) -> CallbackManager:
+                                 streaming: bool,
+                                 conversation_message_task: ConversationMessageTask) -> CallbackManager:
         llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task)
         llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task)
         if streaming:
         if streaming:
             callback_handlers = [llm_callback_handler, DifyStreamingStdOutCallbackHandler()]
             callback_handlers = [llm_callback_handler, DifyStreamingStdOutCallbackHandler()]
@@ -228,7 +233,7 @@ And answer according to the language of the user's question.
     @classmethod
     @classmethod
     def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
     def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
                                          max_token_limit: int) -> \
                                          max_token_limit: int) -> \
-            List[BaseMessage]:
+            str:
         """Get memory messages."""
         """Get memory messages."""
         memory.max_token_limit = max_token_limit
         memory.max_token_limit = max_token_limit
         memory_key = memory.memory_variables[0]
         memory_key = memory.memory_variables[0]