Browse Source

refactor advanced prompt core. (#1350)

Garfield Dai 1 year ago
parent
commit
fe14130b3c

+ 32 - 12
api/core/completion.py

@@ -16,6 +16,7 @@ from core.model_providers.models.entity.message import PromptMessage
 from core.model_providers.models.llm.base import BaseLLM
 from core.orchestrator_rule_parser import OrchestratorRuleParser
 from core.prompt.prompt_template import PromptTemplateParser
+from core.prompt.prompt_transform import PromptTransform
 from models.model import App, AppModelConfig, Account, Conversation, EndUser
 
 
@@ -156,24 +157,28 @@ class Completion:
                       conversation_message_task: ConversationMessageTask,
                       memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory],
                       fake_response: Optional[str]):
+        prompt_transform = PromptTransform()
+
         # get llm prompt
         if app_model_config.prompt_type == 'simple':
-            prompt_messages, stop_words = model_instance.get_prompt(
+            prompt_messages, stop_words = prompt_transform.get_prompt(
                 mode=mode,
                 pre_prompt=app_model_config.pre_prompt,
                 inputs=inputs,
                 query=query,
                 context=agent_execute_result.output if agent_execute_result else None,
-                memory=memory
+                memory=memory,
+                model_instance=model_instance
             )
         else:
-            prompt_messages = model_instance.get_advanced_prompt(
+            prompt_messages = prompt_transform.get_advanced_prompt(
                 app_mode=mode,
                 app_model_config=app_model_config,
                 inputs=inputs,
                 query=query,
                 context=agent_execute_result.output if agent_execute_result else None,
-                memory=memory
+                memory=memory,
+                model_instance=model_instance
             )
 
             model_config = app_model_config.model_dict
@@ -238,15 +243,30 @@ class Completion:
         if max_tokens is None:
             max_tokens = 0
 
+        prompt_transform = PromptTransform()
+        prompt_messages = []
+
         # get prompt without memory and context
-        prompt_messages, _ = model_instance.get_prompt(
-            mode=mode,
-            pre_prompt=app_model_config.pre_prompt,
-            inputs=inputs,
-            query=query,
-            context=None,
-            memory=None
-        )
+        if app_model_config.prompt_type == 'simple':
+            prompt_messages, _ = prompt_transform.get_prompt(
+                mode=mode,
+                pre_prompt=app_model_config.pre_prompt,
+                inputs=inputs,
+                query=query,
+                context=None,
+                memory=None,
+                model_instance=model_instance
+            )
+        else:
+            prompt_messages = prompt_transform.get_advanced_prompt(
+                app_mode=mode,
+                app_model_config=app_model_config,
+                inputs=inputs,
+                query=query,
+                context=None,
+                memory=None,
+                model_instance=model_instance
+            )
 
         prompt_tokens = model_instance.get_num_tokens(prompt_messages)
         rest_tokens = model_limited_tokens - max_tokens - prompt_tokens

+ 0 - 6
api/core/model_providers/models/llm/baichuan_model.py

@@ -37,12 +37,6 @@ class BaichuanModel(BaseLLM):
         prompts = self._get_prompt_from_messages(messages)
         return self._client.generate([prompts], stop, callbacks)
 
-    def prompt_file_name(self, mode: str) -> str:
-        if mode == 'completion':
-            return 'baichuan_completion'
-        else:
-            return 'baichuan_chat'
-
     def get_num_tokens(self, messages: List[PromptMessage]) -> int:
         """
         get num tokens of prompt messages.

+ 5 - 213
api/core/model_providers/models/llm/base.py

@@ -1,28 +1,18 @@
-import json
-import os
-import re
-import time
 from abc import abstractmethod
-from typing import List, Optional, Any, Union, Tuple
+from typing import List, Optional, Any, Union
 import decimal
+import logging
 
 from langchain.callbacks.manager import Callbacks
-from langchain.memory.chat_memory import BaseChatMemory
-from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
+from langchain.schema import LLMResult, BaseMessage, ChatGeneration
 
 from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler
 from core.helper import moderation
 from core.model_providers.models.base import BaseProviderModel
-from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages, \
-    to_lc_messages
+from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_lc_messages
 from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
 from core.model_providers.providers.base import BaseModelProvider
-from core.prompt.prompt_builder import PromptBuilder
-from core.prompt.prompt_template import PromptTemplateParser
 from core.third_party.langchain.llms.fake import FakeLLM
-import logging
-
-from extensions.ext_database import db
 
 logger = logging.getLogger(__name__)
 
@@ -320,206 +310,8 @@ class BaseLLM(BaseProviderModel):
     def support_streaming(self):
         return False
 
-    def get_prompt(self, mode: str,
-                   pre_prompt: str, inputs: dict,
-                   query: str,
-                   context: Optional[str],
-                   memory: Optional[BaseChatMemory]) -> \
-            Tuple[List[PromptMessage], Optional[List[str]]]:
-        prompt_rules = self._read_prompt_rules_from_file(self.prompt_file_name(mode))
-        prompt, stops = self._get_prompt_and_stop(prompt_rules, pre_prompt, inputs, query, context, memory)
-        return [PromptMessage(content=prompt)], stops
-
-    def get_advanced_prompt(self, app_mode: str,
-                   app_model_config: str, inputs: dict,
-                   query: str,
-                   context: Optional[str],
-                   memory: Optional[BaseChatMemory]) -> List[PromptMessage]:
-
-        model_mode = app_model_config.model_dict['mode']
-        conversation_histories_role = {}
-
-        raw_prompt_list = []
-        prompt_messages = []
-
-        if app_mode == 'chat' and model_mode == ModelMode.COMPLETION.value:
-            prompt_text = app_model_config.completion_prompt_config_dict['prompt']['text']
-            raw_prompt_list = [{
-                'role': MessageType.USER.value,
-                'text': prompt_text
-            }]
-            conversation_histories_role = app_model_config.completion_prompt_config_dict['conversation_histories_role']
-        elif app_mode == 'chat' and model_mode == ModelMode.CHAT.value:
-            raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']
-        elif app_mode == 'completion' and model_mode == ModelMode.CHAT.value:
-            raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']
-        elif app_mode == 'completion' and model_mode == ModelMode.COMPLETION.value:
-            prompt_text = app_model_config.completion_prompt_config_dict['prompt']['text']
-            raw_prompt_list = [{
-                'role': MessageType.USER.value,
-                'text': prompt_text
-            }]
-        else:
-            raise Exception("app_mode or model_mode not support")
-
-        for prompt_item in raw_prompt_list:
-            prompt = prompt_item['text']
-
-            # set prompt template variables
-            prompt_template = PromptTemplateParser(template=prompt)
-            prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
-
-            if '#context#' in prompt:
-                if context:
-                    prompt_inputs['#context#'] = context
-                else:
-                    prompt_inputs['#context#'] = ''
-
-            if '#query#' in prompt:
-                if query:
-                    prompt_inputs['#query#'] = query
-                else:
-                    prompt_inputs['#query#'] = ''
-
-            if '#histories#' in prompt:
-                if memory and app_mode == 'chat' and model_mode == ModelMode.COMPLETION.value:
-                    memory.human_prefix = conversation_histories_role['user_prefix']
-                    memory.ai_prefix = conversation_histories_role['assistant_prefix']
-                    histories = self._get_history_messages_from_memory(memory, 2000)
-                    prompt_inputs['#histories#'] = histories
-                else:
-                    prompt_inputs['#histories#'] = ''
-
-            prompt = prompt_template.format(
-                prompt_inputs
-            )
-
-            prompt = re.sub(r'<\|.*?\|>', '', prompt)
-
-            prompt_messages.append(PromptMessage(type = MessageType(prompt_item['role']) ,content=prompt))
-
-        if memory and app_mode == 'chat' and model_mode == ModelMode.CHAT.value:
-            memory.human_prefix = MessageType.USER.value
-            memory.ai_prefix = MessageType.ASSISTANT.value
-            histories = self._get_history_messages_list_from_memory(memory, 2000)
-            prompt_messages.extend(histories)
-
-        if app_mode == 'chat' and model_mode == ModelMode.CHAT.value:
-            prompt_messages.append(PromptMessage(type = MessageType.USER ,content=query))
-
-        return prompt_messages
-
-    def prompt_file_name(self, mode: str) -> str:
-        if mode == 'completion':
-            return 'common_completion'
-        else:
-            return 'common_chat'
-
-    def _get_prompt_and_stop(self, prompt_rules: dict, pre_prompt: str, inputs: dict,
-                             query: str,
-                             context: Optional[str],
-                             memory: Optional[BaseChatMemory]) -> Tuple[str, Optional[list]]:
-        context_prompt_content = ''
-        if context and 'context_prompt' in prompt_rules:
-            prompt_template = PromptTemplateParser(template=prompt_rules['context_prompt'])
-            context_prompt_content = prompt_template.format(
-                {'context': context}
-            )
-
-        pre_prompt_content = ''
-        if pre_prompt:
-            prompt_template = PromptTemplateParser(template=pre_prompt)
-            prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
-            pre_prompt_content = prompt_template.format(
-                prompt_inputs
-            )
-
-        prompt = ''
-        for order in prompt_rules['system_prompt_orders']:
-            if order == 'context_prompt':
-                prompt += context_prompt_content
-            elif order == 'pre_prompt':
-                prompt += pre_prompt_content
-
-        query_prompt = prompt_rules['query_prompt'] if 'query_prompt' in prompt_rules else '{{query}}'
-
-        if memory and 'histories_prompt' in prompt_rules:
-            # append chat histories
-            tmp_human_message = PromptBuilder.to_human_message(
-                prompt_content=prompt + query_prompt,
-                inputs={
-                    'query': query
-                }
-            )
-
-            if self.model_rules.max_tokens.max:
-                curr_message_tokens = self.get_num_tokens(to_prompt_messages([tmp_human_message]))
-                max_tokens = self.model_kwargs.max_tokens
-                rest_tokens = self.model_rules.max_tokens.max - max_tokens - curr_message_tokens
-                rest_tokens = max(rest_tokens, 0)
-            else:
-                rest_tokens = 2000
-
-            memory.human_prefix = prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human'
-            memory.ai_prefix = prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
-
-            histories = self._get_history_messages_from_memory(memory, rest_tokens)
-            prompt_template = PromptTemplateParser(template=prompt_rules['histories_prompt'])
-            histories_prompt_content = prompt_template.format({'histories': histories})
-
-            prompt = ''
-            for order in prompt_rules['system_prompt_orders']:
-                if order == 'context_prompt':
-                    prompt += context_prompt_content
-                elif order == 'pre_prompt':
-                    prompt += (pre_prompt_content + '\n') if pre_prompt_content else ''
-                elif order == 'histories_prompt':
-                    prompt += histories_prompt_content
-
-        prompt_template = PromptTemplateParser(template=query_prompt)
-        query_prompt_content = prompt_template.format({'query': query})
-
-        prompt += query_prompt_content
-
-        prompt = re.sub(r'<\|.*?\|>', '', prompt)
-
-        stops = prompt_rules.get('stops')
-        if stops is not None and len(stops) == 0:
-            stops = None
-
-        return prompt, stops
-
-    def _read_prompt_rules_from_file(self, prompt_name: str) -> dict:
-        # Get the absolute path of the subdirectory
-        prompt_path = os.path.join(
-            os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))),
-            'prompt/generate_prompts')
-
-        json_file_path = os.path.join(prompt_path, f'{prompt_name}.json')
-        # Open the JSON file and read its content
-        with open(json_file_path, 'r') as json_file:
-            return json.load(json_file)
-
-    def _get_history_messages_from_memory(self, memory: BaseChatMemory,
-                                          max_token_limit: int) -> str:
-        """Get memory messages."""
-        memory.max_token_limit = max_token_limit
-        memory_key = memory.memory_variables[0]
-        external_context = memory.load_memory_variables({})
-        return external_context[memory_key]
-
-    def _get_history_messages_list_from_memory(self, memory: BaseChatMemory,
-                                          max_token_limit: int) -> List[PromptMessage]:
-        """Get memory messages."""
-        memory.max_token_limit = max_token_limit
-        memory.return_messages = True
-        memory_key = memory.memory_variables[0]
-        external_context = memory.load_memory_variables({})
-        memory.return_messages = False
-        return to_prompt_messages(external_context[memory_key])
-
     def _get_prompt_from_messages(self, messages: List[PromptMessage],
-                                  model_mode: Optional[ModelMode] = None) -> Union[str | List[BaseMessage]]:
+                                  model_mode: Optional[ModelMode] = None) -> Union[str , List[BaseMessage]]:
         if not model_mode:
             model_mode = self.model_mode
 

+ 0 - 9
api/core/model_providers/models/llm/huggingface_hub_model.py

@@ -66,15 +66,6 @@ class HuggingfaceHubModel(BaseLLM):
         prompts = self._get_prompt_from_messages(messages)
         return self._client.get_num_tokens(prompts)
 
-    def prompt_file_name(self, mode: str) -> str:
-        if 'baichuan' in self.name.lower():
-            if mode == 'completion':
-                return 'baichuan_completion'
-            else:
-                return 'baichuan_chat'
-        else:
-            return super().prompt_file_name(mode)
-
     def _set_model_kwargs(self, model_kwargs: ModelKwargs):
         provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
         self.client.model_kwargs = provider_model_kwargs

+ 0 - 9
api/core/model_providers/models/llm/openllm_model.py

@@ -49,15 +49,6 @@ class OpenLLMModel(BaseLLM):
         prompts = self._get_prompt_from_messages(messages)
         return max(self._client.get_num_tokens(prompts), 0)
 
-    def prompt_file_name(self, mode: str) -> str:
-        if 'baichuan' in self.name.lower():
-            if mode == 'completion':
-                return 'baichuan_completion'
-            else:
-                return 'baichuan_chat'
-        else:
-            return super().prompt_file_name(mode)
-
     def _set_model_kwargs(self, model_kwargs: ModelKwargs):
         pass
 

+ 0 - 9
api/core/model_providers/models/llm/xinference_model.py

@@ -59,15 +59,6 @@ class XinferenceModel(BaseLLM):
         prompts = self._get_prompt_from_messages(messages)
         return max(self._client.get_num_tokens(prompts), 0)
 
-    def prompt_file_name(self, mode: str) -> str:
-        if 'baichuan' in self.name.lower():
-            if mode == 'completion':
-                return 'baichuan_completion'
-            else:
-                return 'baichuan_chat'
-        else:
-            return super().prompt_file_name(mode)
-
     def _set_model_kwargs(self, model_kwargs: ModelKwargs):
         pass
 

+ 344 - 0
api/core/prompt/prompt_transform.py

@@ -0,0 +1,344 @@
+import json
+import os
+import re
+import enum
+from typing import List, Optional, Tuple
+
+from langchain.memory.chat_memory import BaseChatMemory
+from langchain.schema import BaseMessage
+
+from core.model_providers.models.entity.model_params import ModelMode
+from core.model_providers.models.entity.message import PromptMessage, MessageType, to_prompt_messages
+from core.model_providers.models.llm.base import BaseLLM
+from core.model_providers.models.llm.baichuan_model import BaichuanModel
+from core.model_providers.models.llm.huggingface_hub_model import HuggingfaceHubModel
+from core.model_providers.models.llm.openllm_model import OpenLLMModel
+from core.model_providers.models.llm.xinference_model import XinferenceModel
+from core.prompt.prompt_builder import PromptBuilder
+from core.prompt.prompt_template import PromptTemplateParser
+
+class AppMode(enum.Enum):
+    COMPLETION = 'completion'
+    CHAT = 'chat'
+
+class PromptTransform:
+    def get_prompt(self, mode: str,
+                   pre_prompt: str, inputs: dict,
+                   query: str,
+                   context: Optional[str],
+                   memory: Optional[BaseChatMemory],
+                   model_instance: BaseLLM) -> \
+            Tuple[List[PromptMessage], Optional[List[str]]]:
+        prompt_rules = self._read_prompt_rules_from_file(self._prompt_file_name(mode, model_instance))
+        prompt, stops = self._get_prompt_and_stop(prompt_rules, pre_prompt, inputs, query, context, memory, model_instance)
+        return [PromptMessage(content=prompt)], stops
+
+    def get_advanced_prompt(self, 
+            app_mode: str,
+            app_model_config: str, 
+            inputs: dict,
+            query: str,
+            context: Optional[str],
+            memory: Optional[BaseChatMemory],
+            model_instance: BaseLLM) -> List[PromptMessage]:
+        
+        model_mode = app_model_config.model_dict['mode']
+
+        app_mode_enum = AppMode(app_mode)
+        model_mode_enum = ModelMode(model_mode)
+
+        prompt_messages = []
+
+        if app_mode_enum == AppMode.CHAT:
+            if model_mode_enum == ModelMode.COMPLETION:
+                prompt_messages = self._get_chat_app_completion_model_prompt_messages(app_model_config, inputs, query, context, memory, model_instance)
+            elif model_mode_enum == ModelMode.CHAT:
+                prompt_messages =  self._get_chat_app_chat_model_prompt_messages(app_model_config, inputs, query, context, memory, model_instance)
+        elif app_mode_enum == AppMode.COMPLETION:
+            if model_mode_enum == ModelMode.CHAT:
+                prompt_messages =  self._get_completion_app_chat_model_prompt_messages(app_model_config, inputs, context)
+            elif model_mode_enum == ModelMode.COMPLETION:
+                prompt_messages =  self._get_completion_app_completion_model_prompt_messages(app_model_config, inputs, context)
+            
+        return prompt_messages
+
+    def _get_history_messages_from_memory(self, memory: BaseChatMemory,
+                                          max_token_limit: int) -> str:
+        """Get memory messages."""
+        memory.max_token_limit = max_token_limit
+        memory_key = memory.memory_variables[0]
+        external_context = memory.load_memory_variables({})
+        return external_context[memory_key]
+
+    def _get_history_messages_list_from_memory(self, memory: BaseChatMemory,
+                                          max_token_limit: int) -> List[PromptMessage]:
+        """Get memory messages."""
+        memory.max_token_limit = max_token_limit
+        memory.return_messages = True
+        memory_key = memory.memory_variables[0]
+        external_context = memory.load_memory_variables({})
+        memory.return_messages = False
+        return to_prompt_messages(external_context[memory_key])
+    
+    def _prompt_file_name(self, mode: str, model_instance: BaseLLM) -> str:
+        # baichuan
+        if isinstance(model_instance, BaichuanModel):
+            return self._prompt_file_name_for_baichuan(mode)
+
+        baichuan_model_hosted_platforms = (HuggingfaceHubModel, OpenLLMModel, XinferenceModel)
+        if isinstance(model_instance, baichuan_model_hosted_platforms) and 'baichuan' in model_instance.name.lower():
+            return self._prompt_file_name_for_baichuan(mode)
+
+        # common
+        if mode == 'completion':
+            return 'common_completion'
+        else:
+            return 'common_chat'
+        
+    def _prompt_file_name_for_baichuan(self, mode: str) -> str:
+        if mode == 'completion':
+            return 'baichuan_completion'
+        else:
+            return 'baichuan_chat'
+    
+    def _read_prompt_rules_from_file(self, prompt_name: str) -> dict:
+        # Get the absolute path of the subdirectory
+        prompt_path = os.path.join(
+            os.path.dirname(os.path.realpath(__file__)),
+            'generate_prompts')
+
+        json_file_path = os.path.join(prompt_path, f'{prompt_name}.json')
+        # Open the JSON file and read its content
+        with open(json_file_path, 'r') as json_file:
+            return json.load(json_file)
+        
+    def _get_prompt_and_stop(self, prompt_rules: dict, pre_prompt: str, inputs: dict,
+                             query: str,
+                             context: Optional[str],
+                             memory: Optional[BaseChatMemory],
+                             model_instance: BaseLLM) -> Tuple[str, Optional[list]]:
+        context_prompt_content = ''
+        if context and 'context_prompt' in prompt_rules:
+            prompt_template = PromptTemplateParser(template=prompt_rules['context_prompt'])
+            context_prompt_content = prompt_template.format(
+                {'context': context}
+            )
+
+        pre_prompt_content = ''
+        if pre_prompt:
+            prompt_template = PromptTemplateParser(template=pre_prompt)
+            prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
+            pre_prompt_content = prompt_template.format(
+                prompt_inputs
+            )
+
+        prompt = ''
+        for order in prompt_rules['system_prompt_orders']:
+            if order == 'context_prompt':
+                prompt += context_prompt_content
+            elif order == 'pre_prompt':
+                prompt += pre_prompt_content
+
+        query_prompt = prompt_rules['query_prompt'] if 'query_prompt' in prompt_rules else '{{query}}'
+
+        if memory and 'histories_prompt' in prompt_rules:
+            # append chat histories
+            tmp_human_message = PromptBuilder.to_human_message(
+                prompt_content=prompt + query_prompt,
+                inputs={
+                    'query': query
+                }
+            )
+
+            rest_tokens = self._calculate_rest_token(tmp_human_message, model_instance)
+
+            memory.human_prefix = prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human'
+            memory.ai_prefix = prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
+
+            histories = self._get_history_messages_from_memory(memory, rest_tokens)
+            prompt_template = PromptTemplateParser(template=prompt_rules['histories_prompt'])
+            histories_prompt_content = prompt_template.format({'histories': histories})
+
+            prompt = ''
+            for order in prompt_rules['system_prompt_orders']:
+                if order == 'context_prompt':
+                    prompt += context_prompt_content
+                elif order == 'pre_prompt':
+                    prompt += (pre_prompt_content + '\n') if pre_prompt_content else ''
+                elif order == 'histories_prompt':
+                    prompt += histories_prompt_content
+
+        prompt_template = PromptTemplateParser(template=query_prompt)
+        query_prompt_content = prompt_template.format({'query': query})
+
+        prompt += query_prompt_content
+
+        prompt = re.sub(r'<\|.*?\|>', '', prompt)
+
+        stops = prompt_rules.get('stops')
+        if stops is not None and len(stops) == 0:
+            stops = None
+
+        return prompt, stops
+    
+    def _set_context_variable(self, context: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> None:
+        if '#context#' in prompt_template.variable_keys:
+            if context:
+                prompt_inputs['#context#'] = context    
+            else:
+                prompt_inputs['#context#'] = ''
+
+    def _set_query_variable(self, query: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> None:
+        if '#query#' in prompt_template.variable_keys:
+            if query:
+                prompt_inputs['#query#'] = query
+            else:
+                prompt_inputs['#query#'] = ''
+
+    def _set_histories_variable(self, memory: BaseChatMemory, raw_prompt: str, conversation_histories_role: dict, 
+                                prompt_template: PromptTemplateParser, prompt_inputs: dict, model_instance: BaseLLM) -> None:
+        if '#histories#' in prompt_template.variable_keys:
+            if memory:
+                tmp_human_message = PromptBuilder.to_human_message(
+                    prompt_content=raw_prompt,
+                    inputs={ '#histories#': '', **prompt_inputs }
+                )
+
+                rest_tokens = self._calculate_rest_token(tmp_human_message, model_instance)
+                
+                memory.human_prefix = conversation_histories_role['user_prefix']
+                memory.ai_prefix = conversation_histories_role['assistant_prefix']
+                histories = self._get_history_messages_from_memory(memory, rest_tokens)
+                prompt_inputs['#histories#'] = histories
+            else:
+                prompt_inputs['#histories#'] = ''
+
+    def _append_chat_histories(self, memory: BaseChatMemory, prompt_messages: list[PromptMessage], model_instance: BaseLLM) -> None:
+        if memory:
+            rest_tokens = self._calculate_rest_token(prompt_messages, model_instance)
+
+            memory.human_prefix = MessageType.USER.value
+            memory.ai_prefix = MessageType.ASSISTANT.value
+            histories = self._get_history_messages_list_from_memory(memory, rest_tokens)
+            prompt_messages.extend(histories)
+
+    def _calculate_rest_token(self, prompt_messages: BaseMessage, model_instance: BaseLLM) -> int:
+        rest_tokens = 2000
+
+        if model_instance.model_rules.max_tokens.max:
+            curr_message_tokens = model_instance.get_num_tokens(to_prompt_messages(prompt_messages))
+            max_tokens = model_instance.model_kwargs.max_tokens
+            rest_tokens = model_instance.model_rules.max_tokens.max - max_tokens - curr_message_tokens
+            rest_tokens = max(rest_tokens, 0)
+
+        return rest_tokens
+
+    def _format_prompt(self, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> str:
+        prompt = prompt_template.format(
+            prompt_inputs
+        )
+
+        prompt = re.sub(r'<\|.*?\|>', '', prompt)
+        return prompt
+
+    def _get_chat_app_completion_model_prompt_messages(self,
+            app_model_config: str,
+            inputs: dict,
+            query: str,
+            context: Optional[str],
+            memory: Optional[BaseChatMemory],
+            model_instance: BaseLLM) -> List[PromptMessage]:
+        
+        raw_prompt = app_model_config.completion_prompt_config_dict['prompt']['text']
+        conversation_histories_role = app_model_config.completion_prompt_config_dict['conversation_histories_role']
+
+        prompt_messages = []
+        prompt = ''
+        
+        prompt_template = PromptTemplateParser(template=raw_prompt)
+        prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
+
+        self._set_context_variable(context, prompt_template, prompt_inputs)
+
+        self._set_query_variable(query, prompt_template, prompt_inputs)
+
+        self._set_histories_variable(memory, raw_prompt, conversation_histories_role, prompt_template, prompt_inputs, model_instance)
+
+        prompt = self._format_prompt(prompt_template, prompt_inputs)
+
+        prompt_messages.append(PromptMessage(type = MessageType(MessageType.USER) ,content=prompt))
+
+        return prompt_messages
+
+    def _get_chat_app_chat_model_prompt_messages(self,
+            app_model_config: str,
+            inputs: dict,
+            query: str,
+            context: Optional[str],
+            memory: Optional[BaseChatMemory],
+            model_instance: BaseLLM) -> List[PromptMessage]:
+        raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']
+
+        prompt_messages = []
+
+        for prompt_item in raw_prompt_list:
+            raw_prompt = prompt_item['text']
+            prompt = ''
+
+            prompt_template = PromptTemplateParser(template=raw_prompt)
+            prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
+
+            self._set_context_variable(context, prompt_template, prompt_inputs)
+
+            prompt = self._format_prompt(prompt_template, prompt_inputs)
+
+            prompt_messages.append(PromptMessage(type = MessageType(prompt_item['role']) ,content=prompt))
+        
+        self._append_chat_histories(memory, prompt_messages, model_instance)
+
+        prompt_messages.append(PromptMessage(type = MessageType.USER ,content=query))
+
+        return prompt_messages
+
+    def _get_completion_app_completion_model_prompt_messages(self,
+                   app_model_config: str,
+                   inputs: dict,
+                   context: Optional[str]) -> List[PromptMessage]:
+        raw_prompt = app_model_config.completion_prompt_config_dict['prompt']['text']
+
+        prompt_messages = []
+        prompt = ''
+        
+        prompt_template = PromptTemplateParser(template=raw_prompt)
+        prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
+
+        self._set_context_variable(context, prompt_template, prompt_inputs)
+
+        prompt = self._format_prompt(prompt_template, prompt_inputs)
+
+        prompt_messages.append(PromptMessage(type = MessageType(MessageType.USER) ,content=prompt))
+
+        return prompt_messages
+
+    def _get_completion_app_chat_model_prompt_messages(self,
+                   app_model_config: str,
+                   inputs: dict,
+                   context: Optional[str]) -> List[PromptMessage]:
+        raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']
+
+        prompt_messages = []
+
+        for prompt_item in raw_prompt_list:
+            raw_prompt = prompt_item['text']
+            prompt = ''
+
+            prompt_template = PromptTemplateParser(template=raw_prompt)
+            prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
+
+            self._set_context_variable(context, prompt_template, prompt_inputs)
+
+            prompt = self._format_prompt(prompt_template, prompt_inputs)
+
+            prompt_messages.append(PromptMessage(type = MessageType(prompt_item['role']) ,content=prompt))
+        
+        return prompt_messages

+ 15 - 13
api/services/advanced_prompt_template_service.py

@@ -1,6 +1,8 @@
 
 import copy
 
+from core.model_providers.models.entity.model_params import ModelMode
+from core.prompt.prompt_transform import AppMode
 from core.prompt.advanced_prompt_templates import CHAT_APP_COMPLETION_PROMPT_CONFIG, CHAT_APP_CHAT_PROMPT_CONFIG, COMPLETION_APP_CHAT_PROMPT_CONFIG, COMPLETION_APP_COMPLETION_PROMPT_CONFIG, \
     BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG, BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG, BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG, CONTEXT, BAICHUAN_CONTEXT
 
@@ -13,7 +15,7 @@ class AdvancedPromptTemplateService:
         model_name = args['model_name']
         has_context = args['has_context']
 
-        if 'baichuan' in model_name:
+        if 'baichuan' in model_name.lower():
             return cls.get_baichuan_prompt(app_mode, model_mode, has_context)
         else:
             return cls.get_common_prompt(app_mode, model_mode, has_context)
@@ -22,15 +24,15 @@ class AdvancedPromptTemplateService:
     def get_common_prompt(cls, app_mode: str, model_mode:str, has_context: str) -> dict:
         context_prompt = copy.deepcopy(CONTEXT)
 
-        if app_mode == 'chat':
-            if model_mode == 'completion':
+        if app_mode == AppMode.CHAT.value:
+            if model_mode == ModelMode.COMPLETION.value:
                 return cls.get_completion_prompt(copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt)
-            elif model_mode == 'chat':
+            elif model_mode == ModelMode.CHAT.value:
                 return cls.get_chat_prompt(copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt)
-        elif app_mode == 'completion':
-            if model_mode == 'completion':
+        elif app_mode == AppMode.COMPLETION.value:
+            if model_mode == ModelMode.COMPLETION.value:
                 return cls.get_completion_prompt(copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt)
-            elif model_mode == 'chat':
+            elif model_mode == ModelMode.CHAT.value:
                 return cls.get_chat_prompt(copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt)
             
     @classmethod
@@ -51,13 +53,13 @@ class AdvancedPromptTemplateService:
     def get_baichuan_prompt(cls, app_mode: str, model_mode:str, has_context: str) -> dict:
         baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT)
 
-        if app_mode == 'chat':
-            if model_mode == 'completion':
+        if app_mode == AppMode.CHAT.value:
+            if model_mode == ModelMode.COMPLETION.value:
                 return cls.get_completion_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt)
-            elif model_mode == 'chat':
+            elif model_mode == ModelMode.CHAT.value:
                 return cls.get_chat_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt)
-        elif app_mode == 'completion':
-            if model_mode == 'completion':
+        elif app_mode == AppMode.COMPLETION.value:
+            if model_mode == ModelMode.COMPLETION.value:
                 return cls.get_completion_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt)
-            elif model_mode == 'chat':
+            elif model_mode == ModelMode.CHAT.value:
                 return cls.get_chat_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt)

+ 9 - 1
api/services/app_model_config_service.py

@@ -1,6 +1,7 @@
 import re
 import uuid
 
+from core.prompt.prompt_transform import AppMode
 from core.agent.agent_executor import PlanningStrategy
 from core.model_providers.model_provider_factory import ModelProviderFactory
 from core.model_providers.models.entity.model_params import ModelType, ModelMode
@@ -418,7 +419,7 @@ class AppModelConfigService:
             if config['model']["mode"] not in ['chat', 'completion']:
                 raise ValueError("model.mode must be in ['chat', 'completion'] when prompt_type is advanced")
             
-            if app_mode == 'chat' and config['model']["mode"] == ModelMode.COMPLETION.value:
+            if app_mode == AppMode.CHAT.value and config['model']["mode"] == ModelMode.COMPLETION.value:
                 user_prefix = config['completion_prompt_config']['conversation_histories_role']['user_prefix']
                 assistant_prefix = config['completion_prompt_config']['conversation_histories_role']['assistant_prefix']
 
@@ -427,3 +428,10 @@ class AppModelConfigService:
 
                 if not assistant_prefix:
                     config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] = 'Assistant'
+
+
+            if config['model']["mode"] == ModelMode.CHAT.value:
+                prompt_list = config['chat_prompt_config']['prompt']
+
+                if len(prompt_list) > 10:
+                    raise ValueError("prompt messages must be less than 10")