"""Preconfigured converters for msgspec."""
from __future__ import annotations
from base64 import b64decode
from datetime import date, datetime
from enum import Enum
from functools import partial
from typing import Any, Callable, TypeVar, Union, get_type_hints
from attrs import has as attrs_has
from attrs import resolve_types
from msgspec import Struct, convert, to_builtins
from msgspec.json import Encoder, decode
from cattrs._compat import (
fields,
get_args,
get_origin,
has,
is_bare,
is_mapping,
is_sequence,
)
from cattrs.dispatch import UnstructureHook
from cattrs.fns import identity
from ..converters import BaseConverter, Converter
from ..gen import make_hetero_tuple_unstructure_fn
from ..strategies import configure_union_passthrough
from ..tuples import is_namedtuple
from . import wrap
T = TypeVar("T")
__all__ = ["MsgspecJsonConverter", "configure_converter", "make_converter"]
[docs]class MsgspecJsonConverter(Converter):
"""A converter specialized for the _msgspec_ library."""
#: The msgspec encoder for dumping.
encoder: Encoder = Encoder()
[docs] def dumps(self, obj: Any, unstructure_as: Any = None, **kwargs: Any) -> bytes:
"""Unstructure and encode `obj` into JSON bytes."""
return self.encoder.encode(
self.unstructure(obj, unstructure_as=unstructure_as), **kwargs
)
[docs] def get_dumps_hook(
self, unstructure_as: Any, **kwargs: Any
) -> Callable[[Any], bytes]:
"""Produce a `dumps` hook for the given type."""
unstruct_hook = self.get_unstructure_hook(unstructure_as)
if unstruct_hook in (identity, to_builtins):
return self.encoder.encode
return self.dumps
[docs] def loads(self, data: bytes, cl: type[T], **kwargs: Any) -> T:
"""Decode and structure `cl` from the provided JSON bytes."""
return self.structure(decode(data, **kwargs), cl)
[docs] def get_loads_hook(self, cl: type[T]) -> Callable[[bytes], T]:
"""Produce a `loads` hook for the given type."""
return partial(self.loads, cl=cl)
[docs]@wrap(MsgspecJsonConverter)
def make_converter(*args: Any, **kwargs: Any) -> MsgspecJsonConverter:
res = MsgspecJsonConverter(*args, **kwargs)
configure_converter(res)
return res
def configure_passthroughs(converter: Converter) -> None:
"""Configure optimizing passthroughs.
A passthrough is when we let msgspec handle something automatically.
"""
converter.register_unstructure_hook(bytes, to_builtins)
converter.register_unstructure_hook_factory(is_mapping, mapping_unstructure_factory)
converter.register_unstructure_hook_factory(is_sequence, seq_unstructure_factory)
converter.register_unstructure_hook_factory(has, attrs_unstructure_factory)
converter.register_unstructure_hook_factory(
is_namedtuple, namedtuple_unstructure_factory
)
def seq_unstructure_factory(type, converter: Converter) -> UnstructureHook:
"""The msgspec unstructure hook factory for sequences."""
if is_bare(type):
type_arg = Any
else:
args = get_args(type)
type_arg = args[0]
handler = converter.get_unstructure_hook(type_arg, cache_result=False)
if handler in (identity, to_builtins):
return handler
return converter.gen_unstructure_iterable(type)
def mapping_unstructure_factory(type, converter: BaseConverter) -> UnstructureHook:
"""The msgspec unstructure hook factory for mappings."""
if is_bare(type):
key_arg = Any
val_arg = Any
key_handler = converter.get_unstructure_hook(key_arg, cache_result=False)
value_handler = converter.get_unstructure_hook(val_arg, cache_result=False)
else:
args = get_args(type)
if len(args) == 2:
key_arg, val_arg = args
else:
# Probably a Counter
key_arg, val_arg = args, Any
key_handler = converter.get_unstructure_hook(key_arg, cache_result=False)
value_handler = converter.get_unstructure_hook(val_arg, cache_result=False)
if key_handler in (identity, to_builtins) and value_handler in (
identity,
to_builtins,
):
return to_builtins
return converter.gen_unstructure_mapping(type)
def attrs_unstructure_factory(type: Any, converter: Converter) -> UnstructureHook:
"""Choose whether to use msgspec handling or our own."""
origin = get_origin(type)
attribs = fields(origin or type)
if attrs_has(type) and any(isinstance(a.type, str) for a in attribs):
resolve_types(type)
attribs = fields(origin or type)
if any(
attr.name.startswith("_")
or (
converter.get_unstructure_hook(attr.type, cache_result=False)
not in (identity, to_builtins)
)
for attr in attribs
):
return converter.gen_unstructure_attrs_fromdict(type)
return to_builtins
def namedtuple_unstructure_factory(
type: type[tuple], converter: BaseConverter
) -> UnstructureHook:
"""A hook factory for unstructuring namedtuples, modified for msgspec."""
if all(
converter.get_unstructure_hook(t) in (identity, to_builtins)
for t in get_type_hints(type).values()
):
return identity
return make_hetero_tuple_unstructure_fn(
type,
converter,
unstructure_to=tuple,
type_args=tuple(get_type_hints(type).values()),
)