From 593b7fecd4eb3bf24480d18b942703ee9ccf8a04 Mon Sep 17 00:00:00 2001 From: Jared Deckard Date: Thu, 6 Jun 2024 11:56:56 -0500 Subject: [PATCH 1/2] Simplify tuple-keyed hook name and config dicts --- src/marshmallow/decorators.py | 24 ++++++++++--------- src/marshmallow/schema.py | 44 ++++++++++++++++------------------- src/marshmallow/types.py | 1 - 3 files changed, 33 insertions(+), 36 deletions(-) diff --git a/src/marshmallow/decorators.py b/src/marshmallow/decorators.py index dafca9539..965edb68b 100644 --- a/src/marshmallow/decorators.py +++ b/src/marshmallow/decorators.py @@ -68,6 +68,7 @@ def validate_age(self, data, **kwargs): from __future__ import annotations import functools +from collections import defaultdict from typing import Any, Callable, cast PRE_DUMP = "pre_dump" @@ -79,7 +80,7 @@ def validate_age(self, data, **kwargs): class MarshmallowHook: - __marshmallow_hook__: dict[tuple[str, bool] | str, Any] | None = None + __marshmallow_hook__: dict[str, list[tuple[bool, Any]]] | None = None def validates(field_name: str) -> Callable[..., Any]: @@ -117,7 +118,8 @@ def validates_schema( """ return set_hook( fn, - (VALIDATES_SCHEMA, pass_many), + VALIDATES_SCHEMA, + many=pass_many, pass_original=pass_original, skip_on_field_errors=skip_on_field_errors, ) @@ -136,7 +138,7 @@ def pre_dump( .. versionchanged:: 3.0.0 ``many`` is always passed as a keyword arguments to the decorated method. """ - return set_hook(fn, (PRE_DUMP, pass_many)) + return set_hook(fn, PRE_DUMP, many=pass_many) def post_dump( @@ -157,7 +159,7 @@ def post_dump( .. versionchanged:: 3.0.0 ``many`` is always passed as a keyword arguments to the decorated method. """ - return set_hook(fn, (POST_DUMP, pass_many), pass_original=pass_original) + return set_hook(fn, POST_DUMP, many=pass_many, pass_original=pass_original) def pre_load( @@ -174,7 +176,7 @@ def pre_load( ``partial`` and ``many`` are always passed as keyword arguments to the decorated method. """ - return set_hook(fn, (PRE_LOAD, pass_many)) + return set_hook(fn, PRE_LOAD, many=pass_many) def post_load( @@ -196,11 +198,11 @@ def post_load( ``partial`` and ``many`` are always passed as keyword arguments to the decorated method. """ - return set_hook(fn, (POST_LOAD, pass_many), pass_original=pass_original) + return set_hook(fn, POST_LOAD, many=pass_many, pass_original=pass_original) def set_hook( - fn: Callable[..., Any] | None, key: tuple[str, bool] | str, **kwargs: Any + fn: Callable[..., Any] | None, tag: str, many: bool = False, **kwargs: Any ) -> Callable[..., Any]: """Mark decorated function as a hook to be picked up later. You should not need to use this method directly. @@ -214,7 +216,7 @@ def set_hook( """ # Allow using this as either a decorator or a decorator factory. if fn is None: - return functools.partial(set_hook, key=key, **kwargs) + return functools.partial(set_hook, tag=tag, many=many, **kwargs) # Set a __marshmallow_hook__ attribute instead of wrapping in some class, # because I still want this to end up as a normal (unbound) method. @@ -222,10 +224,10 @@ def set_hook( try: hook_config = function.__marshmallow_hook__ except AttributeError: - function.__marshmallow_hook__ = hook_config = {} + function.__marshmallow_hook__ = hook_config = defaultdict(list) # Also save the kwargs for the tagged function on - # __marshmallow_hook__, keyed by (, ) + # __marshmallow_hook__, keyed by if hook_config is not None: - hook_config[key] = kwargs + hook_config[tag].append((many, kwargs)) return fn diff --git a/src/marshmallow/schema.py b/src/marshmallow/schema.py index 23b43c470..ffdc3e04e 100644 --- a/src/marshmallow/schema.py +++ b/src/marshmallow/schema.py @@ -148,7 +148,7 @@ def __init__(cls, name, bases, attrs): class_registry.register(name, cls) cls._hooks = cls.resolve_hooks() - def resolve_hooks(cls) -> dict[types.Tag, list[str]]: + def resolve_hooks(cls) -> dict[str, list[tuple[str, bool, dict]]]: """Add in the decorated processors By doing this after constructing the class, we let standard inheritance @@ -156,7 +156,7 @@ def resolve_hooks(cls) -> dict[types.Tag, list[str]]: """ mro = inspect.getmro(cls) - hooks = defaultdict(list) # type: typing.Dict[types.Tag, typing.List[str]] + hooks = defaultdict(list) # type: typing.Dict[str, typing.List[typing.Tuple[str, bool, dict]]] for attr_name in dir(cls): # Need to look up the actual descriptor, not whatever might be @@ -176,14 +176,16 @@ def resolve_hooks(cls) -> dict[types.Tag, list[str]]: continue try: - hook_config = attr.__marshmallow_hook__ + hook_config = attr.__marshmallow_hook__ # type: typing.Dict[str, typing.List[typing.Tuple[bool, dict]]] except AttributeError: pass else: - for key in hook_config.keys(): + for tag, config in hook_config.items(): # Use name here so we can get the bound method later, in # case the processor was a descriptor or something. - hooks[key].append(attr_name) + hooks[tag].extend( + (attr_name, many, kwargs) for many, kwargs in config + ) return hooks @@ -319,7 +321,7 @@ class AlbumSchema(Schema): # These get set by SchemaMeta opts = None # type: SchemaOpts _declared_fields = {} # type: typing.Dict[str, ma_fields.Field] - _hooks = {} # type: typing.Dict[types.Tag, typing.List[str]] + _hooks = {} # type: typing.Dict[str, typing.List[typing.Tuple[str, bool, dict]]] class Meta: """Options object for a Schema. @@ -539,7 +541,7 @@ def dump(self, obj: typing.Any, *, many: bool | None = None): Validation no longer occurs upon serialization. """ many = self.many if many is None else bool(many) - if self._has_processors(PRE_DUMP): + if self._hooks[PRE_DUMP]: processed_obj = self._invoke_dump_processors( PRE_DUMP, obj, many=many, original_data=obj ) @@ -548,7 +550,7 @@ def dump(self, obj: typing.Any, *, many: bool | None = None): result = self._serialize(processed_obj, many=many) - if self._has_processors(POST_DUMP): + if self._hooks[POST_DUMP]: result = self._invoke_dump_processors( POST_DUMP, result, many=many, original_data=obj ) @@ -846,7 +848,7 @@ def _do_load( if partial is None: partial = self.partial # Run preprocessors - if self._has_processors(PRE_LOAD): + if self._hooks[PRE_LOAD]: try: processed_data = self._invoke_load_processors( PRE_LOAD, data, many=many, original_data=data, partial=partial @@ -870,7 +872,7 @@ def _do_load( error_store=error_store, data=result, many=many ) # Run schema-level validation - if self._has_processors(VALIDATES_SCHEMA): + if self._hooks[VALIDATES_SCHEMA]: field_errors = bool(error_store.errors) self._invoke_schema_validators( error_store=error_store, @@ -892,7 +894,7 @@ def _do_load( ) errors = error_store.errors # Run post processors - if not errors and postprocess and self._has_processors(POST_LOAD): + if not errors and postprocess and self._hooks[POST_LOAD]: try: result = self._invoke_load_processors( POST_LOAD, @@ -1055,9 +1057,6 @@ def _bind_field(self, field_name: str, field_obj: ma_fields.Field) -> None: raise error self.on_bind_field(field_name, field_obj) - def _has_processors(self, tag) -> bool: - return bool(self._hooks[(tag, True)] or self._hooks[(tag, False)]) - def _invoke_dump_processors( self, tag: str, data, *, many: bool, original_data=None ): @@ -1102,9 +1101,8 @@ def _invoke_load_processors( return data def _invoke_field_validators(self, *, error_store: ErrorStore, data, many: bool): - for attr_name in self._hooks[VALIDATES]: + for attr_name, _, validator_kwargs in self._hooks[VALIDATES]: validator = getattr(self, attr_name) - validator_kwargs = validator.__marshmallow_hook__[VALIDATES] field_name = validator_kwargs["field_name"] try: @@ -1159,11 +1157,10 @@ def _invoke_schema_validators( partial: bool | types.StrSequenceOrSet | None, field_errors: bool = False, ): - for attr_name in self._hooks[(VALIDATES_SCHEMA, pass_many)]: + for attr_name, hook_many, validator_kwargs in self._hooks[VALIDATES_SCHEMA]: + if hook_many != pass_many: + continue validator = getattr(self, attr_name) - validator_kwargs = validator.__marshmallow_hook__[ - (VALIDATES_SCHEMA, pass_many) - ] if field_errors and validator_kwargs["skip_on_field_errors"]: continue pass_original = validator_kwargs.get("pass_original", False) @@ -1201,12 +1198,11 @@ def _invoke_processors( original_data=None, **kwargs, ): - key = (tag, pass_many) - for attr_name in self._hooks[key]: + for attr_name, hook_many, processor_kwargs in self._hooks[tag]: + if hook_many != pass_many: + continue # This will be a bound method. processor = getattr(self, attr_name) - - processor_kwargs = processor.__marshmallow_hook__[key] pass_original = processor_kwargs.get("pass_original", False) if many and not pass_many: diff --git a/src/marshmallow/types.py b/src/marshmallow/types.py index ce31c0508..8352afeb3 100644 --- a/src/marshmallow/types.py +++ b/src/marshmallow/types.py @@ -8,5 +8,4 @@ import typing StrSequenceOrSet = typing.Union[typing.Sequence[str], typing.AbstractSet[str]] -Tag = typing.Union[str, typing.Tuple[str, bool]] Validator = typing.Callable[[typing.Any], typing.Any] From 7b62f3e1f21e263ca932321a15d03c5cb28d63f1 Mon Sep 17 00:00:00 2001 From: Jared Deckard Date: Thu, 6 Jun 2024 12:12:55 -0500 Subject: [PATCH 2/2] Normalize benchmark per dump for object count --- performance/benchmark.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/performance/benchmark.py b/performance/benchmark.py index a7b8eb8c7..010f753d5 100644 --- a/performance/benchmark.py +++ b/performance/benchmark.py @@ -100,7 +100,7 @@ def run_timeit(quotes, iterations, repeat, profile=False): profile.disable() profile.dump_stats("marshmallow.pprof") - usec = best * 1e6 / iterations + usec = best * 1e6 / iterations / len(quotes) return usec