Skip to content

Commit

Permalink
Merge pull request #334 from lubaskinc0de/literal-bytes
Browse files Browse the repository at this point in the history
Literal[b"abc"] loader
  • Loading branch information
zhPavel authored Aug 31, 2024
2 parents 0600de6 + 9429a79 commit 418ef28
Show file tree
Hide file tree
Showing 4 changed files with 242 additions and 31 deletions.
1 change: 1 addition & 0 deletions docs/changelog/fragments/318.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add support for bytes inside literal ``Literal[b"abc"]``
5 changes: 3 additions & 2 deletions docs/loading-and-dumping/specific-types-behavior.rst
Original file line number Diff line number Diff line change
Expand Up @@ -170,12 +170,13 @@ Literal
Loader accepts only values listed in ``Literal``.
If ``strict_coercion`` is enabled, the loader will distinguish equal ``bool`` and ``int`` instances,
otherwise, they will be considered as same values.
``Enum`` instances will be loaded via its loaders. Enum loaders have a higher priority over others, that is, they will be applied first.
``Enum`` instances will be loaded via its loaders. ``bytes`` instances (e.g ``b"abc"``) will be loaded via its loaders as well.
Enum loaders have a higher priority over others, that is, they will be applied first.

If the input value could be interpreted as several ``Literal`` members, the result will be undefined.

Dumper will return value without any processing excluding ``Enum`` instances,
they will be processed via the corresponding dumper.
they will be processed via the corresponding dumper. ``bytes`` instances also will be processed via the corresponding dumper.

Be careful when you use a ``0``, ``1``, ``False`` and ``True`` as ``Literal`` members.
Due to type hint caching ``Literal[0, 1]`` sometimes returns ``Literal[False, True]``.
Expand Down
159 changes: 132 additions & 27 deletions src/adaptix/_internal/morphing/generic_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,32 +92,57 @@ def _get_enum_types(self, cases: Collection) -> Collection:
return enum_types

def _fetch_enum_loaders(
self, mediator: Mediator, request: LoaderRequest, enum_classes: Iterable[type[Enum]],
self,
mediator: Mediator,
request: LoaderRequest,
enum_classes: Iterable[type[Enum]],
) -> Iterable[Loader[Enum]]:
requests = [
request.append_loc(TypeHintLoc(type=enum_cls))
for enum_cls in enum_classes
]
requests = [request.append_loc(TypeHintLoc(type=enum_cls)) for enum_cls in enum_classes]
return mediator.mandatory_provide_by_iterable(
requests,
lambda: "Cannot create loaders for enum. Loader for literal cannot be created",
)

def _fetch_bytes_loader(
self,
mediator: Mediator,
request: LoaderRequest,
) -> Loader[bytes]:
request = request.append_loc(TypeHintLoc(type=bytes))
return mediator.mandatory_provide(
request,
lambda _: "Cannot create loader for literal. Loader for bytes cannot be created",
)

def _fetch_enum_dumpers(
self, mediator: Mediator, request: DumperRequest, enum_classes: Iterable[type[Enum]],
self,
mediator: Mediator,
request: DumperRequest,
enum_classes: Iterable[type[Enum]],
) -> Mapping[type[Enum], Dumper[Enum]]:
requests = [
request.append_loc(TypeHintLoc(type=enum_cls))
for enum_cls in enum_classes
]
requests = [request.append_loc(TypeHintLoc(type=enum_cls)) for enum_cls in enum_classes]
dumpers = mediator.mandatory_provide_by_iterable(
requests,
lambda: "Cannot create loaders for enum. Loader for literal cannot be created",
)
return dict(zip(enum_classes, dumpers))

def _fetch_bytes_dumper(
self,
mediator: Mediator,
request: DumperRequest,
) -> Dumper[bytes]:
request = request.append_loc(TypeHintLoc(type=bytes))
return mediator.mandatory_provide(
request,
lambda _: "Cannot create dumper for literal. Dumper for bytes cannot be created",
)

def _get_literal_loader_with_enum( # noqa: C901
self, basic_loader: Loader, enum_loaders: Sequence[Loader[Enum]], allowed_values: Collection,
self,
basic_loader: Loader,
enum_loaders: Sequence[Loader[Enum]],
allowed_values: Collection,
) -> Loader:
if not enum_loaders:
return basic_loader
Expand Down Expand Up @@ -150,17 +175,55 @@ def wrapped_loader_with_enums(data):

return wrapped_loader_with_enums

def _get_literal_loader_with_bytes(
self,
basic_loader: Loader,
allowed_values: Collection,
bytes_loader: Loader,
) -> Loader:
def wrapped_loader_with_bytes(data):
try:
bytes_value = bytes_loader(data)
except LoadError:
pass
else:
if bytes_value in allowed_values:
return bytes_value
return basic_loader(data)

return wrapped_loader_with_bytes

def _get_literal_loader_many(self, *loaders: Loader, basic_loader: Loader) -> Loader:
if len(loaders) == 1:
return loaders[0]

def wrapped_loader_many(data):
for c, loader in enumerate(loaders):
try:
return loader(data)
except LoadError:
last_iteration = len(loaders) - 1
if c != last_iteration:
continue
return basic_loader(data)

return wrapped_loader_many

def provide_loader(self, mediator: Mediator, request: LoaderRequest) -> Loader:
norm = try_normalize_type(request.last_loc.type)
strict_coercion = mediator.mandatory_provide(StrictCoercionRequest(loc_stack=request.loc_stack))
enum_cases = tuple(arg for arg in norm.args if isinstance(arg, Enum))
bytes_cases = tuple(arg for arg in norm.args if isinstance(arg, bytes))
enum_loaders = tuple(self._fetch_enum_loaders(mediator, request, self._get_enum_types(enum_cases)))
bytes_loader = self._fetch_bytes_loader(mediator, request)
allowed_values_repr = self._get_allowed_values_repr(norm.args, mediator, request.loc_stack)
return mediator.cached_call(
self._make_loader,
cases=norm.args,
bytes_cases=bytes_cases,
strict_coercion=strict_coercion,
enum_loaders=enum_loaders,
bytes_loader=bytes_loader,
allowed_values_repr=allowed_values_repr,
)

Expand All @@ -171,11 +234,10 @@ def _make_loader(
strict_coercion: bool,
enum_loaders: Sequence[Loader],
allowed_values_repr: Collection[str],
bytes_cases: Sequence[bytes],
bytes_loader: Loader[bytes],
) -> Loader:
if strict_coercion and any(
isinstance(arg, bool) or _is_exact_zero_or_one(arg)
for arg in cases
):
if strict_coercion and any(isinstance(arg, bool) or _is_exact_zero_or_one(arg) for arg in cases):
allowed_values_with_types = self._get_allowed_values_collection(
[(type(el), el) for el in cases],
)
Expand All @@ -187,7 +249,9 @@ def literal_loader_sc(data):
raise BadVariantLoadError(allowed_values_repr, data)

return self._get_literal_loader_with_enum(
literal_loader_sc, enum_loaders, allowed_values_with_types,
literal_loader_sc,
enum_loaders,
allowed_values_with_types,
)

allowed_values = self._get_allowed_values_collection(cases)
Expand All @@ -197,24 +261,36 @@ def literal_loader(data):
return data
raise BadVariantLoadError(allowed_values_repr, data)

return self._get_literal_loader_with_enum(literal_loader, enum_loaders, allowed_values)
if bytes_cases and not enum_loaders:
return self._get_literal_loader_with_bytes(literal_loader, allowed_values, bytes_loader)

if not bytes_cases:
return self._get_literal_loader_with_enum(literal_loader, enum_loaders, allowed_values)

return self._get_literal_loader_many(
self._get_literal_loader_with_bytes(literal_loader, allowed_values, bytes_loader),
self._get_literal_loader_with_enum(literal_loader, enum_loaders, allowed_values),
basic_loader=literal_loader,
)

def provide_dumper(self, mediator: Mediator, request: DumperRequest) -> Dumper:
norm = try_normalize_type(request.last_loc.type)
enum_cases = [arg for arg in norm.args if isinstance(arg, Enum)]
bytes_cases = tuple(arg for arg in norm.args if isinstance(arg, bytes))

if not enum_cases:
if not enum_cases and not bytes_cases:
return as_is_stub

enum_dumpers = self._fetch_enum_dumpers(mediator, request, self._get_enum_types(enum_cases))
bytes_dumper = self._fetch_bytes_dumper(mediator, request)

return mediator.cached_call(
self._make_dumper,
enum_dumpers_wrapper=MappingHashWrapper(enum_dumpers),
bytes_dumper=bytes_dumper,
)

def _make_dumper(self, enum_dumpers_wrapper: MappingHashWrapper[Mapping[type[Enum], Dumper[Enum]]]):
enum_dumpers = enum_dumpers_wrapper.mapping

def _get_enum_dumper(self, enum_dumpers: Mapping[type[Enum], Dumper[Enum]]) -> Dumper:
if len(enum_dumpers) == 1:
enum_dumper = next(iter(enum_dumpers.values()))

Expand All @@ -232,6 +308,39 @@ def literal_dumper_with_enums(data):

return literal_dumper_with_enums

def _get_bytes_literal_dumper(self, bytes_dumper: Dumper[bytes]) -> Dumper:
def literal_dumper_with_bytes(data):
if isinstance(data, bytes):
return bytes_dumper(data)
return data

return literal_dumper_with_bytes

def _make_dumper(
self,
enum_dumpers_wrapper: MappingHashWrapper[Mapping[type[Enum], Dumper[Enum]]],
bytes_dumper: Optional[Dumper[bytes]],
):
enum_dumpers = enum_dumpers_wrapper.mapping

if not bytes_dumper:
return self._get_enum_dumper(enum_dumpers)

if not enum_dumpers:
return self._get_bytes_literal_dumper(bytes_dumper)

bytes_literal_dumper = self._get_bytes_literal_dumper(bytes_dumper)
enum_literal_dumper = self._get_enum_dumper(enum_dumpers)

def literal_dumper_many(data):
if isinstance(data, bytes):
return bytes_literal_dumper(data)
if isinstance(data, Enum):
return enum_literal_dumper(data)
return data

return literal_dumper_many


@for_predicate(Union)
class UnionProvider(LoaderProvider, DumperProvider):
Expand Down Expand Up @@ -367,9 +476,7 @@ def provide_dumper(self, mediator: Mediator, request: DumperRequest) -> Dumper:
return mediator.cached_call(self._get_single_optional_dumper, not_none_dumper)

forbidden_origins = [
case.source
for case in norm.args
if not self._is_class_origin(case.origin) and case.origin != Literal
case.source for case in norm.args if not self._is_class_origin(case.origin) and case.origin != Literal
]

if forbidden_origins:
Expand Down Expand Up @@ -435,9 +542,7 @@ def _get_dumper_for_literal(
) -> Optional[Dumper]:
try:
literal_type, literal_dumper = next(
(union_case, dumper) for union_case, dumper
in zip(norm.args, dumpers)
if union_case.origin is Literal
(union_case, dumper) for union_case, dumper in zip(norm.args, dumpers) if union_case.origin is Literal
)
except StopIteration:
return None
Expand Down
Loading

0 comments on commit 418ef28

Please sign in to comment.