|  | @@ -1,5 +1,6 @@
 | 
	
		
			
				|  |  |  from abc import abstractmethod
 | 
	
		
			
				|  |  |  from typing import List, Optional, Any, Union
 | 
	
		
			
				|  |  | +import decimal
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  from langchain.callbacks.manager import Callbacks
 | 
	
		
			
				|  |  |  from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
 | 
	
	
		
			
				|  | @@ -10,6 +11,8 @@ from core.model_providers.models.entity.message import PromptMessage, MessageTyp
 | 
	
		
			
				|  |  |  from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
 | 
	
		
			
				|  |  |  from core.model_providers.providers.base import BaseModelProvider
 | 
	
		
			
				|  |  |  from core.third_party.langchain.llms.fake import FakeLLM
 | 
	
		
			
				|  |  | +import logging
 | 
	
		
			
				|  |  | +logger = logging.getLogger(__name__)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  class BaseLLM(BaseProviderModel):
 | 
	
	
		
			
				|  | @@ -60,6 +63,39 @@ class BaseLLM(BaseProviderModel):
 | 
	
		
			
				|  |  |      def _init_client(self) -> Any:
 | 
	
		
			
				|  |  |          raise NotImplementedError
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +    @property
 | 
	
		
			
				|  |  | +    def base_model_name(self) -> str:
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +        get llm base model name
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        :return: str
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +        return self.name
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    @property
 | 
	
		
			
				|  |  | +    def price_config(self) -> dict:
 | 
	
		
			
				|  |  | +        def get_or_default():
 | 
	
		
			
				|  |  | +            default_price_config = {
 | 
	
		
			
				|  |  | +                    'prompt': decimal.Decimal('0'),
 | 
	
		
			
				|  |  | +                    'completion': decimal.Decimal('0'),
 | 
	
		
			
				|  |  | +                    'unit': decimal.Decimal('0'),
 | 
	
		
			
				|  |  | +                    'currency': 'USD'
 | 
	
		
			
				|  |  | +                }
 | 
	
		
			
				|  |  | +            rules = self.model_provider.get_rules()
 | 
	
		
			
				|  |  | +            price_config = rules['price_config'][self.base_model_name] if 'price_config' in rules else default_price_config
 | 
	
		
			
				|  |  | +            price_config = {
 | 
	
		
			
				|  |  | +                'prompt': decimal.Decimal(price_config['prompt']),
 | 
	
		
			
				|  |  | +                'completion': decimal.Decimal(price_config['completion']),
 | 
	
		
			
				|  |  | +                'unit': decimal.Decimal(price_config['unit']),
 | 
	
		
			
				|  |  | +                'currency': price_config['currency']
 | 
	
		
			
				|  |  | +            }
 | 
	
		
			
				|  |  | +            return price_config
 | 
	
		
			
				|  |  | +        
 | 
	
		
			
				|  |  | +        self._price_config = self._price_config if hasattr(self, '_price_config') else get_or_default()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        logger.debug(f"model: {self.name} price_config: {self._price_config}")
 | 
	
		
			
				|  |  | +        return self._price_config
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      def run(self, messages: List[PromptMessage],
 | 
	
		
			
				|  |  |              stop: Optional[List[str]] = None,
 | 
	
		
			
				|  |  |              callbacks: Callbacks = None,
 | 
	
	
		
			
				|  | @@ -161,25 +197,48 @@ class BaseLLM(BaseProviderModel):
 | 
	
		
			
				|  |  |          """
 | 
	
		
			
				|  |  |          raise NotImplementedError
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    @abstractmethod
 | 
	
		
			
				|  |  | -    def get_token_price(self, tokens: int, message_type: MessageType):
 | 
	
		
			
				|  |  | +    def calc_tokens_price(self, tokens:int, message_type: MessageType):
 | 
	
		
			
				|  |  |          """
 | 
	
		
			
				|  |  | -        get token price.
 | 
	
		
			
				|  |  | +        calc tokens total price.
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          :param tokens:
 | 
	
		
			
				|  |  |          :param message_type:
 | 
	
		
			
				|  |  |          :return:
 | 
	
		
			
				|  |  |          """
 | 
	
		
			
				|  |  | -        raise NotImplementedError
 | 
	
		
			
				|  |  | +        if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
 | 
	
		
			
				|  |  | +            unit_price = self.price_config['prompt']
 | 
	
		
			
				|  |  | +        else:
 | 
	
		
			
				|  |  | +            unit_price = self.price_config['completion']
 | 
	
		
			
				|  |  | +        unit = self.price_config['unit']
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        total_price = tokens * unit_price * unit
 | 
	
		
			
				|  |  | +        total_price = total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
 | 
	
		
			
				|  |  | +        logging.debug(f"tokens={tokens}, unit_price={unit_price}, unit={unit}, total_price:{total_price}")
 | 
	
		
			
				|  |  | +        return total_price
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def get_tokens_unit_price(self, message_type: MessageType):
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +        get token price.
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        :param message_type:
 | 
	
		
			
				|  |  | +        :return: decimal.Decimal('0.0001')
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +        if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
 | 
	
		
			
				|  |  | +            unit_price = self.price_config['prompt']
 | 
	
		
			
				|  |  | +        else:
 | 
	
		
			
				|  |  | +            unit_price = self.price_config['completion']
 | 
	
		
			
				|  |  | +        unit_price = unit_price.quantize(decimal.Decimal('0.0001'), rounding=decimal.ROUND_HALF_UP)
 | 
	
		
			
				|  |  | +        logging.debug(f"unit_price={unit_price}")
 | 
	
		
			
				|  |  | +        return unit_price
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    @abstractmethod
 | 
	
		
			
				|  |  |      def get_currency(self):
 | 
	
		
			
				|  |  |          """
 | 
	
		
			
				|  |  |          get token currency.
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        :return:
 | 
	
		
			
				|  |  | +        :return: get from price config, default 'USD'
 | 
	
		
			
				|  |  |          """
 | 
	
		
			
				|  |  | -        raise NotImplementedError
 | 
	
		
			
				|  |  | +        currency = self.price_config['currency']
 | 
	
		
			
				|  |  | +        return currency
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def get_model_kwargs(self):
 | 
	
		
			
				|  |  |          return self.model_kwargs
 |