Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mediator.cached_call to other providers #326

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/adaptix/_internal/datastructures.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)

from .common import VarTuple
from .utils import MappingHashWrapper

K = TypeVar("K", bound=Hashable)
V = TypeVar("V")
Expand Down Expand Up @@ -114,6 +115,9 @@ def __eq__(self, other):
return self._mapping == other._mapping
return NotImplemented

def __hash__(self):
return hash(MappingHashWrapper(self._mapping))


# It's not a KeysView because __iter__ of KeysView must returns an Iterator[K_co]
# but there is no inverse of Type[]
Expand Down
77 changes: 64 additions & 13 deletions src/adaptix/_internal/morphing/concrete_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ def __repr__(self):
return f"{type(self)}(cls={self._cls})"

def provide_loader(self, mediator: Mediator, request: LoaderRequest) -> Loader:
return mediator.cached_call(self._make_loader)

def _make_loader(self):
raw_loader = self._cls.fromisoformat

def isoformat_loader(data):
Expand All @@ -48,10 +51,12 @@ def isoformat_loader(data):
raise TypeLoadError(str, data)
except ValueError:
raise ValueLoadError("Invalid isoformat string", data)

return isoformat_loader

def provide_dumper(self, mediator: Mediator, request: DumperRequest) -> Dumper:
return mediator.cached_call(self._make_dumper)

def _make_dumper(self):
return self._cls.isoformat

def _generate_json_schema(self, mediator: Mediator, request: JSONSchemaRequest) -> JSONSchema:
Expand All @@ -67,6 +72,9 @@ def __repr__(self):
return f"{type(self)}(fmt={self._fmt})"

def provide_loader(self, mediator: Mediator, request: LoaderRequest) -> Loader:
return mediator.cached_call(self._make_loader)

def _make_loader(self):
fmt = self._fmt

def datetime_format_loader(data):
Expand All @@ -80,6 +88,9 @@ def datetime_format_loader(data):
return datetime_format_loader

def provide_dumper(self, mediator: Mediator, request: DumperRequest) -> Dumper:
return mediator.cached_call(self._make_dumper)

def _make_dumper(self):
fmt = self._fmt

def datetime_format_dumper(data: datetime):
Expand All @@ -97,6 +108,9 @@ def __init__(self, tz: Optional[timezone]):
self._tz = tz

def provide_loader(self, mediator: Mediator, request: LoaderRequest) -> Loader:
return mediator.cached_call(self._make_loader)

def _make_loader(self):
tz = self._tz

def datetime_timestamp_loader(data):
Expand All @@ -115,9 +129,11 @@ def datetime_timestamp_loader(data):
return datetime_timestamp_loader

def provide_dumper(self, mediator: Mediator, request: DumperRequest) -> Dumper:
return mediator.cached_call(self._make_dumper)

def _make_dumper(self):
def datetime_timestamp_dumper(data: datetime):
return data.timestamp()

return datetime_timestamp_dumper

def _generate_json_schema(self, mediator: Mediator, request: JSONSchemaRequest) -> JSONSchema:
Expand All @@ -138,6 +154,9 @@ def _is_pydatetime(self) -> bool:
return False

def provide_loader(self, mediator: Mediator, request: LoaderRequest) -> Loader:
return mediator.cached_call(self._make_loader)

def _make_loader(self):
def date_timestamp_loader(data):
try:
# Pure-Python implementation and C-extension implementation
Expand Down Expand Up @@ -171,6 +190,9 @@ def pydate_timestamp_loader(data):
return pydate_timestamp_loader if self._is_pydatetime() else date_timestamp_loader

def provide_dumper(self, mediator: Mediator, request: DumperRequest) -> Dumper:
return mediator.cached_call(self._make_dumper)

def _make_dumper(self):
def date_timestamp_dumper(data: date):
dt = datetime(
year=data.year,
Expand All @@ -179,7 +201,6 @@ def date_timestamp_dumper(data: date):
tzinfo=timezone.utc,
)
return dt.timestamp()

return date_timestamp_dumper

def _generate_json_schema(self, mediator: Mediator, request: JSONSchemaRequest) -> JSONSchema:
Expand All @@ -191,6 +212,9 @@ class SecondsTimedeltaProvider(MorphingProvider):
_OK_TYPES = (int, float, Decimal)

def provide_loader(self, mediator: Mediator, request: LoaderRequest) -> Loader:
return mediator.cached_call(self._make_loader)

def _make_loader(self):
ok_types = self._OK_TYPES

def timedelta_loader(data):
Expand All @@ -201,6 +225,9 @@ def timedelta_loader(data):
return timedelta_loader

def provide_dumper(self, mediator: Mediator, request: DumperRequest) -> Dumper:
return mediator.cached_call(self._make_dumper)

def _make_dumper(self):
return timedelta.total_seconds

def _generate_json_schema(self, mediator: Mediator, request: JSONSchemaRequest) -> JSONSchema:
Expand All @@ -227,12 +254,13 @@ def _generate_json_schema(self, mediator: Mediator, request: JSONSchemaRequest)

class _Base64DumperMixin(DumperProvider):
def provide_dumper(self, mediator: Mediator, request: DumperRequest) -> Dumper:
return mediator.cached_call(self._make_dumper)

def _make_dumper(self):
def bytes_base64_dumper(data):
return b2a_base64(data, newline=False).decode("ascii")

return bytes_base64_dumper


class _Base64JSONSchemaMixin(JSONSchemaProvider):
def _generate_json_schema(self, mediator: Mediator, request: JSONSchemaRequest) -> JSONSchema:
return JSONSchema(type=JSONSchemaType.STRING, content_encoding="base64")
Expand All @@ -244,6 +272,9 @@ def _generate_json_schema(self, mediator: Mediator, request: JSONSchemaRequest)
@for_predicate(bytes)
class BytesBase64Provider(_Base64DumperMixin, _Base64JSONSchemaMixin, MorphingProvider):
def provide_loader(self, mediator: Mediator, request: LoaderRequest) -> Loader:
return mediator.cached_call(self._make_loader)

def _make_loader(self):
def bytes_base64_loader(data):
try:
encoded = data.encode("ascii")
Expand All @@ -257,7 +288,6 @@ def bytes_base64_loader(data):
return a2b_base64(encoded)
except binascii.Error as e:
raise ValueLoadError(str(e), data)

return bytes_base64_loader


Expand All @@ -266,29 +296,36 @@ class BytesIOBase64Provider(_Base64JSONSchemaMixin, MorphingProvider):
_BYTES_PROVIDER = BytesBase64Provider()

def provide_loader(self, mediator: Mediator, request: LoaderRequest) -> Loader:
bytes_base64_loader = self._BYTES_PROVIDER.provide_loader(mediator, request)
return mediator.cached_call(
self._make_loader,
loader=self._BYTES_PROVIDER.provide_loader(mediator, request),
)

def _make_loader(self, loader: Loader):
def bytes_io_base64_loader(data):
return BytesIO(bytes_base64_loader(data))

return BytesIO(loader(data))
return bytes_io_base64_loader

def provide_dumper(self, mediator: Mediator, request: DumperRequest) -> Dumper:
return mediator.cached_call(self._make_dumper)

def _make_dumper(self):
def bytes_io_base64_dumper(data: BytesIO):
return b2a_base64(data.getvalue(), newline=False).decode("ascii")

return bytes_io_base64_dumper


@for_predicate(typing.IO[bytes])
class IOBytesBase64Provider(BytesIOBase64Provider, _Base64JSONSchemaMixin, MorphingProvider):
def provide_dumper(self, mediator: Mediator, request: DumperRequest) -> Dumper:
return mediator.cached_call(self._make_dumper)

def _make_dumper(self):
def io_bytes_base64_dumper(data: typing.IO[bytes]):
if data.seekable():
data.seek(0)

return b2a_base64(data.read(), newline=False).decode("ascii")

return io_bytes_base64_dumper


Expand All @@ -302,9 +339,14 @@ def provide_loader(self, mediator: Mediator, request: LoaderRequest) -> Loader:
replace(request, loc_stack=request.loc_stack.replace_last_type(bytes)),
)

def bytearray_base64_loader(data):
return bytearray(bytes_loader(data))
return mediator.cached_call(
self._make_loader,
loader=bytes_loader,
)

def _make_loader(self, loader: Loader):
def bytearray_base64_loader(data):
return bytearray(loader(data))
return bytearray_base64_loader


Expand All @@ -318,6 +360,9 @@ def __init__(self, flags: re.RegexFlag = re.RegexFlag(0)):
self.flags = flags

def provide_loader(self, mediator: Mediator, request: LoaderRequest) -> Loader:
return mediator.cached_call(self._make_loader)

def _make_loader(self):
flags = self.flags
re_compile = re.compile

Expand Down Expand Up @@ -360,6 +405,12 @@ def __init__(

def provide_loader(self, mediator: Mediator, request: LoaderRequest) -> Loader:
strict_coercion = mediator.mandatory_provide(StrictCoercionRequest(loc_stack=request.loc_stack))
return mediator.cached_call(
self._make_loader,
strict_coercion=strict_coercion,
)

def _make_loader(self, *, strict_coercion: bool):
return self._strict_coercion_loader if strict_coercion else self._lax_coercion_loader

def provide_dumper(self, mediator: Mediator, request: DumperRequest) -> Dumper:
Expand Down
11 changes: 9 additions & 2 deletions src/adaptix/_internal/morphing/constant_length_tuple_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,12 @@ def provide_loader(self, mediator: Mediator, request: LoaderRequest) -> Loader:
)
strict_coercion = mediator.mandatory_provide(StrictCoercionRequest(loc_stack=request.loc_stack))
debug_trail = mediator.mandatory_provide(DebugTrailRequest(loc_stack=request.loc_stack))
return self._make_loader(tuple(loaders), strict_coercion=strict_coercion, debug_trail=debug_trail)
return mediator.cached_call(
self._make_loader,
loaders=tuple(loaders),
strict_coercion=strict_coercion,
debug_trail=debug_trail,
)

def _make_loader(self, loaders: Collection[Loader], *, strict_coercion: bool, debug_trail: DebugTrail):
if debug_trail == DebugTrail.DISABLE:
Expand Down Expand Up @@ -225,7 +230,9 @@ def provide_dumper(self, mediator: Mediator, request: DumperRequest) -> Dumper:
lambda: "Cannot create dumper for tuple. Dumpers for some elements cannot be created",
)
debug_trail = mediator.mandatory_provide(DebugTrailRequest(loc_stack=request.loc_stack))
return self._make_dumper(tuple(dumpers), debug_trail)
return mediator.cached_call(self._make_dumper,
dumpers=tuple(dumpers),
debug_trail=debug_trail)

def _make_dumper(self, dumpers: Collection[Dumper], debug_trail: DebugTrail):
if debug_trail == DebugTrail.DISABLE:
Expand Down
22 changes: 16 additions & 6 deletions src/adaptix/_internal/morphing/dict_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def provide_loader(self, mediator: Mediator, request: LoaderRequest) -> Loader:
debug_trail = mediator.mandatory_provide(
DebugTrailRequest(loc_stack=request.loc_stack),
)
return self._make_loader(
return mediator.cached_call(
self._make_loader,
key_loader=key_loader,
value_loader=value_loader,
debug_trail=debug_trail,
Expand Down Expand Up @@ -159,7 +160,8 @@ def provide_dumper(self, mediator: Mediator, request: DumperRequest) -> Dumper:
debug_trail = mediator.mandatory_provide(
DebugTrailRequest(loc_stack=request.loc_stack),
)
return self._make_dumper(
return mediator.cached_call(
self._make_dumper,
key_dumper=key_dumper,
value_dumper=value_dumper,
debug_trail=debug_trail,
Expand Down Expand Up @@ -261,18 +263,26 @@ def provide_loader(self, mediator: Mediator, request: LoaderRequest) -> Loader:
mediator,
replace(request, loc_stack=request.loc_stack.replace_last_type(dict_type_hint)),
)

return mediator.cached_call(
self._make_loader,
loader=dict_loader,
)

def _make_loader(self, loader: Loader):
default_factory = self.default_factory

def defaultdict_loader(data):
return defaultdict(default_factory, dict_loader(data))
return defaultdict(default_factory, loader(data))

return defaultdict_loader

def provide_dumper(self, mediator: Mediator, request: DumperRequest) -> Dumper:
key, value = self._extract_key_value(request)
dict_type_hint = Dict[key.source, value.source] # type: ignore[misc, name-defined]

return self._DICT_PROVIDER.provide_dumper(
mediator,
replace(request, loc_stack=request.loc_stack.replace_last_type(dict_type_hint)),
return mediator.cached_call(
self._DICT_PROVIDER.provide_dumper,
mediator=mediator,
request=replace(request, loc_stack=request.loc_stack.replace_last_type(dict_type_hint)),
)
Loading
Loading