model_entities.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. from decimal import Decimal
  2. from enum import Enum
  3. from typing import Any, Optional
  4. from pydantic import BaseModel
  5. from core.model_runtime.entities.common_entities import I18nObject
  6. class ModelType(Enum):
  7. """
  8. Enum class for model type.
  9. """
  10. LLM = "llm"
  11. TEXT_EMBEDDING = "text-embedding"
  12. RERANK = "rerank"
  13. SPEECH2TEXT = "speech2text"
  14. MODERATION = "moderation"
  15. TTS = "tts"
  16. # TEXT2IMG = "text2img"
  17. @classmethod
  18. def value_of(cls, origin_model_type: str) -> "ModelType":
  19. """
  20. Get model type from origin model type.
  21. :return: model type
  22. """
  23. if origin_model_type == 'text-generation' or origin_model_type == cls.LLM.value:
  24. return cls.LLM
  25. elif origin_model_type == 'embeddings' or origin_model_type == cls.TEXT_EMBEDDING.value:
  26. return cls.TEXT_EMBEDDING
  27. elif origin_model_type == 'reranking' or origin_model_type == cls.RERANK.value:
  28. return cls.RERANK
  29. elif origin_model_type == 'speech2text' or origin_model_type == cls.SPEECH2TEXT.value:
  30. return cls.SPEECH2TEXT
  31. elif origin_model_type == 'tts' or origin_model_type == cls.TTS.value:
  32. return cls.TTS
  33. elif origin_model_type == cls.MODERATION.value:
  34. return cls.MODERATION
  35. else:
  36. raise ValueError(f'invalid origin model type {origin_model_type}')
  37. def to_origin_model_type(self) -> str:
  38. """
  39. Get origin model type from model type.
  40. :return: origin model type
  41. """
  42. if self == self.LLM:
  43. return 'text-generation'
  44. elif self == self.TEXT_EMBEDDING:
  45. return 'embeddings'
  46. elif self == self.RERANK:
  47. return 'reranking'
  48. elif self == self.SPEECH2TEXT:
  49. return 'speech2text'
  50. elif self == self.TTS:
  51. return 'tts'
  52. elif self == self.MODERATION:
  53. return 'moderation'
  54. else:
  55. raise ValueError(f'invalid model type {self}')
  56. class FetchFrom(Enum):
  57. """
  58. Enum class for fetch from.
  59. """
  60. PREDEFINED_MODEL = "predefined-model"
  61. CUSTOMIZABLE_MODEL = "customizable-model"
  62. class ModelFeature(Enum):
  63. """
  64. Enum class for llm feature.
  65. """
  66. TOOL_CALL = "tool-call"
  67. MULTI_TOOL_CALL = "multi-tool-call"
  68. AGENT_THOUGHT = "agent-thought"
  69. VISION = "vision"
  70. STREAM_TOOL_CALL = "stream-tool-call"
  71. class DefaultParameterName(Enum):
  72. """
  73. Enum class for parameter template variable.
  74. """
  75. TEMPERATURE = "temperature"
  76. TOP_P = "top_p"
  77. PRESENCE_PENALTY = "presence_penalty"
  78. FREQUENCY_PENALTY = "frequency_penalty"
  79. MAX_TOKENS = "max_tokens"
  80. RESPONSE_FORMAT = "response_format"
  81. @classmethod
  82. def value_of(cls, value: Any) -> 'DefaultParameterName':
  83. """
  84. Get parameter name from value.
  85. :param value: parameter value
  86. :return: parameter name
  87. """
  88. for name in cls:
  89. if name.value == value:
  90. return name
  91. raise ValueError(f'invalid parameter name {value}')
  92. class ParameterType(Enum):
  93. """
  94. Enum class for parameter type.
  95. """
  96. FLOAT = "float"
  97. INT = "int"
  98. STRING = "string"
  99. BOOLEAN = "boolean"
  100. class ModelPropertyKey(Enum):
  101. """
  102. Enum class for model property key.
  103. """
  104. MODE = "mode"
  105. CONTEXT_SIZE = "context_size"
  106. MAX_CHUNKS = "max_chunks"
  107. FILE_UPLOAD_LIMIT = "file_upload_limit"
  108. SUPPORTED_FILE_EXTENSIONS = "supported_file_extensions"
  109. MAX_CHARACTERS_PER_CHUNK = "max_characters_per_chunk"
  110. DEFAULT_VOICE = "default_voice"
  111. VOICES = "voices"
  112. WORD_LIMIT = "word_limit"
  113. AUDOI_TYPE = "audio_type"
  114. MAX_WORKERS = "max_workers"
  115. class ProviderModel(BaseModel):
  116. """
  117. Model class for provider model.
  118. """
  119. model: str
  120. label: I18nObject
  121. model_type: ModelType
  122. features: Optional[list[ModelFeature]] = None
  123. fetch_from: FetchFrom
  124. model_properties: dict[ModelPropertyKey, Any]
  125. deprecated: bool = False
  126. class Config:
  127. protected_namespaces = ()
  128. class ParameterRule(BaseModel):
  129. """
  130. Model class for parameter rule.
  131. """
  132. name: str
  133. use_template: Optional[str] = None
  134. label: I18nObject
  135. type: ParameterType
  136. help: Optional[I18nObject] = None
  137. required: bool = False
  138. default: Optional[Any] = None
  139. min: Optional[float] = None
  140. max: Optional[float] = None
  141. precision: Optional[int] = None
  142. options: list[str] = []
  143. class PriceConfig(BaseModel):
  144. """
  145. Model class for pricing info.
  146. """
  147. input: Decimal
  148. output: Optional[Decimal] = None
  149. unit: Decimal
  150. currency: str
  151. class AIModelEntity(ProviderModel):
  152. """
  153. Model class for AI model.
  154. """
  155. parameter_rules: list[ParameterRule] = []
  156. pricing: Optional[PriceConfig] = None
  157. class ModelUsage(BaseModel):
  158. pass
  159. class PriceType(Enum):
  160. """
  161. Enum class for price type.
  162. """
  163. INPUT = "input"
  164. OUTPUT = "output"
  165. class PriceInfo(BaseModel):
  166. """
  167. Model class for price info.
  168. """
  169. unit_price: Decimal
  170. unit: Decimal
  171. total_amount: Decimal
  172. currency: str