encoders.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. import dataclasses
  2. import datetime
  3. from collections import defaultdict, deque
  4. from collections.abc import Callable
  5. from decimal import Decimal
  6. from enum import Enum
  7. from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
  8. from pathlib import Path, PurePath
  9. from re import Pattern
  10. from types import GeneratorType
  11. from typing import Any, Literal, Optional, Union
  12. from uuid import UUID
  13. from pydantic import BaseModel
  14. from pydantic.networks import AnyUrl, NameEmail
  15. from pydantic.types import SecretBytes, SecretStr
  16. from pydantic_core import Url
  17. from pydantic_extra_types.color import Color
  18. def _model_dump(
  19. model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any
  20. ) -> Any:
  21. return model.model_dump(mode=mode, **kwargs)
  22. # Taken from Pydantic v1 as is
  23. def isoformat(o: Union[datetime.date, datetime.time]) -> str:
  24. return o.isoformat()
  25. # Taken from Pydantic v1 as is
  26. # TODO: pv2 should this return strings instead?
  27. def decimal_encoder(dec_value: Decimal) -> Union[int, float]:
  28. """
  29. Encodes a Decimal as int of there's no exponent, otherwise float
  30. This is useful when we use ConstrainedDecimal to represent Numeric(x,0)
  31. where a integer (but not int typed) is used. Encoding this as a float
  32. results in failed round-tripping between encode and parse.
  33. Our Id type is a prime example of this.
  34. >>> decimal_encoder(Decimal("1.0"))
  35. 1.0
  36. >>> decimal_encoder(Decimal("1"))
  37. 1
  38. """
  39. if dec_value.as_tuple().exponent >= 0: # type: ignore[operator]
  40. return int(dec_value)
  41. else:
  42. return float(dec_value)
  43. ENCODERS_BY_TYPE: dict[type[Any], Callable[[Any], Any]] = {
  44. bytes: lambda o: o.decode(),
  45. Color: str,
  46. datetime.date: isoformat,
  47. datetime.datetime: isoformat,
  48. datetime.time: isoformat,
  49. datetime.timedelta: lambda td: td.total_seconds(),
  50. Decimal: decimal_encoder,
  51. Enum: lambda o: o.value,
  52. frozenset: list,
  53. deque: list,
  54. GeneratorType: list,
  55. IPv4Address: str,
  56. IPv4Interface: str,
  57. IPv4Network: str,
  58. IPv6Address: str,
  59. IPv6Interface: str,
  60. IPv6Network: str,
  61. NameEmail: str,
  62. Path: str,
  63. Pattern: lambda o: o.pattern,
  64. SecretBytes: str,
  65. SecretStr: str,
  66. set: list,
  67. UUID: str,
  68. Url: str,
  69. AnyUrl: str,
  70. }
  71. def generate_encoders_by_class_tuples(
  72. type_encoder_map: dict[Any, Callable[[Any], Any]]
  73. ) -> dict[Callable[[Any], Any], tuple[Any, ...]]:
  74. encoders_by_class_tuples: dict[Callable[[Any], Any], tuple[Any, ...]] = defaultdict(
  75. tuple
  76. )
  77. for type_, encoder in type_encoder_map.items():
  78. encoders_by_class_tuples[encoder] += (type_,)
  79. return encoders_by_class_tuples
  80. encoders_by_class_tuples = generate_encoders_by_class_tuples(ENCODERS_BY_TYPE)
  81. def jsonable_encoder(
  82. obj: Any,
  83. by_alias: bool = True,
  84. exclude_unset: bool = False,
  85. exclude_defaults: bool = False,
  86. exclude_none: bool = False,
  87. custom_encoder: Optional[dict[Any, Callable[[Any], Any]]] = None,
  88. sqlalchemy_safe: bool = True,
  89. ) -> Any:
  90. custom_encoder = custom_encoder or {}
  91. if custom_encoder:
  92. if type(obj) in custom_encoder:
  93. return custom_encoder[type(obj)](obj)
  94. else:
  95. for encoder_type, encoder_instance in custom_encoder.items():
  96. if isinstance(obj, encoder_type):
  97. return encoder_instance(obj)
  98. if isinstance(obj, BaseModel):
  99. obj_dict = _model_dump(
  100. obj,
  101. mode="json",
  102. include=None,
  103. exclude=None,
  104. by_alias=by_alias,
  105. exclude_unset=exclude_unset,
  106. exclude_none=exclude_none,
  107. exclude_defaults=exclude_defaults,
  108. )
  109. if "__root__" in obj_dict:
  110. obj_dict = obj_dict["__root__"]
  111. return jsonable_encoder(
  112. obj_dict,
  113. exclude_none=exclude_none,
  114. exclude_defaults=exclude_defaults,
  115. sqlalchemy_safe=sqlalchemy_safe,
  116. )
  117. if dataclasses.is_dataclass(obj):
  118. obj_dict = dataclasses.asdict(obj)
  119. return jsonable_encoder(
  120. obj_dict,
  121. by_alias=by_alias,
  122. exclude_unset=exclude_unset,
  123. exclude_defaults=exclude_defaults,
  124. exclude_none=exclude_none,
  125. custom_encoder=custom_encoder,
  126. sqlalchemy_safe=sqlalchemy_safe,
  127. )
  128. if isinstance(obj, Enum):
  129. return obj.value
  130. if isinstance(obj, PurePath):
  131. return str(obj)
  132. if isinstance(obj, str | int | float | type(None)):
  133. return obj
  134. if isinstance(obj, Decimal):
  135. return format(obj, 'f')
  136. if isinstance(obj, dict):
  137. encoded_dict = {}
  138. allowed_keys = set(obj.keys())
  139. for key, value in obj.items():
  140. if (
  141. (
  142. not sqlalchemy_safe
  143. or (not isinstance(key, str))
  144. or (not key.startswith("_sa"))
  145. )
  146. and (value is not None or not exclude_none)
  147. and key in allowed_keys
  148. ):
  149. encoded_key = jsonable_encoder(
  150. key,
  151. by_alias=by_alias,
  152. exclude_unset=exclude_unset,
  153. exclude_none=exclude_none,
  154. custom_encoder=custom_encoder,
  155. sqlalchemy_safe=sqlalchemy_safe,
  156. )
  157. encoded_value = jsonable_encoder(
  158. value,
  159. by_alias=by_alias,
  160. exclude_unset=exclude_unset,
  161. exclude_none=exclude_none,
  162. custom_encoder=custom_encoder,
  163. sqlalchemy_safe=sqlalchemy_safe,
  164. )
  165. encoded_dict[encoded_key] = encoded_value
  166. return encoded_dict
  167. if isinstance(obj, list | set | frozenset | GeneratorType | tuple | deque):
  168. encoded_list = []
  169. for item in obj:
  170. encoded_list.append(
  171. jsonable_encoder(
  172. item,
  173. by_alias=by_alias,
  174. exclude_unset=exclude_unset,
  175. exclude_defaults=exclude_defaults,
  176. exclude_none=exclude_none,
  177. custom_encoder=custom_encoder,
  178. sqlalchemy_safe=sqlalchemy_safe,
  179. )
  180. )
  181. return encoded_list
  182. if type(obj) in ENCODERS_BY_TYPE:
  183. return ENCODERS_BY_TYPE[type(obj)](obj)
  184. for encoder, classes_tuple in encoders_by_class_tuples.items():
  185. if isinstance(obj, classes_tuple):
  186. return encoder(obj)
  187. try:
  188. data = dict(obj)
  189. except Exception as e:
  190. errors: list[Exception] = []
  191. errors.append(e)
  192. try:
  193. data = vars(obj)
  194. except Exception as e:
  195. errors.append(e)
  196. raise ValueError(errors) from e
  197. return jsonable_encoder(
  198. data,
  199. by_alias=by_alias,
  200. exclude_unset=exclude_unset,
  201. exclude_defaults=exclude_defaults,
  202. exclude_none=exclude_none,
  203. custom_encoder=custom_encoder,
  204. sqlalchemy_safe=sqlalchemy_safe,
  205. )