message_entities.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. from abc import ABC
  2. from enum import Enum
  3. from typing import Optional
  4. from pydantic import BaseModel, Field, field_validator
  5. class PromptMessageRole(Enum):
  6. """
  7. Enum class for prompt message.
  8. """
  9. SYSTEM = "system"
  10. USER = "user"
  11. ASSISTANT = "assistant"
  12. TOOL = "tool"
  13. @classmethod
  14. def value_of(cls, value: str) -> "PromptMessageRole":
  15. """
  16. Get value of given mode.
  17. :param value: mode value
  18. :return: mode
  19. """
  20. for mode in cls:
  21. if mode.value == value:
  22. return mode
  23. raise ValueError(f"invalid prompt message type value {value}")
  24. class PromptMessageTool(BaseModel):
  25. """
  26. Model class for prompt message tool.
  27. """
  28. name: str
  29. description: str
  30. parameters: dict
  31. class PromptMessageFunction(BaseModel):
  32. """
  33. Model class for prompt message function.
  34. """
  35. type: str = "function"
  36. function: PromptMessageTool
  37. class PromptMessageContentType(Enum):
  38. """
  39. Enum class for prompt message content type.
  40. """
  41. TEXT = "text"
  42. IMAGE = "image"
  43. AUDIO = "audio"
  44. VIDEO = "video"
  45. class PromptMessageContent(BaseModel):
  46. """
  47. Model class for prompt message content.
  48. """
  49. type: PromptMessageContentType
  50. data: str
  51. class TextPromptMessageContent(PromptMessageContent):
  52. """
  53. Model class for text prompt message content.
  54. """
  55. type: PromptMessageContentType = PromptMessageContentType.TEXT
  56. class VideoPromptMessageContent(PromptMessageContent):
  57. type: PromptMessageContentType = PromptMessageContentType.VIDEO
  58. data: str = Field(..., description="Base64 encoded video data")
  59. format: str = Field(..., description="Video format")
  60. class AudioPromptMessageContent(PromptMessageContent):
  61. type: PromptMessageContentType = PromptMessageContentType.AUDIO
  62. data: str = Field(..., description="Base64 encoded audio data")
  63. format: str = Field(..., description="Audio format")
  64. class ImagePromptMessageContent(PromptMessageContent):
  65. """
  66. Model class for image prompt message content.
  67. """
  68. class DETAIL(str, Enum):
  69. LOW = "low"
  70. HIGH = "high"
  71. type: PromptMessageContentType = PromptMessageContentType.IMAGE
  72. detail: DETAIL = DETAIL.LOW
  73. class PromptMessage(ABC, BaseModel):
  74. """
  75. Model class for prompt message.
  76. """
  77. role: PromptMessageRole
  78. content: Optional[str | list[PromptMessageContent]] = None
  79. name: Optional[str] = None
  80. def is_empty(self) -> bool:
  81. """
  82. Check if prompt message is empty.
  83. :return: True if prompt message is empty, False otherwise
  84. """
  85. return not self.content
  86. class UserPromptMessage(PromptMessage):
  87. """
  88. Model class for user prompt message.
  89. """
  90. role: PromptMessageRole = PromptMessageRole.USER
  91. class AssistantPromptMessage(PromptMessage):
  92. """
  93. Model class for assistant prompt message.
  94. """
  95. class ToolCall(BaseModel):
  96. """
  97. Model class for assistant prompt message tool call.
  98. """
  99. class ToolCallFunction(BaseModel):
  100. """
  101. Model class for assistant prompt message tool call function.
  102. """
  103. name: str
  104. arguments: str
  105. id: str
  106. type: str
  107. function: ToolCallFunction
  108. @field_validator("id", mode="before")
  109. @classmethod
  110. def transform_id_to_str(cls, value) -> str:
  111. if not isinstance(value, str):
  112. return str(value)
  113. else:
  114. return value
  115. role: PromptMessageRole = PromptMessageRole.ASSISTANT
  116. tool_calls: list[ToolCall] = []
  117. def is_empty(self) -> bool:
  118. """
  119. Check if prompt message is empty.
  120. :return: True if prompt message is empty, False otherwise
  121. """
  122. if not super().is_empty() and not self.tool_calls:
  123. return False
  124. return True
  125. class SystemPromptMessage(PromptMessage):
  126. """
  127. Model class for system prompt message.
  128. """
  129. role: PromptMessageRole = PromptMessageRole.SYSTEM
  130. class ToolPromptMessage(PromptMessage):
  131. """
  132. Model class for tool prompt message.
  133. """
  134. role: PromptMessageRole = PromptMessageRole.TOOL
  135. tool_call_id: str
  136. def is_empty(self) -> bool:
  137. """
  138. Check if prompt message is empty.
  139. :return: True if prompt message is empty, False otherwise
  140. """
  141. if not super().is_empty() and not self.tool_call_id:
  142. return False
  143. return True