encoders.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. import dataclasses
  2. import datetime
  3. from collections import defaultdict, deque
  4. from collections.abc import Callable
  5. from decimal import Decimal
  6. from enum import Enum
  7. from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
  8. from pathlib import Path, PurePath
  9. from re import Pattern
  10. from types import GeneratorType
  11. from typing import Any, Optional, Union
  12. from uuid import UUID
  13. from pydantic import BaseModel
  14. from pydantic.color import Color
  15. from pydantic.networks import AnyUrl, NameEmail
  16. from pydantic.types import SecretBytes, SecretStr
  17. from ._compat import PYDANTIC_V2, Url, _model_dump
  18. # Taken from Pydantic v1 as is
  19. def isoformat(o: Union[datetime.date, datetime.time]) -> str:
  20. return o.isoformat()
  21. # Taken from Pydantic v1 as is
  22. # TODO: pv2 should this return strings instead?
  23. def decimal_encoder(dec_value: Decimal) -> Union[int, float]:
  24. """
  25. Encodes a Decimal as int of there's no exponent, otherwise float
  26. This is useful when we use ConstrainedDecimal to represent Numeric(x,0)
  27. where a integer (but not int typed) is used. Encoding this as a float
  28. results in failed round-tripping between encode and parse.
  29. Our Id type is a prime example of this.
  30. >>> decimal_encoder(Decimal("1.0"))
  31. 1.0
  32. >>> decimal_encoder(Decimal("1"))
  33. 1
  34. """
  35. if dec_value.as_tuple().exponent >= 0: # type: ignore[operator]
  36. return int(dec_value)
  37. else:
  38. return float(dec_value)
  39. ENCODERS_BY_TYPE: dict[type[Any], Callable[[Any], Any]] = {
  40. bytes: lambda o: o.decode(),
  41. Color: str,
  42. datetime.date: isoformat,
  43. datetime.datetime: isoformat,
  44. datetime.time: isoformat,
  45. datetime.timedelta: lambda td: td.total_seconds(),
  46. Decimal: decimal_encoder,
  47. Enum: lambda o: o.value,
  48. frozenset: list,
  49. deque: list,
  50. GeneratorType: list,
  51. IPv4Address: str,
  52. IPv4Interface: str,
  53. IPv4Network: str,
  54. IPv6Address: str,
  55. IPv6Interface: str,
  56. IPv6Network: str,
  57. NameEmail: str,
  58. Path: str,
  59. Pattern: lambda o: o.pattern,
  60. SecretBytes: str,
  61. SecretStr: str,
  62. set: list,
  63. UUID: str,
  64. Url: str,
  65. AnyUrl: str,
  66. }
  67. def generate_encoders_by_class_tuples(
  68. type_encoder_map: dict[Any, Callable[[Any], Any]]
  69. ) -> dict[Callable[[Any], Any], tuple[Any, ...]]:
  70. encoders_by_class_tuples: dict[Callable[[Any], Any], tuple[Any, ...]] = defaultdict(
  71. tuple
  72. )
  73. for type_, encoder in type_encoder_map.items():
  74. encoders_by_class_tuples[encoder] += (type_,)
  75. return encoders_by_class_tuples
  76. encoders_by_class_tuples = generate_encoders_by_class_tuples(ENCODERS_BY_TYPE)
  77. def jsonable_encoder(
  78. obj: Any,
  79. by_alias: bool = True,
  80. exclude_unset: bool = False,
  81. exclude_defaults: bool = False,
  82. exclude_none: bool = False,
  83. custom_encoder: Optional[dict[Any, Callable[[Any], Any]]] = None,
  84. sqlalchemy_safe: bool = True,
  85. ) -> Any:
  86. custom_encoder = custom_encoder or {}
  87. if custom_encoder:
  88. if type(obj) in custom_encoder:
  89. return custom_encoder[type(obj)](obj)
  90. else:
  91. for encoder_type, encoder_instance in custom_encoder.items():
  92. if isinstance(obj, encoder_type):
  93. return encoder_instance(obj)
  94. if isinstance(obj, BaseModel):
  95. # TODO: remove when deprecating Pydantic v1
  96. encoders: dict[Any, Any] = {}
  97. if not PYDANTIC_V2:
  98. encoders = getattr(obj.__config__, "json_encoders", {}) # type: ignore[attr-defined]
  99. if custom_encoder:
  100. encoders.update(custom_encoder)
  101. obj_dict = _model_dump(
  102. obj,
  103. mode="json",
  104. include=None,
  105. exclude=None,
  106. by_alias=by_alias,
  107. exclude_unset=exclude_unset,
  108. exclude_none=exclude_none,
  109. exclude_defaults=exclude_defaults,
  110. )
  111. if "__root__" in obj_dict:
  112. obj_dict = obj_dict["__root__"]
  113. return jsonable_encoder(
  114. obj_dict,
  115. exclude_none=exclude_none,
  116. exclude_defaults=exclude_defaults,
  117. # TODO: remove when deprecating Pydantic v1
  118. custom_encoder=encoders,
  119. sqlalchemy_safe=sqlalchemy_safe,
  120. )
  121. if dataclasses.is_dataclass(obj):
  122. obj_dict = dataclasses.asdict(obj)
  123. return jsonable_encoder(
  124. obj_dict,
  125. by_alias=by_alias,
  126. exclude_unset=exclude_unset,
  127. exclude_defaults=exclude_defaults,
  128. exclude_none=exclude_none,
  129. custom_encoder=custom_encoder,
  130. sqlalchemy_safe=sqlalchemy_safe,
  131. )
  132. if isinstance(obj, Enum):
  133. return obj.value
  134. if isinstance(obj, PurePath):
  135. return str(obj)
  136. if isinstance(obj, str | int | float | type(None)):
  137. return obj
  138. if isinstance(obj, Decimal):
  139. return format(obj, 'f')
  140. if isinstance(obj, dict):
  141. encoded_dict = {}
  142. allowed_keys = set(obj.keys())
  143. for key, value in obj.items():
  144. if (
  145. (
  146. not sqlalchemy_safe
  147. or (not isinstance(key, str))
  148. or (not key.startswith("_sa"))
  149. )
  150. and (value is not None or not exclude_none)
  151. and key in allowed_keys
  152. ):
  153. encoded_key = jsonable_encoder(
  154. key,
  155. by_alias=by_alias,
  156. exclude_unset=exclude_unset,
  157. exclude_none=exclude_none,
  158. custom_encoder=custom_encoder,
  159. sqlalchemy_safe=sqlalchemy_safe,
  160. )
  161. encoded_value = jsonable_encoder(
  162. value,
  163. by_alias=by_alias,
  164. exclude_unset=exclude_unset,
  165. exclude_none=exclude_none,
  166. custom_encoder=custom_encoder,
  167. sqlalchemy_safe=sqlalchemy_safe,
  168. )
  169. encoded_dict[encoded_key] = encoded_value
  170. return encoded_dict
  171. if isinstance(obj, list | set | frozenset | GeneratorType | tuple | deque):
  172. encoded_list = []
  173. for item in obj:
  174. encoded_list.append(
  175. jsonable_encoder(
  176. item,
  177. by_alias=by_alias,
  178. exclude_unset=exclude_unset,
  179. exclude_defaults=exclude_defaults,
  180. exclude_none=exclude_none,
  181. custom_encoder=custom_encoder,
  182. sqlalchemy_safe=sqlalchemy_safe,
  183. )
  184. )
  185. return encoded_list
  186. if type(obj) in ENCODERS_BY_TYPE:
  187. return ENCODERS_BY_TYPE[type(obj)](obj)
  188. for encoder, classes_tuple in encoders_by_class_tuples.items():
  189. if isinstance(obj, classes_tuple):
  190. return encoder(obj)
  191. try:
  192. data = dict(obj)
  193. except Exception as e:
  194. errors: list[Exception] = []
  195. errors.append(e)
  196. try:
  197. data = vars(obj)
  198. except Exception as e:
  199. errors.append(e)
  200. raise ValueError(errors) from e
  201. return jsonable_encoder(
  202. data,
  203. by_alias=by_alias,
  204. exclude_unset=exclude_unset,
  205. exclude_defaults=exclude_defaults,
  206. exclude_none=exclude_none,
  207. custom_encoder=custom_encoder,
  208. sqlalchemy_safe=sqlalchemy_safe,
  209. )