Skip to content

Commit

Permalink
some fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
zhPavel committed Aug 10, 2024
1 parent c9e2a1e commit 9de9901
Show file tree
Hide file tree
Showing 8 changed files with 36 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,11 @@ def provide_dumper(self, mediator: Mediator, request: DumperRequest) -> Dumper:
lambda: "Cannot create dumper for tuple. Dumpers for some elements cannot be created",
)
debug_trail = mediator.mandatory_provide(DebugTrailRequest(loc_stack=request.loc_stack))
return mediator.cached_call(self._make_dumper,
dumpers=tuple(dumpers),
debug_trail=debug_trail)
return mediator.cached_call(
self._make_dumper,
dumpers=tuple(dumpers),
debug_trail=debug_trail,
)

def _make_dumper(self, dumpers: Collection[Dumper], debug_trail: DebugTrail):
if debug_trail == DebugTrail.DISABLE:
Expand Down
6 changes: 2 additions & 4 deletions src/adaptix/_internal/morphing/dict_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,9 +280,7 @@ def defaultdict_loader(data):
def provide_dumper(self, mediator: Mediator, request: DumperRequest) -> Dumper:
key, value = self._extract_key_value(request)
dict_type_hint = Dict[key.source, value.source] # type: ignore[misc, name-defined]

return mediator.cached_call(
self._DICT_PROVIDER.provide_dumper,
mediator=mediator,
return self._DICT_PROVIDER.provide_dumper(
mediator,
request=replace(request, loc_stack=request.loc_stack.replace_last_type(dict_type_hint)),
)
72 changes: 21 additions & 51 deletions src/adaptix/_internal/morphing/generic_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from enum import Enum
from os import PathLike
from pathlib import Path
from typing import Any, Collection, Dict, Iterable, Literal, Optional, Sequence, Set, Type, TypeVar, Union
from typing import Any, Collection, Iterable, Literal, Mapping, Optional, Sequence, Set, Type, TypeVar, Union

from ..common import Dumper, Loader, TypeHint
from ..compat import CompatExceptionGroup
Expand All @@ -16,6 +16,7 @@
from ..provider.location import GenericParamLoc, TypeHintLoc
from ..special_cases_optimization import as_is_stub
from ..type_tools import BaseNormType, NormTypeAlias, is_new_type, is_subclass_soft, strip_tags
from ..utils import MappingHashWrapper
from .load_error import BadVariantLoadError, LoadError, TypeLoadError, UnionLoadError
from .provider_template import DumperProvider, LoaderProvider
from .request_cls import DebugTrailRequest, DumperRequest, LoaderRequest, StrictCoercionRequest
Expand Down Expand Up @@ -104,7 +105,7 @@ def _fetch_enum_loaders(

def _fetch_enum_dumpers(
self, mediator: Mediator, request: DumperRequest, enum_classes: Iterable[Type[Enum]],
) -> Dict[Type[Enum], Dumper[Enum]]:
) -> Mapping[Type[Enum], Dumper[Enum]]:
requests = [
request.append_loc(TypeHintLoc(type=enum_cls))
for enum_cls in enum_classes
Expand Down Expand Up @@ -152,8 +153,7 @@ def wrapped_loader_with_enums(data):
def provide_loader(self, mediator: Mediator, request: LoaderRequest) -> Loader:
norm = try_normalize_type(request.last_loc.type)
strict_coercion = mediator.mandatory_provide(StrictCoercionRequest(loc_stack=request.loc_stack))

enum_cases = [arg for arg in norm.args if isinstance(arg, Enum)]
enum_cases = tuple(arg for arg in norm.args if isinstance(arg, Enum))
enum_loaders = tuple(self._fetch_enum_loaders(mediator, request, self._get_enum_types(enum_cases)))
allowed_values_repr = self._get_allowed_values_repr(norm.args, mediator, request.loc_stack)
return mediator.cached_call(
Expand All @@ -166,6 +166,7 @@ def provide_loader(self, mediator: Mediator, request: LoaderRequest) -> Loader:

def _make_loader(
self,
*,
cases: Sequence[Any],
strict_coercion: bool,
enum_loaders: Sequence[Loader],
Expand Down Expand Up @@ -196,29 +197,23 @@ def literal_loader(data):
return data
raise BadVariantLoadError(allowed_values_repr, data)

return mediator.cached_call(
self._get_literal_loader_with_enum,
basic_loader=literal_loader,
enum_loaders=enum_loaders,
allowed_values=allowed_values,
)
return self._get_literal_loader_with_enum(literal_loader, enum_loaders, allowed_values)

def provide_dumper(self, mediator: Mediator, request: DumperRequest) -> Dumper:
norm = try_normalize_type(request.last_loc.type)
return mediator.cached_call(
self._make_dumper,
norm=norm,
mediator=mediator,
request=request,
)

def _make_dumper(self, norm: BaseNormType, mediator: Mediator, request: DumperRequest):
enum_cases = [arg for arg in norm.args if isinstance(arg, Enum)]

if not enum_cases:
return as_is_stub

enum_dumpers = self._fetch_enum_dumpers(mediator, request, self._get_enum_types(enum_cases))
return mediator.cached_call(
self._make_dumper,
enum_dumpers_wrapper=MappingHashWrapper(enum_dumpers),
)

def _make_dumper(self, enum_dumpers_wrapper: MappingHashWrapper[Mapping[Type[Enum], Dumper[Enum]]]):
enum_dumpers = enum_dumpers_wrapper.mapping

if len(enum_dumpers) == 1:
enum_dumper = next(iter(enum_dumpers.values()))
Expand All @@ -237,21 +232,13 @@ def literal_dumper_with_enums(data):

return literal_dumper_with_enums


@for_predicate(Union)
class UnionProvider(LoaderProvider, DumperProvider):
def provide_loader(self, mediator: Mediator, request: LoaderRequest) -> Loader:
norm = try_normalize_type(request.last_loc.type)
debug_trail = mediator.mandatory_provide(DebugTrailRequest(loc_stack=request.loc_stack))

return mediator.cached_call(
self._make_loader,
norm=norm,
debug_trail=debug_trail,
mediator=mediator,
request=request,
)

def _make_loader(self, norm: BaseNormType, debug_trail: DebugTrail, mediator: Mediator, request: LoaderRequest):
if self._is_single_optional(norm):
not_none = next(case for case in norm.args if case.origin is not None)
not_none_loader = mediator.mandatory_provide(
Expand All @@ -264,14 +251,9 @@ def _make_loader(self, norm: BaseNormType, debug_trail: DebugTrail, mediator: Me
lambda x: "Cannot create loader for union. Loaders for some union cases cannot be created",
)
if debug_trail in (DebugTrail.ALL, DebugTrail.FIRST):
return self._single_optional_dt_loader(
tp=norm.source,
loader=not_none_loader,
)
return mediator.cached_call(self._single_optional_dt_loader, norm.source, not_none_loader)
if debug_trail == DebugTrail.DISABLE:
return self._single_optional_dt_disable_loader(
loader=not_none_loader,
)
return mediator.cached_call(self._single_optional_dt_disable_loader, not_none_loader)
raise ValueError

loaders = mediator.mandatory_provide_by_iterable(
Expand All @@ -287,11 +269,11 @@ def _make_loader(self, norm: BaseNormType, debug_trail: DebugTrail, mediator: Me
lambda: "Cannot create loader for union. Loaders for some union cases cannot be created",
)
if debug_trail == DebugTrail.DISABLE:
return self._get_loader_dt_disable(loader_iter=tuple(loaders))
return mediator.cached_call(self._get_loader_dt_disable, tuple(loaders))
if debug_trail == DebugTrail.FIRST:
return self._get_loader_dt_first(tp=norm.source, loader_iter=tuple(loaders))
return mediator.cached_call(self._get_loader_dt_first, norm.source, tuple(loaders))
if debug_trail == DebugTrail.ALL:
return self._get_loader_dt_all(tp=norm.source, loader_iter=tuple(loaders))
return mediator.cached_call(self._get_loader_dt_all, norm.source, tuple(loaders))
raise ValueError

def _single_optional_dt_disable_loader(self, loader: Loader) -> Loader:
Expand Down Expand Up @@ -369,14 +351,6 @@ def provide_dumper(self, mediator: Mediator, request: DumperRequest) -> Dumper:
request_type = request.last_loc.type
norm = try_normalize_type(request_type)

return mediator.cached_call(
self._make_dumper,
norm=norm,
mediator=mediator,
request=request,
)

def _make_dumper(self, norm: BaseNormType, mediator: Mediator, request: DumperRequest):
if self._is_single_optional(norm):
not_none = next(case for case in norm.args if case.origin is not None)
not_none_dumper = mediator.mandatory_provide(
Expand All @@ -390,9 +364,7 @@ def _make_dumper(self, norm: BaseNormType, mediator: Mediator, request: DumperRe
)
if not_none_dumper == as_is_stub:
return as_is_stub
return self._get_single_optional_dumper(
dumper=not_none_dumper,
)
return mediator.cached_call(self._get_single_optional_dumper, not_none_dumper)

forbidden_origins = [
case.source
Expand Down Expand Up @@ -434,9 +406,7 @@ def _make_dumper(self, norm: BaseNormType, dumpers: Iterable[Dumper]) -> Dumper:
if literal_dumper:
return literal_dumper

return self._produce_dumper(
dumper_type_dispatcher=dumper_type_dispatcher,
)
return self._produce_dumper(dumper_type_dispatcher)

def _produce_dumper(self, dumper_type_dispatcher: ClassDispatcher[Any, Dumper]) -> Dumper:
def union_dumper(data):
Expand Down
5 changes: 2 additions & 3 deletions src/adaptix/_internal/morphing/name_layout/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from ...provider.located_request import LocatedRequest
from ...provider.overlay_schema import Overlay, Schema, provide_schema
from ...retort.operating_retort import OperatingRetort
from ...retort.searching_retort import ProviderNotFoundError
from ...special_cases_optimization import with_default_clause
from ...utils import Omittable, get_prefix_groups
from ..model.crown_definitions import (
Expand Down Expand Up @@ -105,7 +104,7 @@ def apply_lsc(

class NameMappingRetort(OperatingRetort):
def provide_name_mapping(self, request: NameMappingRequest) -> Optional[KeyPath]:
return self._facade_provide(request, error_message="")
return self._provide_from_recipe(request)


class BuiltinStructureMaker(StructureMaker):
Expand Down Expand Up @@ -146,7 +145,7 @@ def _map_fields(
loc_stack=request.loc_stack.append_with(field_to_loc(field)),
),
)
except ProviderNotFoundError:
except CannotProvide:
path = (generated_key, )

if path is None:
Expand Down
4 changes: 2 additions & 2 deletions src/adaptix/_internal/provider/shape_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,14 @@ def _get_shape(self, tp) -> Shape:

@method_handler
def _provide_input_shape(self, mediator: Mediator, request: InputShapeRequest) -> InputShape:
shape = self._get_shape(request.last_loc.type)
shape = mediator.cached_call(self._get_shape, request.last_loc.type)
if shape.input is None:
raise CannotProvide
return shape.input

@method_handler
def _provide_output_shape(self, mediator: Mediator, request: OutputShapeRequest) -> OutputShape:
shape = self._get_shape(request.last_loc.type)
shape = mediator.cached_call(self._get_shape, request.last_loc.type)
if shape.output is None:
raise CannotProvide
return shape.output
Expand Down
2 changes: 2 additions & 0 deletions src/adaptix/_internal/retort/builtin_mediator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def __init__(
self._no_request_bus_error_maker = no_request_bus_error_maker
self._call_cache = call_cache

__hash__ = None # type: ignore[assignment]

def provide(self, request: Request[T]) -> T:
try:
request_bus = self._request_buses[type(request)]
Expand Down
3 changes: 1 addition & 2 deletions src/adaptix/_internal/retort/operating_retort.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
from typing import Any, Callable, Dict, Generic, Iterable, Optional, Sequence, Type, TypeVar

from ... import TypeHint
from ..conversion.request_cls import CoercerRequest, LinkingRequest
from ..morphing.json_schema.definitions import JSONSchema
from ..morphing.json_schema.request_cls import InlineJSONSchemaRequest, JSONSchemaRefRequest, JSONSchemaRequest
from ..morphing.request_cls import DumperRequest, LoaderRequest
from ..provider.essential import Mediator, Provider, Request
from ..provider.loc_stack_tools import format_loc_stack
from ..provider.located_request import LocatedRequest, LocatedRequestMethodsProvider
from ..provider.location import AnyLoc
from ..provider.methods_provider import method_handler
from .request_bus import ErrorRepresentor, RecursionResolver, RequestRouter
from .routers import CheckerAndHandler, SimpleRouter, create_router_for_located_request
from .searching_retort import SearchingRetort
from ... import TypeHint


class FuncWrapper:
Expand Down
2 changes: 1 addition & 1 deletion src/adaptix/_internal/retort/searching_retort.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class SearchingRetort(BaseRetort, Provider, ABC):
"""A retort that can operate as Retort but have no predefined providers and no high-level user interface"""

def _provide_from_recipe(self, request: Request[T]) -> T:
return self._create_mediator(request).provide_from_next()
return self._create_mediator(request).provide(request)

def get_request_handlers(self) -> Sequence[Tuple[Type[Request], RequestChecker, RequestHandler]]:
def retort_request_handler(mediator, request):
Expand Down

0 comments on commit 9de9901

Please sign in to comment.