Skip to content

Commit

Permalink
Merge pull request #2279 from deckar01/simplify-hooks
Browse files Browse the repository at this point in the history
Simplify Hooks
  • Loading branch information
lafrech committed Aug 20, 2024
2 parents bbbd7af + 7b62f3e commit 7609530
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 37 deletions.
2 changes: 1 addition & 1 deletion performance/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def run_timeit(quotes, iterations, repeat, profile=False):
profile.disable()
profile.dump_stats("marshmallow.pprof")

usec = best * 1e6 / iterations
usec = best * 1e6 / iterations / len(quotes)
return usec


Expand Down
24 changes: 13 additions & 11 deletions src/marshmallow/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def validate_age(self, data, **kwargs):
from __future__ import annotations

import functools
from collections import defaultdict
from typing import Any, Callable, cast

PRE_DUMP = "pre_dump"
Expand All @@ -79,7 +80,7 @@ def validate_age(self, data, **kwargs):


class MarshmallowHook:
__marshmallow_hook__: dict[tuple[str, bool] | str, Any] | None = None
__marshmallow_hook__: dict[str, list[tuple[bool, Any]]] | None = None


def validates(field_name: str) -> Callable[..., Any]:
Expand Down Expand Up @@ -117,7 +118,8 @@ def validates_schema(
"""
return set_hook(
fn,
(VALIDATES_SCHEMA, pass_many),
VALIDATES_SCHEMA,
many=pass_many,
pass_original=pass_original,
skip_on_field_errors=skip_on_field_errors,
)
Expand All @@ -136,7 +138,7 @@ def pre_dump(
.. versionchanged:: 3.0.0
``many`` is always passed as a keyword arguments to the decorated method.
"""
return set_hook(fn, (PRE_DUMP, pass_many))
return set_hook(fn, PRE_DUMP, many=pass_many)


def post_dump(
Expand All @@ -157,7 +159,7 @@ def post_dump(
.. versionchanged:: 3.0.0
``many`` is always passed as a keyword arguments to the decorated method.
"""
return set_hook(fn, (POST_DUMP, pass_many), pass_original=pass_original)
return set_hook(fn, POST_DUMP, many=pass_many, pass_original=pass_original)


def pre_load(
Expand All @@ -174,7 +176,7 @@ def pre_load(
``partial`` and ``many`` are always passed as keyword arguments to
the decorated method.
"""
return set_hook(fn, (PRE_LOAD, pass_many))
return set_hook(fn, PRE_LOAD, many=pass_many)


def post_load(
Expand All @@ -196,11 +198,11 @@ def post_load(
``partial`` and ``many`` are always passed as keyword arguments to
the decorated method.
"""
return set_hook(fn, (POST_LOAD, pass_many), pass_original=pass_original)
return set_hook(fn, POST_LOAD, many=pass_many, pass_original=pass_original)


def set_hook(
fn: Callable[..., Any] | None, key: tuple[str, bool] | str, **kwargs: Any
fn: Callable[..., Any] | None, tag: str, many: bool = False, **kwargs: Any
) -> Callable[..., Any]:
"""Mark decorated function as a hook to be picked up later.
You should not need to use this method directly.
Expand All @@ -214,18 +216,18 @@ def set_hook(
"""
# Allow using this as either a decorator or a decorator factory.
if fn is None:
return functools.partial(set_hook, key=key, **kwargs)
return functools.partial(set_hook, tag=tag, many=many, **kwargs)

# Set a __marshmallow_hook__ attribute instead of wrapping in some class,
# because I still want this to end up as a normal (unbound) method.
function = cast(MarshmallowHook, fn)
try:
hook_config = function.__marshmallow_hook__
except AttributeError:
function.__marshmallow_hook__ = hook_config = {}
function.__marshmallow_hook__ = hook_config = defaultdict(list)
# Also save the kwargs for the tagged function on
# __marshmallow_hook__, keyed by (<tag>, <pass_many>)
# __marshmallow_hook__, keyed by <tag>
if hook_config is not None:
hook_config[key] = kwargs
hook_config[tag].append((many, kwargs))

return fn
44 changes: 20 additions & 24 deletions src/marshmallow/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,15 +148,15 @@ def __init__(cls, name, bases, attrs):
class_registry.register(name, cls)
cls._hooks = cls.resolve_hooks()

def resolve_hooks(cls) -> dict[types.Tag, list[str]]:
def resolve_hooks(cls) -> dict[str, list[tuple[str, bool, dict]]]:
"""Add in the decorated processors
By doing this after constructing the class, we let standard inheritance
do all the hard work.
"""
mro = inspect.getmro(cls)

hooks = defaultdict(list) # type: typing.Dict[types.Tag, typing.List[str]]
hooks = defaultdict(list) # type: typing.Dict[str, typing.List[typing.Tuple[str, bool, dict]]]

for attr_name in dir(cls):
# Need to look up the actual descriptor, not whatever might be
Expand All @@ -176,14 +176,16 @@ def resolve_hooks(cls) -> dict[types.Tag, list[str]]:
continue

try:
hook_config = attr.__marshmallow_hook__
hook_config = attr.__marshmallow_hook__ # type: typing.Dict[str, typing.List[typing.Tuple[bool, dict]]]
except AttributeError:
pass
else:
for key in hook_config.keys():
for tag, config in hook_config.items():
# Use name here so we can get the bound method later, in
# case the processor was a descriptor or something.
hooks[key].append(attr_name)
hooks[tag].extend(
(attr_name, many, kwargs) for many, kwargs in config
)

return hooks

Expand Down Expand Up @@ -319,7 +321,7 @@ class AlbumSchema(Schema):
# These get set by SchemaMeta
opts = None # type: SchemaOpts
_declared_fields = {} # type: typing.Dict[str, ma_fields.Field]
_hooks = {} # type: typing.Dict[types.Tag, typing.List[str]]
_hooks = {} # type: typing.Dict[str, typing.List[typing.Tuple[str, bool, dict]]]

class Meta:
"""Options object for a Schema.
Expand Down Expand Up @@ -539,7 +541,7 @@ def dump(self, obj: typing.Any, *, many: bool | None = None):
Validation no longer occurs upon serialization.
"""
many = self.many if many is None else bool(many)
if self._has_processors(PRE_DUMP):
if self._hooks[PRE_DUMP]:
processed_obj = self._invoke_dump_processors(
PRE_DUMP, obj, many=many, original_data=obj
)
Expand All @@ -548,7 +550,7 @@ def dump(self, obj: typing.Any, *, many: bool | None = None):

result = self._serialize(processed_obj, many=many)

if self._has_processors(POST_DUMP):
if self._hooks[POST_DUMP]:
result = self._invoke_dump_processors(
POST_DUMP, result, many=many, original_data=obj
)
Expand Down Expand Up @@ -846,7 +848,7 @@ def _do_load(
if partial is None:
partial = self.partial
# Run preprocessors
if self._has_processors(PRE_LOAD):
if self._hooks[PRE_LOAD]:
try:
processed_data = self._invoke_load_processors(
PRE_LOAD, data, many=many, original_data=data, partial=partial
Expand All @@ -870,7 +872,7 @@ def _do_load(
error_store=error_store, data=result, many=many
)
# Run schema-level validation
if self._has_processors(VALIDATES_SCHEMA):
if self._hooks[VALIDATES_SCHEMA]:
field_errors = bool(error_store.errors)
self._invoke_schema_validators(
error_store=error_store,
Expand All @@ -892,7 +894,7 @@ def _do_load(
)
errors = error_store.errors
# Run post processors
if not errors and postprocess and self._has_processors(POST_LOAD):
if not errors and postprocess and self._hooks[POST_LOAD]:
try:
result = self._invoke_load_processors(
POST_LOAD,
Expand Down Expand Up @@ -1055,9 +1057,6 @@ def _bind_field(self, field_name: str, field_obj: ma_fields.Field) -> None:
raise error
self.on_bind_field(field_name, field_obj)

def _has_processors(self, tag) -> bool:
return bool(self._hooks[(tag, True)] or self._hooks[(tag, False)])

def _invoke_dump_processors(
self, tag: str, data, *, many: bool, original_data=None
):
Expand Down Expand Up @@ -1102,9 +1101,8 @@ def _invoke_load_processors(
return data

def _invoke_field_validators(self, *, error_store: ErrorStore, data, many: bool):
for attr_name in self._hooks[VALIDATES]:
for attr_name, _, validator_kwargs in self._hooks[VALIDATES]:
validator = getattr(self, attr_name)
validator_kwargs = validator.__marshmallow_hook__[VALIDATES]
field_name = validator_kwargs["field_name"]

try:
Expand Down Expand Up @@ -1159,11 +1157,10 @@ def _invoke_schema_validators(
partial: bool | types.StrSequenceOrSet | None,
field_errors: bool = False,
):
for attr_name in self._hooks[(VALIDATES_SCHEMA, pass_many)]:
for attr_name, hook_many, validator_kwargs in self._hooks[VALIDATES_SCHEMA]:
if hook_many != pass_many:
continue
validator = getattr(self, attr_name)
validator_kwargs = validator.__marshmallow_hook__[
(VALIDATES_SCHEMA, pass_many)
]
if field_errors and validator_kwargs["skip_on_field_errors"]:
continue
pass_original = validator_kwargs.get("pass_original", False)
Expand Down Expand Up @@ -1201,12 +1198,11 @@ def _invoke_processors(
original_data=None,
**kwargs,
):
key = (tag, pass_many)
for attr_name in self._hooks[key]:
for attr_name, hook_many, processor_kwargs in self._hooks[tag]:
if hook_many != pass_many:
continue
# This will be a bound method.
processor = getattr(self, attr_name)

processor_kwargs = processor.__marshmallow_hook__[key]
pass_original = processor_kwargs.get("pass_original", False)

if many and not pass_many:
Expand Down
1 change: 0 additions & 1 deletion src/marshmallow/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,4 @@
import typing

StrSequenceOrSet = typing.Union[typing.Sequence[str], typing.AbstractSet[str]]
Tag = typing.Union[str, typing.Tuple[str, bool]]
Validator = typing.Callable[[typing.Any], typing.Any]

0 comments on commit 7609530

Please sign in to comment.