Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support pickling dynamic classes subclassing typing.Generic instances on 3.7+ #351

Merged
merged 11 commits into from
Mar 15, 2020
20 changes: 18 additions & 2 deletions cloudpickle/cloudpickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,8 +644,9 @@ def save_dynamic_class(self, obj):
else:
# "Regular" class definition:
tp = type(obj)
bases = _get_bases(obj)
valtron marked this conversation as resolved.
Show resolved Hide resolved
self.save_reduce(_make_skeleton_class,
(tp, obj.__name__, obj.__bases__, type_kwargs,
(tp, obj.__name__, bases, type_kwargs,
valtron marked this conversation as resolved.
Show resolved Hide resolved
_ensure_tracking(obj), None),
obj=obj)

Expand Down Expand Up @@ -1163,10 +1164,17 @@ class id will also reuse this class definition.
The "extra" variable is meant to be a dict (or None) that can be used for
forward compatibility shall the need arise.
"""
skeleton_class = type_constructor(name, bases, type_kwargs)
skeleton_class = _make_new_class(type_constructor, name, bases, type_kwargs)
return _lookup_class_or_track(class_tracker_id, skeleton_class)


def _make_new_class(type_constructor, name, bases, type_kwargs):
valtron marked this conversation as resolved.
Show resolved Hide resolved
return types.new_class(
name, bases, {'metaclass': type_constructor},
lambda ns: ns.update(type_kwargs)
)


def _rehydrate_skeleton_class(skeleton_class, class_dict):
"""Put attributes from `class_dict` back on `skeleton_class`.

Expand Down Expand Up @@ -1268,3 +1276,11 @@ def _typevar_reduce(obj):
if module_and_name is None:
return (_make_typevar, _decompose_typevar(obj))
return (getattr, module_and_name)


def _get_bases(typ):
if hasattr(typ, '__orig_bases__'):
bases_attr = '__orig_bases__'
valtron marked this conversation as resolved.
Show resolved Hide resolved
valtron marked this conversation as resolved.
Show resolved Hide resolved
else:
bases_attr = '__bases__'
valtron marked this conversation as resolved.
Show resolved Hide resolved
return getattr(typ, bases_attr)
4 changes: 2 additions & 2 deletions cloudpickle/cloudpickle_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
_is_dynamic, _extract_code_globals, _BUILTIN_TYPE_NAMES, DEFAULT_PROTOCOL,
_find_imported_submodules, _get_cell_contents, _is_importable_by_name, _builtin_type,
Enum, _ensure_tracking, _make_skeleton_class, _make_skeleton_enum,
_extract_class_dict, dynamic_subimport, subimport, _typevar_reduce,
_extract_class_dict, dynamic_subimport, subimport, _typevar_reduce, _get_bases,
)

load, loads = _pickle.load, _pickle.loads
Expand Down Expand Up @@ -76,7 +76,7 @@ def _class_getnewargs(obj):
if isinstance(__dict__, property):
type_kwargs['__dict__'] = __dict__

return (type(obj), obj.__name__, obj.__bases__, type_kwargs,
return (type(obj), obj.__name__, _get_bases(obj), type_kwargs,
_ensure_tracking(obj), None)


Expand Down
47 changes: 47 additions & 0 deletions tests/cloudpickle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2130,6 +2130,53 @@ def test_pickle_importable_typevar(self):
from typing import AnyStr
assert AnyStr is pickle_depickle(AnyStr, protocol=self.protocol)

@unittest.skipIf(sys.version_info < (3, 7),
"Pickling generics not supported below py37")
def test_generic(self):
valtron marked this conversation as resolved.
Show resolved Hide resolved
valtron marked this conversation as resolved.
Show resolved Hide resolved
from typing import (
Optional, TypeVar, Generic, Tuple, Callable,
Dict, Any, ClassVar, NoReturn, Union, List,
)

T = TypeVar('T')

class C(Generic[T]):
pass

objs = [
valtron marked this conversation as resolved.
Show resolved Hide resolved
C, C[int],
T, Any, NoReturn, Optional, Generic,
Union, ClassVar,
Optional[int],
Generic[T],
Callable[[int], Any],
Callable[..., Any],
Callable[[], Any],
Tuple[int, ...],
Tuple[int, C[int]],
ClassVar[C[int]],
List[int],
Dict[int, str],
]

for obj in objs:
_ = pickle_depickle(obj, protocol=self.protocol)
valtron marked this conversation as resolved.
Show resolved Hide resolved

@unittest.skipIf(sys.version_info < (3, 7),
"Pickling generics not supported below py37")
def test_generic_extensions(self):
valtron marked this conversation as resolved.
Show resolved Hide resolved
typing_extensions = pytest.importorskip('typing_extensions')

objs = [
typing_extensions.Literal,
typing_extensions.Final,
typing_extensions.Literal['a'],
typing_extensions.Final[int],
]

for obj in objs:
_ = pickle_depickle(obj, protocol=self.protocol)


class Protocol2CloudPickleTest(CloudPickleTest):

Expand Down