Skip to content

Commit

Permalink
Support pickling dynamic classes subclassing typing.Generic instanc…
Browse files Browse the repository at this point in the history
…es on 3.7+ (#351)
  • Loading branch information
valtron authored Mar 15, 2020
1 parent 215d3dd commit 3e80b26
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 5 deletions.
4 changes: 4 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
and expand the support for pickling `TypeVar` instances (dynamic or non-dynamic)
to Python 3.5-3.6 ([PR #350](https://github.com/cloudpipe/cloudpickle/pull/350))

- Add support for pickling dynamic classes subclassing `typing.Generic`
instances on Python 3.7+
([PR #351](https://github.com/cloudpipe/cloudpickle/pull/351))

1.3.0
=====

Expand Down
19 changes: 16 additions & 3 deletions cloudpickle/cloudpickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def dump(self, obj):
raise

def save_typevar(self, obj):
self.save_reduce(*_typevar_reduce(obj))
self.save_reduce(*_typevar_reduce(obj), obj=obj)

dispatch[typing.TypeVar] = save_typevar

Expand Down Expand Up @@ -645,7 +645,7 @@ def save_dynamic_class(self, obj):
# "Regular" class definition:
tp = type(obj)
self.save_reduce(_make_skeleton_class,
(tp, obj.__name__, obj.__bases__, type_kwargs,
(tp, obj.__name__, _get_bases(obj), type_kwargs,
_ensure_tracking(obj), None),
obj=obj)

Expand Down Expand Up @@ -1163,7 +1163,10 @@ class id will also reuse this class definition.
The "extra" variable is meant to be a dict (or None) that can be used for
forward compatibility shall the need arise.
"""
skeleton_class = type_constructor(name, bases, type_kwargs)
skeleton_class = types.new_class(
name, bases, {'metaclass': type_constructor},
lambda ns: ns.update(type_kwargs)
)
return _lookup_class_or_track(class_tracker_id, skeleton_class)


Expand Down Expand Up @@ -1268,3 +1271,13 @@ def _typevar_reduce(obj):
if module_and_name is None:
return (_make_typevar, _decompose_typevar(obj))
return (getattr, module_and_name)


def _get_bases(typ):
if hasattr(typ, '__orig_bases__'):
# For generic types (see PEP 560)
bases_attr = '__orig_bases__'
else:
# For regular class objects
bases_attr = '__bases__'
return getattr(typ, bases_attr)
4 changes: 2 additions & 2 deletions cloudpickle/cloudpickle_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
_is_dynamic, _extract_code_globals, _BUILTIN_TYPE_NAMES, DEFAULT_PROTOCOL,
_find_imported_submodules, _get_cell_contents, _is_importable_by_name, _builtin_type,
Enum, _ensure_tracking, _make_skeleton_class, _make_skeleton_enum,
_extract_class_dict, dynamic_subimport, subimport, _typevar_reduce,
_extract_class_dict, dynamic_subimport, subimport, _typevar_reduce, _get_bases,
)

load, loads = _pickle.load, _pickle.loads
Expand Down Expand Up @@ -76,7 +76,7 @@ def _class_getnewargs(obj):
if isinstance(__dict__, property):
type_kwargs['__dict__'] = __dict__

return (type(obj), obj.__name__, obj.__bases__, type_kwargs,
return (type(obj), obj.__name__, _get_bases(obj), type_kwargs,
_ensure_tracking(obj), None)


Expand Down
85 changes: 85 additions & 0 deletions tests/cloudpickle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@

from .testutils import subprocess_pickle_echo
from .testutils import assert_run_python_script
from .testutils import subprocess_worker


_TEST_GLOBAL_VARIABLE = "default_value"
Expand Down Expand Up @@ -2121,6 +2122,12 @@ def test_pickle_dynamic_typevar(self):
for attr in attr_list:
assert getattr(T, attr) == getattr(depickled_T, attr)

def test_pickle_dynamic_typevar_memoization(self):
T = typing.TypeVar('T')
depickled_T1, depickled_T2 = pickle_depickle((T, T),
protocol=self.protocol)
assert depickled_T1 is depickled_T2

def test_pickle_importable_typevar(self):
from .mypkg import T
T1 = pickle_depickle(T, protocol=self.protocol)
Expand All @@ -2130,6 +2137,61 @@ def test_pickle_importable_typevar(self):
from typing import AnyStr
assert AnyStr is pickle_depickle(AnyStr, protocol=self.protocol)

@unittest.skipIf(sys.version_info < (3, 7),
"Pickling generics not supported below py37")
def test_generic_type(self):
T = typing.TypeVar('T')

class C(typing.Generic[T]):
pass

assert pickle_depickle(C, protocol=self.protocol) is C
assert pickle_depickle(C[int], protocol=self.protocol) is C[int]

with subprocess_worker(protocol=self.protocol) as worker:

def check_generic(generic, origin, type_value):
assert generic.__origin__ is origin
assert len(generic.__args__) == 1
assert generic.__args__[0] is type_value

assert len(origin.__orig_bases__) == 1
ob = origin.__orig_bases__[0]
assert ob.__origin__ is typing.Generic
assert len(ob.__parameters__) == 1

return "ok"

assert check_generic(C[int], C, int) == "ok"
assert worker.run(check_generic, C[int], C, int) == "ok"

@unittest.skipIf(sys.version_info < (3, 7),
"Pickling type hints not supported below py37")
def test_locally_defined_class_with_type_hints(self):
with subprocess_worker(protocol=self.protocol) as worker:
for type_ in _all_types_to_test():
# The type annotation syntax causes a SyntaxError on Python 3.5
code = textwrap.dedent("""\
class MyClass:
attribute: type_
def method(self, arg: type_) -> type_:
return arg
""")
ns = {"type_": type_}
exec(code, ns)
MyClass = ns["MyClass"]

def check_annotations(obj, expected_type):
assert obj.__annotations__["attribute"] is expected_type
assert obj.method.__annotations__["arg"] is expected_type
assert obj.method.__annotations__["return"] is expected_type
return "ok"

obj = MyClass()
assert check_annotations(obj, type_) == "ok"
assert worker.run(check_annotations, obj, type_) == "ok"


class Protocol2CloudPickleTest(CloudPickleTest):

Expand Down Expand Up @@ -2161,5 +2223,28 @@ def test_lookup_module_and_qualname_stdlib_typevar():
assert name == 'AnyStr'


def _all_types_to_test():
T = typing.TypeVar('T')

class C(typing.Generic[T]):
pass

return [
C, C[int],
T, typing.Any, typing.NoReturn, typing.Optional,
typing.Generic, typing.Union, typing.ClassVar,
typing.Optional[int],
typing.Generic[T],
typing.Callable[[int], typing.Any],
typing.Callable[..., typing.Any],
typing.Callable[[], typing.Any],
typing.Tuple[int, ...],
typing.Tuple[int, C[int]],
typing.ClassVar[C[int]],
typing.List[int],
typing.Dict[int, str],
]


if __name__ == '__main__':
unittest.main()

0 comments on commit 3e80b26

Please sign in to comment.