From e2831d04ae64d6dc8686a3c5ed1eb703dc748d43 Mon Sep 17 00:00:00 2001 From: Anirudh Vegesana Date: Tue, 19 Apr 2022 04:28:12 -0700 Subject: [PATCH] Lift closure cell update to earliest function (#461) --- dill/_dill.py | 32 +++++++++++++++++++++++++++----- tests/test_recursive.py | 14 ++++++++++++++ 2 files changed, 41 insertions(+), 5 deletions(-) diff --git a/dill/_dill.py b/dill/_dill.py index 6e78b831..ab05672d 100644 --- a/dill/_dill.py +++ b/dill/_dill.py @@ -518,7 +518,8 @@ def __init__(self, *args, **kwds): self._strictio = False #_strictio self._fmode = settings['fmode'] if _fmode is None else _fmode self._recurse = settings['recurse'] if _recurse is None else _recurse - self._postproc = {} + from collections import OrderedDict + self._postproc = OrderedDict() def dump(self, obj): #NOTE: if settings change, need to update attributes # register if the object is a numpy ufunc @@ -1424,14 +1425,14 @@ def save_cell(pickler, obj): log.info("# Ce3") return if is_dill(pickler, child=True): - postproc = pickler._postproc.get(id(f)) + postproc = next(iter(pickler._postproc.values()), None) if postproc is not None: log.info("Ce2: %s" % obj) # _CELL_REF is defined in _shims.py to support older versions of # dill. When breaking changes are made to dill, (_CELL_REF,) can # be replaced by () - postproc.append((_shims._setattr, (obj, 'cell_contents', f))) pickler.save_reduce(_create_cell, (_CELL_REF,), obj=obj) + postproc.append((_shims._setattr, (obj, 'cell_contents', f))) log.info("# Ce2") return log.info("Ce1: %s" % obj) @@ -1748,16 +1749,37 @@ def save_function(pickler, obj): postproc_list.append((dict.update, (globs, globs_copy))) if PY3: + closure = obj.__closure__ fkwdefaults = getattr(obj, '__kwdefaults__', None) _save_with_postproc(pickler, (_create_function, ( obj.__code__, globs, obj.__name__, obj.__defaults__, - obj.__closure__, obj.__dict__, fkwdefaults + closure, obj.__dict__, fkwdefaults )), obj=obj, postproc_list=postproc_list) else: + closure = obj.func_closure _save_with_postproc(pickler, (_create_function, ( obj.func_code, globs, obj.func_name, obj.func_defaults, - obj.func_closure, obj.__dict__ + closure, obj.__dict__ )), obj=obj, postproc_list=postproc_list) + + # Lift closure cell update to earliest function (#458) + topmost_postproc = next(iter(pickler._postproc.values()), None) + if closure and topmost_postproc: + for cell in closure: + possible_postproc = (setattr, (cell, 'cell_contents', obj)) + try: + topmost_postproc.remove(possible_postproc) + except ValueError: + continue + + # Change the value of the cell + pickler.save_reduce(*possible_postproc) + # pop None created by calling preprocessing step off stack + if PY3: + pickler.write(bytes('0', 'UTF-8')) + else: + pickler.write('0') + log.info("# F1") else: log.info("F2: %s" % obj) diff --git a/tests/test_recursive.py b/tests/test_recursive.py index ef779fe8..ee71a688 100644 --- a/tests/test_recursive.py +++ b/tests/test_recursive.py @@ -154,6 +154,19 @@ def test_recursive_function(): fib = fib4 +def collection_function_recursion(): + d = {} + def g(): + return d + d['g'] = g + return g + + +def test_collection_function_recursion(): + g = copy(collection_function_recursion()) + assert g()['g'] is g + + if __name__ == '__main__': with warnings.catch_warnings(): warnings.simplefilter('error') @@ -163,3 +176,4 @@ def test_recursive_function(): test_circular_reference() test_function_cells() test_recursive_function() + test_collection_function_recursion()