base.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. from abc import abstractmethod
  2. from typing import Any
  3. import decimal
  4. import tiktoken
  5. from langchain.schema.language_model import _get_token_ids_default_method
  6. from core.model_providers.models.base import BaseProviderModel
  7. from core.model_providers.models.entity.model_params import ModelType
  8. from core.model_providers.providers.base import BaseModelProvider
  9. import logging
  10. logger = logging.getLogger(__name__)
  11. class BaseEmbedding(BaseProviderModel):
  12. name: str
  13. type: ModelType = ModelType.EMBEDDINGS
  14. def __init__(self, model_provider: BaseModelProvider, client: Any, name: str):
  15. super().__init__(model_provider, client)
  16. self.name = name
  17. @property
  18. def base_model_name(self) -> str:
  19. """
  20. get base model name
  21. :return: str
  22. """
  23. return self.name
  24. @property
  25. def price_config(self) -> dict:
  26. def get_or_default():
  27. default_price_config = {
  28. 'completion': decimal.Decimal('0'),
  29. 'unit': decimal.Decimal('0'),
  30. 'currency': 'USD'
  31. }
  32. rules = self.model_provider.get_rules()
  33. price_config = rules['price_config'][self.base_model_name] if 'price_config' in rules else default_price_config
  34. price_config = {
  35. 'completion': decimal.Decimal(price_config['completion']),
  36. 'unit': decimal.Decimal(price_config['unit']),
  37. 'currency': price_config['currency']
  38. }
  39. return price_config
  40. self._price_config = self._price_config if hasattr(self, '_price_config') else get_or_default()
  41. logger.debug(f"model: {self.name} price_config: {self._price_config}")
  42. return self._price_config
  43. def calc_tokens_price(self, tokens:int) -> decimal.Decimal:
  44. """
  45. calc tokens total price.
  46. :param tokens:
  47. :return: decimal.Decimal('0.0000001')
  48. """
  49. unit_price = self.price_config['completion']
  50. unit = self.price_config['unit']
  51. total_price = tokens * unit_price * unit
  52. total_price = total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
  53. logging.debug(f"tokens={tokens}, unit_price={unit_price}, unit={unit}, total_price:{total_price}")
  54. return total_price
  55. def get_tokens_unit_price(self) -> decimal.Decimal:
  56. """
  57. get token price.
  58. :return: decimal.Decimal('0.0001')
  59. """
  60. unit_price = self.price_config['completion']
  61. unit_price = unit_price.quantize(decimal.Decimal('0.0001'), rounding=decimal.ROUND_HALF_UP)
  62. logger.debug(f'unit_price:{unit_price}')
  63. return unit_price
  64. def get_num_tokens(self, text: str) -> int:
  65. """
  66. get num tokens of text.
  67. :param text:
  68. :return:
  69. """
  70. if len(text) == 0:
  71. return 0
  72. return len(_get_token_ids_default_method(text))
  73. def get_currency(self):
  74. """
  75. get token currency.
  76. :return: get from price config, default 'USD'
  77. """
  78. currency = self.price_config['currency']
  79. return currency
  80. @abstractmethod
  81. def handle_exceptions(self, ex: Exception) -> Exception:
  82. raise NotImplementedError