calc_token_mixin.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. from typing import List, cast
  2. from langchain.schema import BaseMessage
  3. from core.entities.application_entities import ModelConfigEntity
  4. from core.entities.message_entities import lc_messages_to_prompt_messages
  5. from core.model_runtime.entities.message_entities import PromptMessage
  6. from core.model_runtime.entities.model_entities import ModelPropertyKey
  7. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  8. class CalcTokenMixin:
  9. def get_message_rest_tokens(self, model_config: ModelConfigEntity, messages: List[PromptMessage], **kwargs) -> int:
  10. """
  11. Got the rest tokens available for the model after excluding messages tokens and completion max tokens
  12. :param model_config:
  13. :param messages:
  14. :return:
  15. """
  16. model_type_instance = model_config.provider_model_bundle.model_type_instance
  17. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  18. model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
  19. max_tokens = 0
  20. for parameter_rule in model_config.model_schema.parameter_rules:
  21. if (parameter_rule.name == 'max_tokens'
  22. or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
  23. max_tokens = (model_config.parameters.get(parameter_rule.name)
  24. or model_config.parameters.get(parameter_rule.use_template)) or 0
  25. if model_context_tokens is None:
  26. return 0
  27. if max_tokens is None:
  28. max_tokens = 0
  29. prompt_tokens = model_type_instance.get_num_tokens(
  30. model_config.model,
  31. model_config.credentials,
  32. messages
  33. )
  34. rest_tokens = model_context_tokens - max_tokens - prompt_tokens
  35. return rest_tokens
  36. class ExceededLLMTokensLimitError(Exception):
  37. pass