llm_entities.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. from decimal import Decimal
  2. from enum import Enum
  3. from typing import Optional
  4. from pydantic import BaseModel
  5. from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage
  6. from core.model_runtime.entities.model_entities import ModelUsage, PriceInfo
  7. class LLMMode(Enum):
  8. """
  9. Enum class for large language model mode.
  10. """
  11. COMPLETION = "completion"
  12. CHAT = "chat"
  13. @classmethod
  14. def value_of(cls, value: str) -> 'LLMMode':
  15. """
  16. Get value of given mode.
  17. :param value: mode value
  18. :return: mode
  19. """
  20. for mode in cls:
  21. if mode.value == value:
  22. return mode
  23. raise ValueError(f'invalid mode value {value}')
  24. class LLMUsage(ModelUsage):
  25. """
  26. Model class for llm usage.
  27. """
  28. prompt_tokens: int
  29. prompt_unit_price: Decimal
  30. prompt_price_unit: Decimal
  31. prompt_price: Decimal
  32. completion_tokens: int
  33. completion_unit_price: Decimal
  34. completion_price_unit: Decimal
  35. completion_price: Decimal
  36. total_tokens: int
  37. total_price: Decimal
  38. currency: str
  39. latency: float
  40. @classmethod
  41. def empty_usage(cls):
  42. return cls(
  43. prompt_tokens=0,
  44. prompt_unit_price=Decimal('0.0'),
  45. prompt_price_unit=Decimal('0.0'),
  46. prompt_price=Decimal('0.0'),
  47. completion_tokens=0,
  48. completion_unit_price=Decimal('0.0'),
  49. completion_price_unit=Decimal('0.0'),
  50. completion_price=Decimal('0.0'),
  51. total_tokens=0,
  52. total_price=Decimal('0.0'),
  53. currency='USD',
  54. latency=0.0
  55. )
  56. class LLMResult(BaseModel):
  57. """
  58. Model class for llm result.
  59. """
  60. model: str
  61. prompt_messages: list[PromptMessage]
  62. message: AssistantPromptMessage
  63. usage: LLMUsage
  64. system_fingerprint: Optional[str] = None
  65. class LLMResultChunkDelta(BaseModel):
  66. """
  67. Model class for llm result chunk delta.
  68. """
  69. index: int
  70. message: AssistantPromptMessage
  71. usage: Optional[LLMUsage] = None
  72. finish_reason: Optional[str] = None
  73. class LLMResultChunk(BaseModel):
  74. """
  75. Model class for llm result chunk.
  76. """
  77. model: str
  78. prompt_messages: list[PromptMessage]
  79. system_fingerprint: Optional[str] = None
  80. delta: LLMResultChunkDelta
  81. class NumTokensResult(PriceInfo):
  82. """
  83. Model class for number of tokens result.
  84. """
  85. tokens: int