123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228 |
- import dataclasses
- import datetime
- from collections import defaultdict, deque
- from collections.abc import Callable
- from decimal import Decimal
- from enum import Enum
- from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
- from pathlib import Path, PurePath
- from re import Pattern
- from types import GeneratorType
- from typing import Any, Optional, Union
- from uuid import UUID
- from pydantic import BaseModel
- from pydantic.color import Color
- from pydantic.networks import AnyUrl, NameEmail
- from pydantic.types import SecretBytes, SecretStr
- from ._compat import PYDANTIC_V2, Url, _model_dump
- # Taken from Pydantic v1 as is
- def isoformat(o: Union[datetime.date, datetime.time]) -> str:
- return o.isoformat()
- # Taken from Pydantic v1 as is
- # TODO: pv2 should this return strings instead?
- def decimal_encoder(dec_value: Decimal) -> Union[int, float]:
- """
- Encodes a Decimal as int of there's no exponent, otherwise float
- This is useful when we use ConstrainedDecimal to represent Numeric(x,0)
- where a integer (but not int typed) is used. Encoding this as a float
- results in failed round-tripping between encode and parse.
- Our Id type is a prime example of this.
- >>> decimal_encoder(Decimal("1.0"))
- 1.0
- >>> decimal_encoder(Decimal("1"))
- 1
- """
- if dec_value.as_tuple().exponent >= 0: # type: ignore[operator]
- return int(dec_value)
- else:
- return float(dec_value)
- ENCODERS_BY_TYPE: dict[type[Any], Callable[[Any], Any]] = {
- bytes: lambda o: o.decode(),
- Color: str,
- datetime.date: isoformat,
- datetime.datetime: isoformat,
- datetime.time: isoformat,
- datetime.timedelta: lambda td: td.total_seconds(),
- Decimal: decimal_encoder,
- Enum: lambda o: o.value,
- frozenset: list,
- deque: list,
- GeneratorType: list,
- IPv4Address: str,
- IPv4Interface: str,
- IPv4Network: str,
- IPv6Address: str,
- IPv6Interface: str,
- IPv6Network: str,
- NameEmail: str,
- Path: str,
- Pattern: lambda o: o.pattern,
- SecretBytes: str,
- SecretStr: str,
- set: list,
- UUID: str,
- Url: str,
- AnyUrl: str,
- }
- def generate_encoders_by_class_tuples(
- type_encoder_map: dict[Any, Callable[[Any], Any]]
- ) -> dict[Callable[[Any], Any], tuple[Any, ...]]:
- encoders_by_class_tuples: dict[Callable[[Any], Any], tuple[Any, ...]] = defaultdict(
- tuple
- )
- for type_, encoder in type_encoder_map.items():
- encoders_by_class_tuples[encoder] += (type_,)
- return encoders_by_class_tuples
- encoders_by_class_tuples = generate_encoders_by_class_tuples(ENCODERS_BY_TYPE)
- def jsonable_encoder(
- obj: Any,
- by_alias: bool = True,
- exclude_unset: bool = False,
- exclude_defaults: bool = False,
- exclude_none: bool = False,
- custom_encoder: Optional[dict[Any, Callable[[Any], Any]]] = None,
- sqlalchemy_safe: bool = True,
- ) -> Any:
- custom_encoder = custom_encoder or {}
- if custom_encoder:
- if type(obj) in custom_encoder:
- return custom_encoder[type(obj)](obj)
- else:
- for encoder_type, encoder_instance in custom_encoder.items():
- if isinstance(obj, encoder_type):
- return encoder_instance(obj)
- if isinstance(obj, BaseModel):
- # TODO: remove when deprecating Pydantic v1
- encoders: dict[Any, Any] = {}
- if not PYDANTIC_V2:
- encoders = getattr(obj.__config__, "json_encoders", {}) # type: ignore[attr-defined]
- if custom_encoder:
- encoders.update(custom_encoder)
- obj_dict = _model_dump(
- obj,
- mode="json",
- include=None,
- exclude=None,
- by_alias=by_alias,
- exclude_unset=exclude_unset,
- exclude_none=exclude_none,
- exclude_defaults=exclude_defaults,
- )
- if "__root__" in obj_dict:
- obj_dict = obj_dict["__root__"]
- return jsonable_encoder(
- obj_dict,
- exclude_none=exclude_none,
- exclude_defaults=exclude_defaults,
- # TODO: remove when deprecating Pydantic v1
- custom_encoder=encoders,
- sqlalchemy_safe=sqlalchemy_safe,
- )
- if dataclasses.is_dataclass(obj):
- obj_dict = dataclasses.asdict(obj)
- return jsonable_encoder(
- obj_dict,
- by_alias=by_alias,
- exclude_unset=exclude_unset,
- exclude_defaults=exclude_defaults,
- exclude_none=exclude_none,
- custom_encoder=custom_encoder,
- sqlalchemy_safe=sqlalchemy_safe,
- )
- if isinstance(obj, Enum):
- return obj.value
- if isinstance(obj, PurePath):
- return str(obj)
- if isinstance(obj, str | int | float | type(None)):
- return obj
- if isinstance(obj, Decimal):
- return format(obj, 'f')
- if isinstance(obj, dict):
- encoded_dict = {}
- allowed_keys = set(obj.keys())
- for key, value in obj.items():
- if (
- (
- not sqlalchemy_safe
- or (not isinstance(key, str))
- or (not key.startswith("_sa"))
- )
- and (value is not None or not exclude_none)
- and key in allowed_keys
- ):
- encoded_key = jsonable_encoder(
- key,
- by_alias=by_alias,
- exclude_unset=exclude_unset,
- exclude_none=exclude_none,
- custom_encoder=custom_encoder,
- sqlalchemy_safe=sqlalchemy_safe,
- )
- encoded_value = jsonable_encoder(
- value,
- by_alias=by_alias,
- exclude_unset=exclude_unset,
- exclude_none=exclude_none,
- custom_encoder=custom_encoder,
- sqlalchemy_safe=sqlalchemy_safe,
- )
- encoded_dict[encoded_key] = encoded_value
- return encoded_dict
- if isinstance(obj, list | set | frozenset | GeneratorType | tuple | deque):
- encoded_list = []
- for item in obj:
- encoded_list.append(
- jsonable_encoder(
- item,
- by_alias=by_alias,
- exclude_unset=exclude_unset,
- exclude_defaults=exclude_defaults,
- exclude_none=exclude_none,
- custom_encoder=custom_encoder,
- sqlalchemy_safe=sqlalchemy_safe,
- )
- )
- return encoded_list
- if type(obj) in ENCODERS_BY_TYPE:
- return ENCODERS_BY_TYPE[type(obj)](obj)
- for encoder, classes_tuple in encoders_by_class_tuples.items():
- if isinstance(obj, classes_tuple):
- return encoder(obj)
- try:
- data = dict(obj)
- except Exception as e:
- errors: list[Exception] = []
- errors.append(e)
- try:
- data = vars(obj)
- except Exception as e:
- errors.append(e)
- raise ValueError(errors) from e
- return jsonable_encoder(
- data,
- by_alias=by_alias,
- exclude_unset=exclude_unset,
- exclude_defaults=exclude_defaults,
- exclude_none=exclude_none,
- custom_encoder=custom_encoder,
- sqlalchemy_safe=sqlalchemy_safe,
- )
|