diff --git a/docs/changelog/fragments/318.feature.rst b/docs/changelog/fragments/318.feature.rst new file mode 100644 index 00000000..33363870 --- /dev/null +++ b/docs/changelog/fragments/318.feature.rst @@ -0,0 +1 @@ +Add support for bytes inside literal ``Literal[b"abc"]`` diff --git a/docs/loading-and-dumping/specific-types-behavior.rst b/docs/loading-and-dumping/specific-types-behavior.rst index 5bce1cff..8c544a3b 100644 --- a/docs/loading-and-dumping/specific-types-behavior.rst +++ b/docs/loading-and-dumping/specific-types-behavior.rst @@ -170,12 +170,13 @@ Literal Loader accepts only values listed in ``Literal``. If ``strict_coercion`` is enabled, the loader will distinguish equal ``bool`` and ``int`` instances, otherwise, they will be considered as same values. -``Enum`` instances will be loaded via its loaders. Enum loaders have a higher priority over others, that is, they will be applied first. +``Enum`` instances will be loaded via its loaders. ``bytes`` instances (e.g ``b"abc"``) will be loaded via its loaders as well. +Enum loaders have a higher priority over others, that is, they will be applied first. If the input value could be interpreted as several ``Literal`` members, the result will be undefined. Dumper will return value without any processing excluding ``Enum`` instances, -they will be processed via the corresponding dumper. +they will be processed via the corresponding dumper. ``bytes`` instances also will be processed via the corresponding dumper. Be careful when you use a ``0``, ``1``, ``False`` and ``True`` as ``Literal`` members. Due to type hint caching ``Literal[0, 1]`` sometimes returns ``Literal[False, True]``. diff --git a/src/adaptix/_internal/morphing/generic_provider.py b/src/adaptix/_internal/morphing/generic_provider.py index ae974b0f..26376227 100644 --- a/src/adaptix/_internal/morphing/generic_provider.py +++ b/src/adaptix/_internal/morphing/generic_provider.py @@ -92,32 +92,57 @@ def _get_enum_types(self, cases: Collection) -> Collection: return enum_types def _fetch_enum_loaders( - self, mediator: Mediator, request: LoaderRequest, enum_classes: Iterable[type[Enum]], + self, + mediator: Mediator, + request: LoaderRequest, + enum_classes: Iterable[type[Enum]], ) -> Iterable[Loader[Enum]]: - requests = [ - request.append_loc(TypeHintLoc(type=enum_cls)) - for enum_cls in enum_classes - ] + requests = [request.append_loc(TypeHintLoc(type=enum_cls)) for enum_cls in enum_classes] return mediator.mandatory_provide_by_iterable( requests, lambda: "Cannot create loaders for enum. Loader for literal cannot be created", ) + def _fetch_bytes_loader( + self, + mediator: Mediator, + request: LoaderRequest, + ) -> Loader[bytes]: + request = request.append_loc(TypeHintLoc(type=bytes)) + return mediator.mandatory_provide( + request, + lambda _: "Cannot create loader for literal. Loader for bytes cannot be created", + ) + def _fetch_enum_dumpers( - self, mediator: Mediator, request: DumperRequest, enum_classes: Iterable[type[Enum]], + self, + mediator: Mediator, + request: DumperRequest, + enum_classes: Iterable[type[Enum]], ) -> Mapping[type[Enum], Dumper[Enum]]: - requests = [ - request.append_loc(TypeHintLoc(type=enum_cls)) - for enum_cls in enum_classes - ] + requests = [request.append_loc(TypeHintLoc(type=enum_cls)) for enum_cls in enum_classes] dumpers = mediator.mandatory_provide_by_iterable( requests, lambda: "Cannot create loaders for enum. Loader for literal cannot be created", ) return dict(zip(enum_classes, dumpers)) + def _fetch_bytes_dumper( + self, + mediator: Mediator, + request: DumperRequest, + ) -> Dumper[bytes]: + request = request.append_loc(TypeHintLoc(type=bytes)) + return mediator.mandatory_provide( + request, + lambda _: "Cannot create dumper for literal. Dumper for bytes cannot be created", + ) + def _get_literal_loader_with_enum( # noqa: C901 - self, basic_loader: Loader, enum_loaders: Sequence[Loader[Enum]], allowed_values: Collection, + self, + basic_loader: Loader, + enum_loaders: Sequence[Loader[Enum]], + allowed_values: Collection, ) -> Loader: if not enum_loaders: return basic_loader @@ -150,17 +175,55 @@ def wrapped_loader_with_enums(data): return wrapped_loader_with_enums + def _get_literal_loader_with_bytes( + self, + basic_loader: Loader, + allowed_values: Collection, + bytes_loader: Loader, + ) -> Loader: + def wrapped_loader_with_bytes(data): + try: + bytes_value = bytes_loader(data) + except LoadError: + pass + else: + if bytes_value in allowed_values: + return bytes_value + return basic_loader(data) + + return wrapped_loader_with_bytes + + def _get_literal_loader_many(self, *loaders: Loader, basic_loader: Loader) -> Loader: + if len(loaders) == 1: + return loaders[0] + + def wrapped_loader_many(data): + for c, loader in enumerate(loaders): + try: + return loader(data) + except LoadError: + last_iteration = len(loaders) - 1 + if c != last_iteration: + continue + return basic_loader(data) + + return wrapped_loader_many + def provide_loader(self, mediator: Mediator, request: LoaderRequest) -> Loader: norm = try_normalize_type(request.last_loc.type) strict_coercion = mediator.mandatory_provide(StrictCoercionRequest(loc_stack=request.loc_stack)) enum_cases = tuple(arg for arg in norm.args if isinstance(arg, Enum)) + bytes_cases = tuple(arg for arg in norm.args if isinstance(arg, bytes)) enum_loaders = tuple(self._fetch_enum_loaders(mediator, request, self._get_enum_types(enum_cases))) + bytes_loader = self._fetch_bytes_loader(mediator, request) allowed_values_repr = self._get_allowed_values_repr(norm.args, mediator, request.loc_stack) return mediator.cached_call( self._make_loader, cases=norm.args, + bytes_cases=bytes_cases, strict_coercion=strict_coercion, enum_loaders=enum_loaders, + bytes_loader=bytes_loader, allowed_values_repr=allowed_values_repr, ) @@ -171,11 +234,10 @@ def _make_loader( strict_coercion: bool, enum_loaders: Sequence[Loader], allowed_values_repr: Collection[str], + bytes_cases: Sequence[bytes], + bytes_loader: Loader[bytes], ) -> Loader: - if strict_coercion and any( - isinstance(arg, bool) or _is_exact_zero_or_one(arg) - for arg in cases - ): + if strict_coercion and any(isinstance(arg, bool) or _is_exact_zero_or_one(arg) for arg in cases): allowed_values_with_types = self._get_allowed_values_collection( [(type(el), el) for el in cases], ) @@ -187,7 +249,9 @@ def literal_loader_sc(data): raise BadVariantLoadError(allowed_values_repr, data) return self._get_literal_loader_with_enum( - literal_loader_sc, enum_loaders, allowed_values_with_types, + literal_loader_sc, + enum_loaders, + allowed_values_with_types, ) allowed_values = self._get_allowed_values_collection(cases) @@ -197,24 +261,36 @@ def literal_loader(data): return data raise BadVariantLoadError(allowed_values_repr, data) - return self._get_literal_loader_with_enum(literal_loader, enum_loaders, allowed_values) + if bytes_cases and not enum_loaders: + return self._get_literal_loader_with_bytes(literal_loader, allowed_values, bytes_loader) + + if not bytes_cases: + return self._get_literal_loader_with_enum(literal_loader, enum_loaders, allowed_values) + + return self._get_literal_loader_many( + self._get_literal_loader_with_bytes(literal_loader, allowed_values, bytes_loader), + self._get_literal_loader_with_enum(literal_loader, enum_loaders, allowed_values), + basic_loader=literal_loader, + ) def provide_dumper(self, mediator: Mediator, request: DumperRequest) -> Dumper: norm = try_normalize_type(request.last_loc.type) enum_cases = [arg for arg in norm.args if isinstance(arg, Enum)] + bytes_cases = tuple(arg for arg in norm.args if isinstance(arg, bytes)) - if not enum_cases: + if not enum_cases and not bytes_cases: return as_is_stub enum_dumpers = self._fetch_enum_dumpers(mediator, request, self._get_enum_types(enum_cases)) + bytes_dumper = self._fetch_bytes_dumper(mediator, request) + return mediator.cached_call( self._make_dumper, enum_dumpers_wrapper=MappingHashWrapper(enum_dumpers), + bytes_dumper=bytes_dumper, ) - def _make_dumper(self, enum_dumpers_wrapper: MappingHashWrapper[Mapping[type[Enum], Dumper[Enum]]]): - enum_dumpers = enum_dumpers_wrapper.mapping - + def _get_enum_dumper(self, enum_dumpers: Mapping[type[Enum], Dumper[Enum]]) -> Dumper: if len(enum_dumpers) == 1: enum_dumper = next(iter(enum_dumpers.values())) @@ -232,6 +308,39 @@ def literal_dumper_with_enums(data): return literal_dumper_with_enums + def _get_bytes_literal_dumper(self, bytes_dumper: Dumper[bytes]) -> Dumper: + def literal_dumper_with_bytes(data): + if isinstance(data, bytes): + return bytes_dumper(data) + return data + + return literal_dumper_with_bytes + + def _make_dumper( + self, + enum_dumpers_wrapper: MappingHashWrapper[Mapping[type[Enum], Dumper[Enum]]], + bytes_dumper: Optional[Dumper[bytes]], + ): + enum_dumpers = enum_dumpers_wrapper.mapping + + if not bytes_dumper: + return self._get_enum_dumper(enum_dumpers) + + if not enum_dumpers: + return self._get_bytes_literal_dumper(bytes_dumper) + + bytes_literal_dumper = self._get_bytes_literal_dumper(bytes_dumper) + enum_literal_dumper = self._get_enum_dumper(enum_dumpers) + + def literal_dumper_many(data): + if isinstance(data, bytes): + return bytes_literal_dumper(data) + if isinstance(data, Enum): + return enum_literal_dumper(data) + return data + + return literal_dumper_many + @for_predicate(Union) class UnionProvider(LoaderProvider, DumperProvider): @@ -367,9 +476,7 @@ def provide_dumper(self, mediator: Mediator, request: DumperRequest) -> Dumper: return mediator.cached_call(self._get_single_optional_dumper, not_none_dumper) forbidden_origins = [ - case.source - for case in norm.args - if not self._is_class_origin(case.origin) and case.origin != Literal + case.source for case in norm.args if not self._is_class_origin(case.origin) and case.origin != Literal ] if forbidden_origins: @@ -435,9 +542,7 @@ def _get_dumper_for_literal( ) -> Optional[Dumper]: try: literal_type, literal_dumper = next( - (union_case, dumper) for union_case, dumper - in zip(norm.args, dumpers) - if union_case.origin is Literal + (union_case, dumper) for union_case, dumper in zip(norm.args, dumpers) if union_case.origin is Literal ) except StopIteration: return None diff --git a/tests/unit/morphing/generic_provider/test_literal_provider.py b/tests/unit/morphing/generic_provider/test_literal_provider.py index 320cb533..c88e80a6 100644 --- a/tests/unit/morphing/generic_provider/test_literal_provider.py +++ b/tests/unit/morphing/generic_provider/test_literal_provider.py @@ -1,11 +1,12 @@ # ruff: noqa: FBT003 from enum import Enum -from typing import Literal +from typing import Any, Iterable, Literal from uuid import uuid4 +import pytest from tests_helpers import raises_exc -from adaptix import Retort +from adaptix import P, Provider, Retort, dumper, loader from adaptix._internal.morphing.load_error import BadVariantLoadError @@ -124,6 +125,66 @@ class Enum2(Enum): ) +@pytest.mark.parametrize( + ["input_data", "recipe"], + [ + ("YWJj", []), + ("abc", [loader(P[bytes], lambda x: x.encode())]), + ], +) +def test_loader_with_bytes( + strict_coercion, + debug_trail, + input_data: Any, + recipe: Iterable[Provider], +): + retort = Retort( + recipe=recipe, + ) + + loader = retort.replace( + strict_coercion=strict_coercion, + debug_trail=debug_trail, + ).get_loader( + Literal[b"abc"], + ) + + assert loader(input_data) == b"abc" + + raises_exc( + BadVariantLoadError({b"abc"}, "YWJ"), + lambda: loader("YWJ"), + ) + + +def test_loader_with_bytes_and_enums(strict_coercion, debug_trail): + class Enum1(Enum): + CASE1 = 1 + CASE2 = 2 + + retort = Retort() + + loader = retort.replace( + strict_coercion=strict_coercion, + debug_trail=debug_trail, + ).get_loader( + Literal[b"abc", Enum1.CASE1], + ) + + assert loader("YWJj") == b"abc" + assert loader(1) == Enum1.CASE1 + + raises_exc( + BadVariantLoadError({b"abc", Enum1.CASE1.value}, "YWJ"), + lambda: loader("YWJ"), + ) + + raises_exc( + BadVariantLoadError({b"abc", Enum1.CASE1.value}, 2), + lambda: loader(2), + ) + + def test_dumper_with_enums(strict_coercion, debug_trail): retort = Retort() @@ -156,3 +217,46 @@ class Enum2(Enum): assert dumper(Enum1.CASE1) == 1 assert dumper(Enum1.CASE2) == 2 assert dumper(10) == 10 + +@pytest.mark.parametrize( + ["expected_data", "recipe"], + [ + ("YWJj", []), + ("abc", [dumper(P[bytes], lambda x: x.decode())]), + ], +) +def test_dumper_with_bytes(strict_coercion, debug_trail, expected_data: Any, recipe: Iterable[Provider]): + retort = Retort( + recipe=recipe, + ) + + dumper = retort.replace( + strict_coercion=strict_coercion, + debug_trail=debug_trail, + ).get_dumper( + Literal[b"abc"], + ) + + assert dumper(b"abc") == expected_data + + +def test_dumper_with_bytes_and_enums(strict_coercion, debug_trail): + class Enum1(Enum): + CASE1 = 1 + CASE2 = 2 + + class Enum2(Enum): + CASE1 = 1 + CASE2 = 2 + + retort = Retort() + + dumper = retort.replace( + strict_coercion=strict_coercion, + debug_trail=debug_trail, + ).get_dumper( + Literal[b"abc", Enum1.CASE1], + ) + + assert dumper(b"abc") == "YWJj" + assert dumper(Enum1.CASE1) == 1