message_entities.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. from abc import ABC
  2. from enum import Enum
  3. from typing import Optional
  4. from pydantic import BaseModel, Field, field_validator
  5. class PromptMessageRole(Enum):
  6. """
  7. Enum class for prompt message.
  8. """
  9. SYSTEM = "system"
  10. USER = "user"
  11. ASSISTANT = "assistant"
  12. TOOL = "tool"
  13. @classmethod
  14. def value_of(cls, value: str) -> "PromptMessageRole":
  15. """
  16. Get value of given mode.
  17. :param value: mode value
  18. :return: mode
  19. """
  20. for mode in cls:
  21. if mode.value == value:
  22. return mode
  23. raise ValueError(f"invalid prompt message type value {value}")
  24. class PromptMessageTool(BaseModel):
  25. """
  26. Model class for prompt message tool.
  27. """
  28. name: str
  29. description: str
  30. parameters: dict
  31. class PromptMessageFunction(BaseModel):
  32. """
  33. Model class for prompt message function.
  34. """
  35. type: str = "function"
  36. function: PromptMessageTool
  37. class PromptMessageContentType(Enum):
  38. """
  39. Enum class for prompt message content type.
  40. """
  41. TEXT = "text"
  42. IMAGE = "image"
  43. AUDIO = "audio"
  44. class PromptMessageContent(BaseModel):
  45. """
  46. Model class for prompt message content.
  47. """
  48. type: PromptMessageContentType
  49. data: str
  50. class TextPromptMessageContent(PromptMessageContent):
  51. """
  52. Model class for text prompt message content.
  53. """
  54. type: PromptMessageContentType = PromptMessageContentType.TEXT
  55. class AudioPromptMessageContent(PromptMessageContent):
  56. type: PromptMessageContentType = PromptMessageContentType.AUDIO
  57. data: str = Field(..., description="Base64 encoded audio data")
  58. format: str = Field(..., description="Audio format")
  59. class ImagePromptMessageContent(PromptMessageContent):
  60. """
  61. Model class for image prompt message content.
  62. """
  63. class DETAIL(str, Enum):
  64. LOW = "low"
  65. HIGH = "high"
  66. type: PromptMessageContentType = PromptMessageContentType.IMAGE
  67. detail: DETAIL = DETAIL.LOW
  68. class PromptMessage(ABC, BaseModel):
  69. """
  70. Model class for prompt message.
  71. """
  72. role: PromptMessageRole
  73. content: Optional[str | list[PromptMessageContent]] = None
  74. name: Optional[str] = None
  75. def is_empty(self) -> bool:
  76. """
  77. Check if prompt message is empty.
  78. :return: True if prompt message is empty, False otherwise
  79. """
  80. return not self.content
  81. class UserPromptMessage(PromptMessage):
  82. """
  83. Model class for user prompt message.
  84. """
  85. role: PromptMessageRole = PromptMessageRole.USER
  86. class AssistantPromptMessage(PromptMessage):
  87. """
  88. Model class for assistant prompt message.
  89. """
  90. class ToolCall(BaseModel):
  91. """
  92. Model class for assistant prompt message tool call.
  93. """
  94. class ToolCallFunction(BaseModel):
  95. """
  96. Model class for assistant prompt message tool call function.
  97. """
  98. name: str
  99. arguments: str
  100. id: str
  101. type: str
  102. function: ToolCallFunction
  103. @field_validator("id", mode="before")
  104. @classmethod
  105. def transform_id_to_str(cls, value) -> str:
  106. if not isinstance(value, str):
  107. return str(value)
  108. else:
  109. return value
  110. role: PromptMessageRole = PromptMessageRole.ASSISTANT
  111. tool_calls: list[ToolCall] = []
  112. def is_empty(self) -> bool:
  113. """
  114. Check if prompt message is empty.
  115. :return: True if prompt message is empty, False otherwise
  116. """
  117. if not super().is_empty() and not self.tool_calls:
  118. return False
  119. return True
  120. class SystemPromptMessage(PromptMessage):
  121. """
  122. Model class for system prompt message.
  123. """
  124. role: PromptMessageRole = PromptMessageRole.SYSTEM
  125. class ToolPromptMessage(PromptMessage):
  126. """
  127. Model class for tool prompt message.
  128. """
  129. role: PromptMessageRole = PromptMessageRole.TOOL
  130. tool_call_id: str
  131. def is_empty(self) -> bool:
  132. """
  133. Check if prompt message is empty.
  134. :return: True if prompt message is empty, False otherwise
  135. """
  136. if not super().is_empty() and not self.tool_call_id:
  137. return False
  138. return True