Skip to content

Commit

Permalink
fix uqfoundation#288 nested namedtuples
Browse files Browse the repository at this point in the history
  • Loading branch information
anivegesana committed Jan 27, 2022
1 parent 0392e14 commit 9450a7b
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 17 deletions.
66 changes: 49 additions & 17 deletions dill/_dill.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,9 +707,14 @@ def _create_function(fcode, fglobals, fname=None, fdefaults=None,
fclosure=None, fdict=None, fkwdefaults=None):
# same as FunctionType, but enable passing __dict__ to new function,
# __dict__ is the storehouse for attributes added after function creation
if fdict is None: fdict = dict()
func = FunctionType(fcode, fglobals or dict(), fname, fdefaults, fclosure)
func.__dict__.update(fdict) #XXX: better copy? option to copy?
if fdict is not None:
func.__dict__.update(fdict) #XXX: better copy? option to copy?
elif IS_PYPY2:
# __reduce__ crashes in PyPy2
# __setstate__ is not used in any PyPy2 code and is removed in PyPy3
# so is likely an artifact
func.__setstate__ = None
if fkwdefaults is not None:
func.__kwdefaults__ = fkwdefaults
# 'recurse' only stores referenced modules/objects in fglobals,
Expand Down Expand Up @@ -971,14 +976,23 @@ def _create_dtypemeta(scalar_type):
return NumpyDType
return type(NumpyDType(scalar_type))

def _create_namedtuple(name, fieldnames, modulename):
class_ = _import_module(modulename + '.' + name, safe=True)
if class_ is not None:
return class_
import collections
t = collections.namedtuple(name, fieldnames)
t.__module__ = modulename
return t
if OLD37:
def _create_namedtuple(name, fieldnames, modulename, defaults=None):
class_ = _import_module(modulename + '.' + name, safe=True)
if class_ is not None:
return class_
import collections
t = collections.namedtuple(name, fieldnames)
t.__module__ = modulename
return t
else:
def _create_namedtuple(name, fieldnames, modulename, defaults=None):
class_ = _import_module(modulename + '.' + name, safe=True)
if class_ is not None:
return class_
import collections
t = collections.namedtuple(name, fieldnames, defaults=defaults, module=modulename)
return t

def _getattr(objclass, name, repr_str):
# hack to grab the reference directly
Expand Down Expand Up @@ -1603,10 +1617,14 @@ def save_type(pickler, obj, postproc_list=None):
log.info("T1: %s" % obj)
pickler.save_reduce(_load_type, (_typemap[obj],), obj=obj)
log.info("# T1")
elif issubclass(obj, tuple) and all([hasattr(obj, attr) for attr in ('_fields','_asdict','_make','_replace')]):
elif obj.__bases__ == (tuple,) and all([hasattr(obj, attr) for attr in ('_fields','_asdict','_make','_replace')]):
# special case: namedtuples
log.info("T6: %s" % obj)
pickler.save_reduce(_create_namedtuple, (getattr(obj, "__qualname__", obj.__name__), obj._fields, obj.__module__), obj=obj)
if OLD37 or (not obj._field_defaults):
pickler.save_reduce(_create_namedtuple, (obj.__name__, obj._fields, obj.__module__), obj=obj)
else:
defaults = [obj._field_defaults[field] for field in obj._fields]
pickler.save_reduce(_create_namedtuple, (obj.__name__, obj._fields, obj.__module__, defaults), obj=obj)
log.info("# T6")
return

Expand Down Expand Up @@ -1748,16 +1766,30 @@ def save_function(pickler, obj):
postproc_list.append((dict.update, (globs, globs_copy)))

if PY3:
fkwdefaults = getattr(obj, '__kwdefaults__', None)
state_dict = {}
for fattrname in ('__kwdefaults__', '__annotations__'):
fattr = getattr(obj, fattrname, None)
if fattr is not None:
state_dict[fattrname] = fattr
if obj.__qualname__ != obj.__name__:
state_dict['__qualname__'] = obj.__qualname__

state = obj.__dict__
if type(state) is not dict:
state_dict['__dict__'] = state
state = None
if state_dict:
state = state, state_dict

_save_with_postproc(pickler, (_create_function, (
obj.__code__, globs, obj.__name__, obj.__defaults__,
obj.__closure__, obj.__dict__, fkwdefaults
)), obj=obj, postproc_list=postproc_list)
obj.__closure__
), state), obj=obj, postproc_list=postproc_list)
else:
_save_with_postproc(pickler, (_create_function, (
obj.func_code, globs, obj.func_name, obj.func_defaults,
obj.func_closure, obj.__dict__
)), obj=obj, postproc_list=postproc_list)
obj.func_closure
), obj.__dict__), obj=obj, postproc_list=postproc_list)
log.info("# F1")
else:
log.info("F2: %s" % obj)
Expand Down
7 changes: 7 additions & 0 deletions tests/test_classdef.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,13 @@ def test_namedtuple():
assert Bad._fields == dill.loads(dill.dumps(Bad))._fields
assert tuple(Badi) == tuple(dill.loads(dill.dumps(Badi)))

class A:
class B(namedtuple("B", ["one", "two"])):
pass

a = A()
assert dill.copy(a)

def test_dtype():
try:
import numpy as np
Expand Down

0 comments on commit 9450a7b

Please sign in to comment.