Skip to content

Commit

Permalink
Enable hook factories to take converters
Browse files Browse the repository at this point in the history
  • Loading branch information
Tinche committed Jan 27, 2024
1 parent 436d651 commit e613d74
Show file tree
Hide file tree
Showing 8 changed files with 169 additions and 51 deletions.
92 changes: 80 additions & 12 deletions src/cattrs/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dataclasses import Field
from enum import Enum
from functools import partial
from inspect import Signature
from pathlib import Path
from typing import Any, Callable, Iterable, Optional, Tuple, TypeVar, overload

Expand Down Expand Up @@ -55,6 +56,7 @@
MultiStrategyDispatch,
StructuredValue,
StructureHook,
TargetType,
UnstructuredValue,
UnstructureHook,
)
Expand Down Expand Up @@ -85,11 +87,25 @@

T = TypeVar("T")
V = TypeVar("V")

UnstructureHookFactory = TypeVar(
"UnstructureHookFactory", bound=HookFactory[UnstructureHook]
)

# The Extended factory also takes a converter.
ExtendedUnstructureHookFactory = TypeVar(
"ExtendedUnstructureHookFactory",
bound=Callable[[TargetType, "BaseConverter"], UnstructureHook],
)

StructureHookFactory = TypeVar("StructureHookFactory", bound=HookFactory[StructureHook])

# The Extended factory also takes a converter.
ExtendedStructureHookFactory = TypeVar(
"ExtendedStructureHookFactory",
bound=Callable[[TargetType, "BaseConverter"], StructureHook],
)


class UnstructureStrategy(Enum):
"""`attrs` classes unstructuring strategies."""
Expand Down Expand Up @@ -151,7 +167,9 @@ def __init__(
self._unstructure_attrs = self.unstructure_attrs_astuple
self._structure_attrs = self.structure_attrs_fromtuple

self._unstructure_func = MultiStrategyDispatch(unstructure_fallback_factory)
self._unstructure_func = MultiStrategyDispatch(
unstructure_fallback_factory, self
)
self._unstructure_func.register_cls_list(
[(bytes, identity), (str, identity), (Path, str)]
)
Expand All @@ -163,12 +181,12 @@ def __init__(
),
(
lambda t: get_final_base(t) is not None,
lambda t: self._unstructure_func.dispatch(get_final_base(t)),
lambda t: self.get_unstructure_hook(get_final_base(t)),
True,
),
(
is_type_alias,
lambda t: self._unstructure_func.dispatch(get_type_alias_base(t)),
lambda t: self.get_unstructure_hook(get_type_alias_base(t)),
True,
),
(is_mapping, self._unstructure_mapping),
Expand All @@ -185,7 +203,7 @@ def __init__(
# Per-instance register of to-attrs converters.
# Singledispatch dispatches based on the first argument, so we
# store the function and switch the arguments in self.loads.
self._structure_func = MultiStrategyDispatch(structure_fallback_factory)
self._structure_func = MultiStrategyDispatch(structure_fallback_factory, self)
self._structure_func.register_func_list(
[
(
Expand Down Expand Up @@ -308,6 +326,12 @@ def register_unstructure_hook_factory(
) -> Callable[[UnstructureHookFactory], UnstructureHookFactory]:
...

@overload
def register_unstructure_hook_factory(
self, predicate: Callable[[Any], bool]
) -> Callable[[ExtendedUnstructureHookFactory], ExtendedUnstructureHookFactory]:
...

@overload
def register_unstructure_hook_factory(
self, predicate: Callable[[Any], bool], factory: UnstructureHookFactory
Expand All @@ -325,7 +349,10 @@ def register_unstructure_hook_factory(
"""
Register a hook factory for a given predicate.
May also be used as a decorator.
May also be used as a decorator. When used as a decorator, the hook
factory may expose an additional required parameter. In this case,
the current converter will be provided to the hook factory as that
parameter.
:param predicate: A function that, given a type, returns whether the factory
can produce a hook for that type.
Expand All @@ -336,7 +363,23 @@ def register_unstructure_hook_factory(
This method may now be used as a decorator.
"""
if factory is None:
return partial(self.register_unstructure_hook_factory, predicate)

def decorator(factory):
# Is this an extended factory (takes a converter too)?
sig = signature(factory)
if (
len(sig.parameters) >= 2
and (list(sig.parameters.values())[1]).default is Signature.empty
):
self._unstructure_func.register_func_list(
[(predicate, factory, "extended")]
)
else:
self._unstructure_func.register_func_list(
[(predicate, factory, True)]
)

return decorator
self._unstructure_func.register_func_list([(predicate, factory, True)])
return factory

Expand Down Expand Up @@ -420,6 +463,12 @@ def register_structure_hook_factory(
) -> Callable[[StructureHookFactory, StructureHookFactory]]:
...

@overload
def register_structure_hook_factory(
self, predicate: Callable[[Any, bool]]
) -> Callable[[ExtendedStructureHookFactory, ExtendedStructureHookFactory]]:
...

@overload
def register_structure_hook_factory(
self, predicate: Callable[[Any], bool], factory: StructureHookFactory
Expand All @@ -434,7 +483,10 @@ def register_structure_hook_factory(
"""
Register a hook factory for a given predicate.
May also be used as a decorator.
May also be used as a decorator. When used as a decorator, the hook
factory may expose an additional required parameter. In this case,
the current converter will be provided to the hook factory as that
parameter.
:param predicate: A function that, given a type, returns whether the factory
can produce a hook for that type.
Expand All @@ -445,7 +497,23 @@ def register_structure_hook_factory(
This method may now be used as a decorator.
"""
if factory is None:
return partial(self.register_structure_hook_factory, predicate)
# Decorator use.
def decorator(factory):
# Is this an extended factory (takes a converter too)?
sig = signature(factory)
if (
len(sig.parameters) >= 2
and (list(sig.parameters.values())[1]).default is Signature.empty
):
self._structure_func.register_func_list(
[(predicate, factory, "extended")]
)
else:
self._structure_func.register_func_list(
[(predicate, factory, True)]
)

return decorator
self._structure_func.register_func_list([(predicate, factory, True)])
return factory

Expand Down Expand Up @@ -684,7 +752,7 @@ def _structure_list(self, obj: Iterable[T], cl: Any) -> list[T]:
def _structure_deque(self, obj: Iterable[T], cl: Any) -> deque[T]:
"""Convert an iterable to a potentially generic deque."""
if is_bare(cl) or cl.__args__[0] in ANIES:
res = deque(e for e in obj)
res = deque(obj)
else:
elem_type = cl.__args__[0]
handler = self._structure_func.dispatch(elem_type)
Expand Down Expand Up @@ -1048,7 +1116,7 @@ def __init__(
)
self.register_unstructure_hook_factory(
lambda t: get_newtype_base(t) is not None,
lambda t: self._unstructure_func.dispatch(get_newtype_base(t)),
lambda t: self.get_unstructure_hook(get_newtype_base(t)),
)

self.register_structure_hook_factory(is_annotated, self.gen_structure_annotated)
Expand All @@ -1070,7 +1138,7 @@ def get_structure_newtype(self, type: type[T]) -> Callable[[Any, Any], T]:

def gen_unstructure_annotated(self, type):
origin = type.__origin__
return self._unstructure_func.dispatch(origin)
return self.get_unstructure_hook(origin)

def gen_structure_annotated(self, type) -> Callable:
"""A hook factory for annotated types."""
Expand Down Expand Up @@ -1111,7 +1179,7 @@ def gen_unstructure_optional(self, cl: type[T]) -> Callable[[T], Any]:
if isinstance(other, TypeVar):
handler = self.unstructure
else:
handler = self._unstructure_func.dispatch(other)
handler = self.get_unstructure_hook(other)

def unstructure_optional(val, _handler=handler):
return None if val is None else _handler(val)
Expand Down
79 changes: 49 additions & 30 deletions src/cattrs/dispatch.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from functools import lru_cache, partial, singledispatch
from typing import Any, Callable, Dict, Generic, List, Optional, Tuple, TypeVar, Union
from __future__ import annotations

from attrs import Factory, define, field
from functools import lru_cache, singledispatch
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, TypeVar

from attrs import Factory, define

from cattrs._compat import TypeAlias

if TYPE_CHECKING:
from .converters import BaseConverter

T = TypeVar("T")

TargetType: TypeAlias = Any
Expand Down Expand Up @@ -33,23 +38,25 @@ class FunctionDispatch:
objects that help determine dispatch should be instantiated objects.
"""

_handler_pairs: List[
Tuple[Callable[[Any], bool], Callable[[Any, Any], Any], bool]
_converter: BaseConverter
_handler_pairs: list[
tuple[Callable[[Any], bool], Callable[[Any, Any], Any], bool, bool]
] = Factory(list)

def register(
self,
can_handle: Callable[[Any], bool],
predicate: Callable[[Any], bool],
func: Callable[..., Any],
is_generator=False,
takes_converter=False,
) -> None:
self._handler_pairs.insert(0, (can_handle, func, is_generator))
self._handler_pairs.insert(0, (predicate, func, is_generator, takes_converter))

def dispatch(self, typ: Any) -> Optional[Callable[..., Any]]:
def dispatch(self, typ: Any) -> Callable[..., Any] | None:
"""
Return the appropriate handler for the object passed.
"""
for can_handle, handler, is_generator in self._handler_pairs:
for can_handle, handler, is_generator, takes_converter in self._handler_pairs:
# can handle could raise an exception here
# such as issubclass being called on an instance.
# it's easier to just ignore that case.
Expand All @@ -59,6 +66,8 @@ def dispatch(self, typ: Any) -> Optional[Callable[..., Any]]:
continue
if ch:
if is_generator:
if takes_converter:
return handler(typ, self._converter)
return handler(typ)

return handler
Expand All @@ -67,11 +76,11 @@ def dispatch(self, typ: Any) -> Optional[Callable[..., Any]]:
def get_num_fns(self) -> int:
return len(self._handler_pairs)

def copy_to(self, other: "FunctionDispatch", skip: int = 0) -> None:
def copy_to(self, other: FunctionDispatch, skip: int = 0) -> None:
other._handler_pairs = self._handler_pairs[:-skip] + other._handler_pairs


@define
@define(init=False)
class MultiStrategyDispatch(Generic[Hook]):
"""
MultiStrategyDispatch uses a combination of exact-match dispatch,
Expand All @@ -85,18 +94,20 @@ class MultiStrategyDispatch(Generic[Hook]):
"""

_fallback_factory: HookFactory[Hook]
_direct_dispatch: Dict[TargetType, Hook] = field(init=False, factory=dict)
_function_dispatch: FunctionDispatch = field(init=False, factory=FunctionDispatch)
_single_dispatch: Any = field(
init=False, factory=partial(singledispatch, _DispatchNotFound)
)
dispatch: Callable[[TargetType], Hook] = field(
init=False,
default=Factory(
lambda self: lru_cache(maxsize=None)(self.dispatch_without_caching),
takes_self=True,
),
)
_converter: BaseConverter
_direct_dispatch: dict[TargetType, Hook]
_function_dispatch: FunctionDispatch
_single_dispatch: Any
dispatch: Callable[[TargetType, BaseConverter], Hook]

def __init__(
self, fallback_factory: HookFactory[Hook], converter: BaseConverter
) -> None:
self._fallback_factory = fallback_factory
self._direct_dispatch = {}
self._function_dispatch = FunctionDispatch(converter)
self._single_dispatch = singledispatch(_DispatchNotFound)
self.dispatch = lru_cache(maxsize=None)(self.dispatch_without_caching)

def dispatch_without_caching(self, typ: TargetType) -> Hook:
"""Dispatch on the type but without caching the result."""
Expand Down Expand Up @@ -126,15 +137,18 @@ def register_cls_list(self, cls_and_handler, direct: bool = False) -> None:

def register_func_list(
self,
pred_and_handler: List[
Union[
Tuple[Callable[[Any], bool], Any],
Tuple[Callable[[Any], bool], Any, bool],
pred_and_handler: list[
tuple[Callable[[Any], bool], Any]
| tuple[Callable[[Any], bool], Any, bool]
| tuple[
Callable[[Any], bool],
Callable[[Any, BaseConverter], Any],
Literal["extended"],
]
],
):
"""
Register a predicate function to determine if the handle
Register a predicate function to determine if the handler
should be used for the type.
"""
for tup in pred_and_handler:
Expand All @@ -143,7 +157,12 @@ def register_func_list(
self._function_dispatch.register(func, handler)
else:
func, handler, is_gen = tup
self._function_dispatch.register(func, handler, is_generator=is_gen)
if is_gen == "extended":
self._function_dispatch.register(
func, handler, is_generator=is_gen, takes_converter=True
)
else:
self._function_dispatch.register(func, handler, is_generator=is_gen)
self.clear_direct()
self.dispatch.cache_clear()

Expand All @@ -159,7 +178,7 @@ def clear_cache(self) -> None:
def get_num_fns(self) -> int:
return self._function_dispatch.get_num_fns()

def copy_to(self, other: "MultiStrategyDispatch", skip: int = 0) -> None:
def copy_to(self, other: MultiStrategyDispatch, skip: int = 0) -> None:
self._function_dispatch.copy_to(other._function_dispatch, skip=skip)
for cls, fn in self._single_dispatch.registry.items():
other._single_dispatch.register(cls, fn)
Expand Down
4 changes: 2 additions & 2 deletions src/cattrs/gen/typeddicts.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def make_dict_unstructure_fn(
if nrb is not NOTHING:
t = nrb
try:
handler = converter._unstructure_func.dispatch(t)
handler = converter.get_unstructure_hook(t)
except RecursionError:
# There's a circular reference somewhere down the line
handler = converter.unstructure
Expand Down Expand Up @@ -185,7 +185,7 @@ def make_dict_unstructure_fn(
if nrb is not NOTHING:
t = nrb
try:
handler = converter._unstructure_func.dispatch(t)
handler = converter.get_unstructure_hook(t)
except RecursionError:
# There's a circular reference somewhere down the line
handler = converter.unstructure
Expand Down
2 changes: 1 addition & 1 deletion src/cattrs/preconf/orjson.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def key_handler(v):
# (For example base85 encoding for bytes.)
# In that case, we want to use the override.

kh = converter._unstructure_func.dispatch(args[0])
kh = converter.get_unstructure_hook(args[0])
if kh != identity:
key_handler = kh

Expand Down
Loading

0 comments on commit e613d74

Please sign in to comment.