base.py 1.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. from abc import abstractmethod
  2. from typing import Any
  3. import tiktoken
  4. from langchain.schema.language_model import _get_token_ids_default_method
  5. from core.model_providers.models.base import BaseProviderModel
  6. from core.model_providers.models.entity.model_params import ModelType
  7. from core.model_providers.providers.base import BaseModelProvider
  8. class BaseEmbedding(BaseProviderModel):
  9. name: str
  10. type: ModelType = ModelType.EMBEDDINGS
  11. def __init__(self, model_provider: BaseModelProvider, client: Any, name: str):
  12. super().__init__(model_provider, client)
  13. self.name = name
  14. def get_num_tokens(self, text: str) -> int:
  15. """
  16. get num tokens of text.
  17. :param text:
  18. :return:
  19. """
  20. if len(text) == 0:
  21. return 0
  22. return len(_get_token_ids_default_method(text))
  23. def get_token_price(self, tokens: int):
  24. return 0
  25. def get_currency(self):
  26. return 'USD'
  27. @abstractmethod
  28. def handle_exceptions(self, ex: Exception) -> Exception:
  29. raise NotImplementedError