|  | @@ -1,17 +1,24 @@
 | 
	
		
			
				|  |  | +import json
 | 
	
		
			
				|  |  | +import os
 | 
	
		
			
				|  |  | +import re
 | 
	
		
			
				|  |  |  from abc import abstractmethod
 | 
	
		
			
				|  |  | -from typing import List, Optional, Any, Union
 | 
	
		
			
				|  |  | +from typing import List, Optional, Any, Union, Tuple
 | 
	
		
			
				|  |  |  import decimal
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  from langchain.callbacks.manager import Callbacks
 | 
	
		
			
				|  |  | +from langchain.memory.chat_memory import BaseChatMemory
 | 
	
		
			
				|  |  |  from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler
 | 
	
		
			
				|  |  |  from core.model_providers.models.base import BaseProviderModel
 | 
	
		
			
				|  |  | -from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult
 | 
	
		
			
				|  |  | +from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages
 | 
	
		
			
				|  |  |  from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
 | 
	
		
			
				|  |  |  from core.model_providers.providers.base import BaseModelProvider
 | 
	
		
			
				|  |  | +from core.prompt.prompt_builder import PromptBuilder
 | 
	
		
			
				|  |  | +from core.prompt.prompt_template import JinjaPromptTemplate
 | 
	
		
			
				|  |  |  from core.third_party.langchain.llms.fake import FakeLLM
 | 
	
		
			
				|  |  |  import logging
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  logger = logging.getLogger(__name__)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -76,13 +83,14 @@ class BaseLLM(BaseProviderModel):
 | 
	
		
			
				|  |  |      def price_config(self) -> dict:
 | 
	
		
			
				|  |  |          def get_or_default():
 | 
	
		
			
				|  |  |              default_price_config = {
 | 
	
		
			
				|  |  | -                    'prompt': decimal.Decimal('0'),
 | 
	
		
			
				|  |  | -                    'completion': decimal.Decimal('0'),
 | 
	
		
			
				|  |  | -                    'unit': decimal.Decimal('0'),
 | 
	
		
			
				|  |  | -                    'currency': 'USD'
 | 
	
		
			
				|  |  | -                }
 | 
	
		
			
				|  |  | +                'prompt': decimal.Decimal('0'),
 | 
	
		
			
				|  |  | +                'completion': decimal.Decimal('0'),
 | 
	
		
			
				|  |  | +                'unit': decimal.Decimal('0'),
 | 
	
		
			
				|  |  | +                'currency': 'USD'
 | 
	
		
			
				|  |  | +            }
 | 
	
		
			
				|  |  |              rules = self.model_provider.get_rules()
 | 
	
		
			
				|  |  | -            price_config = rules['price_config'][self.base_model_name] if 'price_config' in rules else default_price_config
 | 
	
		
			
				|  |  | +            price_config = rules['price_config'][
 | 
	
		
			
				|  |  | +                self.base_model_name] if 'price_config' in rules else default_price_config
 | 
	
		
			
				|  |  |              price_config = {
 | 
	
		
			
				|  |  |                  'prompt': decimal.Decimal(price_config['prompt']),
 | 
	
		
			
				|  |  |                  'completion': decimal.Decimal(price_config['completion']),
 | 
	
	
		
			
				|  | @@ -90,7 +98,7 @@ class BaseLLM(BaseProviderModel):
 | 
	
		
			
				|  |  |                  'currency': price_config['currency']
 | 
	
		
			
				|  |  |              }
 | 
	
		
			
				|  |  |              return price_config
 | 
	
		
			
				|  |  | -        
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |          self._price_config = self._price_config if hasattr(self, '_price_config') else get_or_default()
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          logger.debug(f"model: {self.name} price_config: {self._price_config}")
 | 
	
	
		
			
				|  | @@ -158,7 +166,8 @@ class BaseLLM(BaseProviderModel):
 | 
	
		
			
				|  |  |              total_tokens = result.llm_output['token_usage']['total_tokens']
 | 
	
		
			
				|  |  |          else:
 | 
	
		
			
				|  |  |              prompt_tokens = self.get_num_tokens(messages)
 | 
	
		
			
				|  |  | -            completion_tokens = self.get_num_tokens([PromptMessage(content=completion_content, type=MessageType.ASSISTANT)])
 | 
	
		
			
				|  |  | +            completion_tokens = self.get_num_tokens(
 | 
	
		
			
				|  |  | +                [PromptMessage(content=completion_content, type=MessageType.ASSISTANT)])
 | 
	
		
			
				|  |  |              total_tokens = prompt_tokens + completion_tokens
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          self.model_provider.update_last_used()
 | 
	
	
		
			
				|  | @@ -293,6 +302,119 @@ class BaseLLM(BaseProviderModel):
 | 
	
		
			
				|  |  |      def support_streaming(cls):
 | 
	
		
			
				|  |  |          return False
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +    def get_prompt(self, mode: str,
 | 
	
		
			
				|  |  | +                   pre_prompt: str, inputs: dict,
 | 
	
		
			
				|  |  | +                   query: str,
 | 
	
		
			
				|  |  | +                   context: Optional[str],
 | 
	
		
			
				|  |  | +                   memory: Optional[BaseChatMemory]) -> \
 | 
	
		
			
				|  |  | +            Tuple[List[PromptMessage], Optional[List[str]]]:
 | 
	
		
			
				|  |  | +        prompt_rules = self._read_prompt_rules_from_file(self.prompt_file_name(mode))
 | 
	
		
			
				|  |  | +        prompt, stops = self._get_prompt_and_stop(prompt_rules, pre_prompt, inputs, query, context, memory)
 | 
	
		
			
				|  |  | +        return [PromptMessage(content=prompt)], stops
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def prompt_file_name(self, mode: str) -> str:
 | 
	
		
			
				|  |  | +        if mode == 'completion':
 | 
	
		
			
				|  |  | +            return 'common_completion'
 | 
	
		
			
				|  |  | +        else:
 | 
	
		
			
				|  |  | +            return 'common_chat'
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def _get_prompt_and_stop(self, prompt_rules: dict, pre_prompt: str, inputs: dict,
 | 
	
		
			
				|  |  | +                             query: str,
 | 
	
		
			
				|  |  | +                             context: Optional[str],
 | 
	
		
			
				|  |  | +                             memory: Optional[BaseChatMemory]) -> Tuple[str, Optional[list]]:
 | 
	
		
			
				|  |  | +        context_prompt_content = ''
 | 
	
		
			
				|  |  | +        if context and 'context_prompt' in prompt_rules:
 | 
	
		
			
				|  |  | +            prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['context_prompt'])
 | 
	
		
			
				|  |  | +            context_prompt_content = prompt_template.format(
 | 
	
		
			
				|  |  | +                context=context
 | 
	
		
			
				|  |  | +            )
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        pre_prompt_content = ''
 | 
	
		
			
				|  |  | +        if pre_prompt:
 | 
	
		
			
				|  |  | +            prompt_template = JinjaPromptTemplate.from_template(template=pre_prompt)
 | 
	
		
			
				|  |  | +            prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
 | 
	
		
			
				|  |  | +            pre_prompt_content = prompt_template.format(
 | 
	
		
			
				|  |  | +                **prompt_inputs
 | 
	
		
			
				|  |  | +            )
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        prompt = ''
 | 
	
		
			
				|  |  | +        for order in prompt_rules['system_prompt_orders']:
 | 
	
		
			
				|  |  | +            if order == 'context_prompt':
 | 
	
		
			
				|  |  | +                prompt += context_prompt_content
 | 
	
		
			
				|  |  | +            elif order == 'pre_prompt':
 | 
	
		
			
				|  |  | +                prompt += (pre_prompt_content + '\n\n') if pre_prompt_content else ''
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        query_prompt = prompt_rules['query_prompt'] if 'query_prompt' in prompt_rules else '{{query}}'
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        if memory and 'histories_prompt' in prompt_rules:
 | 
	
		
			
				|  |  | +            # append chat histories
 | 
	
		
			
				|  |  | +            tmp_human_message = PromptBuilder.to_human_message(
 | 
	
		
			
				|  |  | +                prompt_content=prompt + query_prompt,
 | 
	
		
			
				|  |  | +                inputs={
 | 
	
		
			
				|  |  | +                    'query': query
 | 
	
		
			
				|  |  | +                }
 | 
	
		
			
				|  |  | +            )
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            if self.model_rules.max_tokens.max:
 | 
	
		
			
				|  |  | +                curr_message_tokens = self.get_num_tokens(to_prompt_messages([tmp_human_message]))
 | 
	
		
			
				|  |  | +                max_tokens = self.model_kwargs.max_tokens
 | 
	
		
			
				|  |  | +                rest_tokens = self.model_rules.max_tokens.max - max_tokens - curr_message_tokens
 | 
	
		
			
				|  |  | +                rest_tokens = max(rest_tokens, 0)
 | 
	
		
			
				|  |  | +            else:
 | 
	
		
			
				|  |  | +                rest_tokens = 2000
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            memory.human_prefix = prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human'
 | 
	
		
			
				|  |  | +            memory.ai_prefix = prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            histories = self._get_history_messages_from_memory(memory, rest_tokens)
 | 
	
		
			
				|  |  | +            prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['histories_prompt'])
 | 
	
		
			
				|  |  | +            histories_prompt_content = prompt_template.format(
 | 
	
		
			
				|  |  | +                histories=histories
 | 
	
		
			
				|  |  | +            )
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            prompt = ''
 | 
	
		
			
				|  |  | +            for order in prompt_rules['system_prompt_orders']:
 | 
	
		
			
				|  |  | +                if order == 'context_prompt':
 | 
	
		
			
				|  |  | +                    prompt += context_prompt_content
 | 
	
		
			
				|  |  | +                elif order == 'pre_prompt':
 | 
	
		
			
				|  |  | +                    prompt += (pre_prompt_content + '\n') if pre_prompt_content else ''
 | 
	
		
			
				|  |  | +                elif order == 'histories_prompt':
 | 
	
		
			
				|  |  | +                    prompt += histories_prompt_content
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        prompt_template = JinjaPromptTemplate.from_template(template=query_prompt)
 | 
	
		
			
				|  |  | +        query_prompt_content = prompt_template.format(
 | 
	
		
			
				|  |  | +            query=query
 | 
	
		
			
				|  |  | +        )
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        prompt += query_prompt_content
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        prompt = re.sub(r'<\|.*?\|>', '', prompt)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        stops = prompt_rules.get('stops')
 | 
	
		
			
				|  |  | +        if stops is not None and len(stops) == 0:
 | 
	
		
			
				|  |  | +            stops = None
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        return prompt, stops
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def _read_prompt_rules_from_file(self, prompt_name: str) -> dict:
 | 
	
		
			
				|  |  | +        # Get the absolute path of the subdirectory
 | 
	
		
			
				|  |  | +        prompt_path = os.path.join(
 | 
	
		
			
				|  |  | +            os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))),
 | 
	
		
			
				|  |  | +            'prompt/generate_prompts')
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        json_file_path = os.path.join(prompt_path, f'{prompt_name}.json')
 | 
	
		
			
				|  |  | +        # Open the JSON file and read its content
 | 
	
		
			
				|  |  | +        with open(json_file_path, 'r') as json_file:
 | 
	
		
			
				|  |  | +            return json.load(json_file)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def _get_history_messages_from_memory(self, memory: BaseChatMemory,
 | 
	
		
			
				|  |  | +                                          max_token_limit: int) -> str:
 | 
	
		
			
				|  |  | +        """Get memory messages."""
 | 
	
		
			
				|  |  | +        memory.max_token_limit = max_token_limit
 | 
	
		
			
				|  |  | +        memory_key = memory.memory_variables[0]
 | 
	
		
			
				|  |  | +        external_context = memory.load_memory_variables({})
 | 
	
		
			
				|  |  | +        return external_context[memory_key]
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      def _get_prompt_from_messages(self, messages: List[PromptMessage],
 | 
	
		
			
				|  |  |                                    model_mode: Optional[ModelMode] = None) -> Union[str | List[BaseMessage]]:
 | 
	
		
			
				|  |  |          if not model_mode:
 |