Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch to using registered serialization #9

Merged
merged 1 commit into from
Mar 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Pipfile.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

176 changes: 77 additions & 99 deletions scanspec/core.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,28 @@
from dataclasses import fields as dataclass_fields
from types import new_class
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Dict,
Iterator,
List,
Optional,
Mapping,
Type,
TypeVar,
)

import numpy as np
from apischema import deserialize, deserializer, serialize
from apischema import deserialize, deserializer, schema_ref, serialize
from apischema.conversions import (
Conversion,
LazyConversion,
dataclass_input_wrapper,
identity,
reset_deserializers,
)
from apischema.metadata.implem import ConversionMetadata
from apischema.metadata.keys import CONVERSIONS_METADATA
from apischema.conversions.converters import serializer
from apischema.metadata import conversion
from apischema.tagged_unions import Tagged, TaggedUnion, get_tagged
from typing_extensions import Annotated

#: The type of class the function will return
T = TypeVar("T")
Expand Down Expand Up @@ -61,88 +60,54 @@ def alternative_constructor(f):
"""
cls_name = f.__qualname__.split(".")[0]
_alternative_constructors.setdefault(cls_name, []).append(f)
m = staticmethod(f)
return m


def _update_serialization(parent_class: Any) -> Conversion:
"""Performs several tasks to setup (de)serialization. First,
handle alternative constructors so they are added to the TaggedUnion.
Second, calculate a tagged_union_conversion. Sub-classes are iterated
over for a second time. This time each dataclass field is checked. If
the dataclass field is of the same type as one of the serializable classes
(i.e. a child of Serializable), its dynamic converion is updated.
The tagged union is then used to register a a deserialization.
It is also returned for use in dynamic conversions."""

sub_cls: Any
namespace: Dict[str, Any] = {}
annotations: Dict[str, Type[Tagged[Any]]] = {}
for sub_cls in rec_subclasses(parent_class):
# Add tagged field for the Spec subclass
annotations[sub_cls.__name__] = Tagged[sub_cls]
# Add tagged fields for all its additional constructors
# (use class __dict__ in order to avoid inheritances of this constructors)
for constructor in _alternative_constructors.get(sub_cls.__name__, []):
# Build the alias of the field
alias = (
"".join(map(str.capitalize, constructor.__name__.split("_")))
+ sub_cls.__name__
return staticmethod(f)


def _make_tagged_union(base: Type, is_serialization: bool) -> Type[TaggedUnion]:
# base is a direct subclass of Serializable, like Spec or Region
namespace: Dict[str, Any] = dict(__annotations__={})
for cls in rec_subclasses(base):
# Add tagged field for the Serializable subclass
namespace["__annotations__"][cls.__name__] = Tagged[cls] # type: ignore
if is_serialization:
# Specify that we should use the identity serialization rather
# than our registered to_tagged_union() serializer when inside the
# tagged union
serialization = Conversion(
identity,
source=cls,
# Tagged field default serialization (to tagged union) must be
# bypassed. However, dynamic conversion discards schema_ref, so
# you must put it back manually, and do the bypass in a sub
# conversion.
target=Annotated[cls, schema_ref(cls.__name__)],
sub_conversions=identity,
)
# dataclass_input_wrapper uses get_type_hints, but the constructor
# return type is stringified and the class not defined yet,
# so it must be assigned manually
constructor.__annotations__["return"] = sub_cls
# Wraps the constructor and rename its input class
wrapper, wrapper_cls = dataclass_input_wrapper(constructor)
wrapper_cls.__name__ = alias
# Add constructor tagged field with its conversion
annotations[alias] = Tagged[sub_cls]
namespace[alias] = Tagged(deserialization=wrapper)
namespace[cls.__name__] = Tagged(conversion(serialization=serialization))
else:
# Build deserialization aliases for each alternative constructor alias
for constructor in _alternative_constructors.get(cls.__name__, []):
alias = (
"".join(map(str.capitalize, constructor.__name__.split("_")))
+ cls.__name__
)
# dataclass_input_wrapper uses get_type_hints, but the constructor
# return type is stringified and the class not defined yet,
# so it must be assigned manually
constructor.__annotations__["return"] = cls
# Wraps the constructor and rename its input class
wrapper, wrapper_cls = dataclass_input_wrapper(constructor)
wrapper_cls.__name__ = alias
# Add constructor tagged field with its conversion
namespace["__annotations__"][alias] = Tagged[cls] # type: ignore
namespace[alias] = Tagged(conversion(deserialization=wrapper))
# Create the tagged union class
namespace = dict(__annotations__=annotations, **namespace)

tagged_union = new_class(
f"Tagged{parent_class.__name__}Union",
union = new_class(
f"Tagged{base.__name__}Union",
(TaggedUnion,),
exec_body=lambda ns: ns.update(namespace),
)

tagged_union_conversion = Conversion(
lambda obj: tagged_union(**{type(obj).__name__: obj}),
source=parent_class,
target=tagged_union,
)

# Add dynamic conversions for attributes which are children of Serializable
for sub_cls in rec_subclasses(parent_class):
for field in dataclass_fields(sub_cls):
meta = field.metadata
if meta and field.type in parent_class.registered_serializable:
if field.type == parent_class:
conversion = tagged_union_conversion
else:
conversion = field.type.conversion
meta = {
**meta,
**{
CONVERSIONS_METADATA: ConversionMetadata(
serialization=conversion
)
},
}
field.metadata = meta

# Because deserializers stack, they must be reset before being reassigned
reset_deserializers(parent_class)
# Register the deserializer using get_tagged
deserializer(
Conversion(
lambda obj: get_tagged(obj)[1], source=tagged_union, target=parent_class
)
)

return tagged_union_conversion
return union


class Serializable:
Expand All @@ -151,26 +116,39 @@ class Serializable:
(de)serialize grandchild classes. Each time a grandchild class is added
conversion is updated to create a full TaggedUnion for the child class."""

conversion = ClassVar[Optional[Conversion]]
registered_serializable: ClassVar[List[Any]] = []

def __init_subclass__(cls, **kwargs):
parent_cls = cls.__bases__[0]
if parent_cls == Serializable:
cls.registered_serializable.append(cls)
else:
super().__init_subclass__(**kwargs)
parent_cls.conversion = _update_serialization(parent_cls)
super().__init_subclass__(**kwargs)
# Retrieved the base class inheriting Serializable
bases = [c for c in cls.__mro__ if Serializable in c.__bases__]
assert (
len(bases) == 1
), f"Cannot have multiple base classes inheriting Serializable {bases}"
base = bases[0]
assert issubclass(base, Serializable)
# Create the serialization tagged union class
serialization_union = _make_tagged_union(base, is_serialization=True)

# And a function that converts to it
def to_tagged_union(obj):
return serialization_union(**{obj.__class__.__name__: obj})

# Register the serializer
serializer(Conversion(to_tagged_union, source=base, target=serialization_union))
# Create the deserialization tagged union class
deserialization_union = _make_tagged_union(base, is_serialization=False)
# Because deserializers stack, they must be reset before being reassigned
reset_deserializers(base)
# Register the deserializer using get_tagged
deserializer(
Conversion(lambda obj: get_tagged(obj)[1], deserialization_union, base,)
)

def serialize(self):
def serialize(self) -> Mapping[str, Any]:
"""Serialize to a dictionary representation"""
parent_cls = self.__class__.__bases__[0]
return serialize(
self, conversions=LazyConversion(lambda: parent_cls.conversion)
)
return serialize(self)

@classmethod
def deserialize(cls: T, serialization: Dict[str, Any]) -> T:
def deserialize(cls: Type[T], serialization: Mapping[str, Any]) -> T:
"""Deserialize from a dictionary representation"""
return deserialize(cls, serialization)

Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ install_requires =
# make sure a python 3.9 compatible numpy is selected
numpy>=1.19.3
click
apischema
apischema>=0.14.7
typing_extensions

[options.extras_require]
Expand Down