base.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. from abc import abstractmethod
  2. from typing import Any
  3. import decimal
  4. import tiktoken
  5. from langchain.schema.language_model import _get_token_ids_default_method
  6. from core.model_providers.models.base import BaseProviderModel
  7. from core.model_providers.models.entity.model_params import ModelType
  8. from core.model_providers.providers.base import BaseModelProvider
  9. import logging
  10. logger = logging.getLogger(__name__)
  11. class BaseEmbedding(BaseProviderModel):
  12. name: str
  13. type: ModelType = ModelType.EMBEDDINGS
  14. def __init__(self, model_provider: BaseModelProvider, client: Any, name: str):
  15. super().__init__(model_provider, client)
  16. self.name = name
  17. @property
  18. def base_model_name(self) -> str:
  19. """
  20. get base model name
  21. :return: str
  22. """
  23. return self.name
  24. @property
  25. def price_config(self) -> dict:
  26. def get_or_default():
  27. default_price_config = {
  28. 'prompt': decimal.Decimal('0'),
  29. 'completion': decimal.Decimal('0'),
  30. 'unit': decimal.Decimal('0'),
  31. 'currency': 'USD'
  32. }
  33. rules = self.model_provider.get_rules()
  34. price_config = rules['price_config'][self.base_model_name] if 'price_config' in rules else default_price_config
  35. price_config = {
  36. 'prompt': decimal.Decimal(price_config['prompt']),
  37. 'completion': decimal.Decimal(price_config['completion']),
  38. 'unit': decimal.Decimal(price_config['unit']),
  39. 'currency': price_config['currency']
  40. }
  41. return price_config
  42. self._price_config = self._price_config if hasattr(self, '_price_config') else get_or_default()
  43. logger.debug(f"model: {self.name} price_config: {self._price_config}")
  44. return self._price_config
  45. def calc_tokens_price(self, tokens:int) -> decimal.Decimal:
  46. """
  47. calc tokens total price.
  48. :param tokens:
  49. :return: decimal.Decimal('0.0000001')
  50. """
  51. unit_price = self._price_config['completion']
  52. unit = self._price_config['unit']
  53. total_price = tokens * unit_price * unit
  54. total_price = total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
  55. logging.debug(f"tokens={tokens}, unit_price={unit_price}, unit={unit}, total_price:{total_price}")
  56. return total_price
  57. def get_tokens_unit_price(self) -> decimal.Decimal:
  58. """
  59. get token price.
  60. :return: decimal.Decimal('0.0001')
  61. """
  62. unit_price = self._price_config['completion']
  63. unit_price = unit_price.quantize(decimal.Decimal('0.0001'), rounding=decimal.ROUND_HALF_UP)
  64. logger.debug(f'unit_price:{unit_price}')
  65. return unit_price
  66. def get_num_tokens(self, text: str) -> int:
  67. """
  68. get num tokens of text.
  69. :param text:
  70. :return:
  71. """
  72. if len(text) == 0:
  73. return 0
  74. return len(_get_token_ids_default_method(text))
  75. def get_currency(self):
  76. """
  77. get token currency.
  78. :return: get from price config, default 'USD'
  79. """
  80. currency = self._price_config['currency']
  81. return currency
  82. @abstractmethod
  83. def handle_exceptions(self, ex: Exception) -> Exception:
  84. raise NotImplementedError