Skip to content

Commit

Permalink
Fix pickling errors thrown when saving some Stdlib modules (#529)
Browse files Browse the repository at this point in the history
* fix KeyError when pickling type with '__dict__' or '__weakref__' in '__slots__'

* fix KeyError when pickling a type where '__slots__' is a string
  • Loading branch information
leogama authored Aug 1, 2022
1 parent b2fa04d commit 87b8541
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 19 deletions.
28 changes: 10 additions & 18 deletions dill/_dill.py
Original file line number Diff line number Diff line change
Expand Up @@ -1012,13 +1012,6 @@ def _get_attr(self, name):
# stop recursive pickling
return getattr(self, name, None) or getattr(__builtin__, name)

def _dict_from_dictproxy(dictproxy):
_dict = dictproxy.copy() # convert dictproxy to dict
_dict.pop('__dict__', None)
_dict.pop('__weakref__', None)
_dict.pop('__prepare__', None)
return _dict

def _import_module(import_name, safe=False):
try:
if import_name.startswith('__runtime__.'):
Expand Down Expand Up @@ -1712,28 +1705,27 @@ def save_type(pickler, obj, postproc_list=None):
obj_recursive = id(obj) in getattr(pickler, '_postproc', ())
incorrectly_named = not _locate_function(obj, pickler)
if not _byref and not obj_recursive and incorrectly_named: # not a function, but the name was held over
if issubclass(type(obj), type):
# thanks to Tom Stepleton pointing out pickler._session unneeded
_t = 'T2'
logger.trace(pickler, "%s: %s", _t, obj)
_dict = _dict_from_dictproxy(obj.__dict__)
else:
_t = 'T3'
logger.trace(pickler, "%s: %s", _t, obj)
_dict = obj.__dict__
# thanks to Tom Stepleton pointing out pickler._session unneeded
logger.trace(pickler, "T2: %s", obj)
_dict = obj.__dict__.copy() # convert dictproxy to dict
#print (_dict)
#print ("%s\n%s" % (type(obj), obj.__name__))
#print ("%s\n%s" % (obj.__bases__, obj.__dict__))
for name in _dict.get("__slots__", []):
slots = _dict.get('__slots__', ())
if type(slots) == str: slots = (slots,) # __slots__ accepts a single string
for name in slots:
del _dict[name]
_dict.pop('__dict__', None)
_dict.pop('__weakref__', None)
_dict.pop('__prepare__', None)
if obj_name != obj.__name__:
if postproc_list is None:
postproc_list = []
postproc_list.append((setattr, (obj, '__qualname__', obj_name)))
_save_with_postproc(pickler, (_create_type, (
type(obj), obj.__name__, obj.__bases__, _dict
)), obj=obj, postproc_list=postproc_list)
logger.trace(pickler, "# %s", _t)
logger.trace(pickler, "# T2")
else:
logger.trace(pickler, "T4: %s", obj)
if incorrectly_named:
Expand Down
2 changes: 1 addition & 1 deletion dill/tests/test_classdef.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def test(cls):

# test slots
class Y(object):
__slots__ = ['y']
__slots__ = ('y', '__weakref__')
def __init__(self, y):
self.y = y

Expand Down

0 comments on commit 87b8541

Please sign in to comment.