From 5bd56a8e0096c3f45f4c37ff0b67ac6aa341f22b Mon Sep 17 00:00:00 2001 From: Anirudh Vegesana Date: Thu, 21 Apr 2022 16:55:43 -0700 Subject: [PATCH 1/5] Pickle inner `collections.namedtuple`s and function attributes (#448) * fix #288 nested namedtuples * Remove special case for PyPy 2.7 that doesn't exist __kwdefaults__ and __annotations__ are invalid in PyPy2.7 * Bug fix for __qualname__ on classes * Fix bug if _postproc not present and use _setitems --- dill/_dill.py | 118 ++++++++++++++++++++++++++++------------ tests/test_classdef.py | 16 +++++- tests/test_functions.py | 5 ++ 3 files changed, 102 insertions(+), 37 deletions(-) diff --git a/dill/_dill.py b/dill/_dill.py index bcd47159..cddc9cf5 100644 --- a/dill/_dill.py +++ b/dill/_dill.py @@ -720,9 +720,9 @@ def _create_function(fcode, fglobals, fname=None, fdefaults=None, fclosure=None, fdict=None, fkwdefaults=None): # same as FunctionType, but enable passing __dict__ to new function, # __dict__ is the storehouse for attributes added after function creation - if fdict is None: fdict = dict() func = FunctionType(fcode, fglobals or dict(), fname, fdefaults, fclosure) - func.__dict__.update(fdict) #XXX: better copy? option to copy? + if fdict is not None: + func.__dict__.update(fdict) #XXX: better copy? option to copy? if fkwdefaults is not None: func.__kwdefaults__ = fkwdefaults # 'recurse' only stores referenced modules/objects in fglobals, @@ -1001,14 +1001,23 @@ def _create_dtypemeta(scalar_type): return NumpyDType return type(NumpyDType(scalar_type)) -def _create_namedtuple(name, fieldnames, modulename): - class_ = _import_module(modulename + '.' + name, safe=True) - if class_ is not None: - return class_ - import collections - t = collections.namedtuple(name, fieldnames) - t.__module__ = modulename - return t +if OLD37: + def _create_namedtuple(name, fieldnames, modulename, defaults=None): + class_ = _import_module(modulename + '.' + name, safe=True) + if class_ is not None: + return class_ + import collections + t = collections.namedtuple(name, fieldnames) + t.__module__ = modulename + return t +else: + def _create_namedtuple(name, fieldnames, modulename, defaults=None): + class_ = _import_module(modulename + '.' + name, safe=True) + if class_ is not None: + return class_ + import collections + t = collections.namedtuple(name, fieldnames, defaults=defaults, module=modulename) + return t def _getattr(objclass, name, repr_str): # hack to grab the reference directly @@ -1058,6 +1067,11 @@ def _locate_function(obj, session=False): return found is obj +def _setitems(dest, source): + for k, v in source.items(): + dest[k] = v + + def _save_with_postproc(pickler, reduction, is_pickler_dill=None, obj=Getattr.NO_DEFAULT, postproc_list=None): if obj is Getattr.NO_DEFAULT: obj = Reduce(reduction) # pragma: no cover @@ -1089,7 +1103,7 @@ def _save_with_postproc(pickler, reduction, is_pickler_dill=None, obj=Getattr.NO postproc = pickler._postproc.pop(id(obj)) # assert postproc_list == postproc, 'Stack tampered!' for reduction in reversed(postproc): - if reduction[0] is dict.update and type(reduction[1][0]) is dict: + if reduction[0] is _setitems: # use the internal machinery of pickle.py to speedup when # updating a dictionary in postproc dest, source = reduction[1] @@ -1719,10 +1733,14 @@ def save_type(pickler, obj, postproc_list=None): log.info("T1: %s" % obj) pickler.save_reduce(_load_type, (_typemap[obj],), obj=obj) log.info("# T1") - elif issubclass(obj, tuple) and all([hasattr(obj, attr) for attr in ('_fields','_asdict','_make','_replace')]): + elif obj.__bases__ == (tuple,) and all([hasattr(obj, attr) for attr in ('_fields','_asdict','_make','_replace')]): # special case: namedtuples log.info("T6: %s" % obj) - pickler.save_reduce(_create_namedtuple, (getattr(obj, "__qualname__", obj.__name__), obj._fields, obj.__module__), obj=obj) + if OLD37 or (not obj._field_defaults): + pickler.save_reduce(_create_namedtuple, (obj.__name__, obj._fields, obj.__module__), obj=obj) + else: + defaults = [obj._field_defaults[field] for field in obj._fields] + pickler.save_reduce(_create_namedtuple, (obj.__name__, obj._fields, obj.__module__, defaults), obj=obj) log.info("# T6") return @@ -1764,8 +1782,12 @@ def save_type(pickler, obj, postproc_list=None): #print ("%s\n%s" % (obj.__bases__, obj.__dict__)) for name in _dict.get("__slots__", []): del _dict[name] + if PY3 and obj_name != obj.__name__: + if postproc_list is None: + postproc_list = [] + postproc_list.append((setattr, (obj, '__qualname__', obj_name))) _save_with_postproc(pickler, (_create_type, ( - type(obj), obj_name, obj.__bases__, _dict + type(obj), obj.__name__, obj.__bases__, _dict )), obj=obj, postproc_list=postproc_list) log.info("# %s" % _t) else: @@ -1858,42 +1880,66 @@ def save_function(pickler, obj): glob_ids = {id(g) for g in globs_copy.itervalues()} for stack_element in _postproc: if stack_element in glob_ids: - _postproc[stack_element].append((dict.update, (globs, globs_copy))) + _postproc[stack_element].append((_setitems, (globs, globs_copy))) break else: - postproc_list.append((dict.update, (globs, globs_copy))) + postproc_list.append((_setitems, (globs, globs_copy))) if PY3: closure = obj.__closure__ - fkwdefaults = getattr(obj, '__kwdefaults__', None) + state_dict = {} + for fattrname in ('__doc__', '__kwdefaults__', '__annotations__'): + fattr = getattr(obj, fattrname, None) + if fattr is not None: + state_dict[fattrname] = fattr + if obj.__qualname__ != obj.__name__: + state_dict['__qualname__'] = obj.__qualname__ + if '__name__' not in globs or obj.__module__ != globs['__name__']: + state_dict['__module__'] = obj.__module__ + + state = obj.__dict__ + if type(state) is not dict: + state_dict['__dict__'] = state + state = None + if state_dict: + state = state, state_dict + _save_with_postproc(pickler, (_create_function, ( obj.__code__, globs, obj.__name__, obj.__defaults__, - closure, obj.__dict__, fkwdefaults - )), obj=obj, postproc_list=postproc_list) + closure + ), state), obj=obj, postproc_list=postproc_list) else: closure = obj.func_closure + if obj.__doc__ is not None: + postproc_list.append((setattr, (obj, '__doc__', obj.__doc__))) + if '__name__' not in globs or obj.__module__ != globs['__name__']: + postproc_list.append((setattr, (obj, '__module__', obj.__module__))) + if obj.__dict__: + postproc_list.append((setattr, (obj, '__dict__', obj.__dict__))) + _save_with_postproc(pickler, (_create_function, ( obj.func_code, globs, obj.func_name, obj.func_defaults, - closure, obj.__dict__ + closure )), obj=obj, postproc_list=postproc_list) # Lift closure cell update to earliest function (#458) - topmost_postproc = next(iter(pickler._postproc.values()), None) - if closure and topmost_postproc: - for cell in closure: - possible_postproc = (setattr, (cell, 'cell_contents', obj)) - try: - topmost_postproc.remove(possible_postproc) - except ValueError: - continue - - # Change the value of the cell - pickler.save_reduce(*possible_postproc) - # pop None created by calling preprocessing step off stack - if PY3: - pickler.write(bytes('0', 'UTF-8')) - else: - pickler.write('0') + if _postproc: + topmost_postproc = next(iter(_postproc.values()), None) + if closure and topmost_postproc: + for cell in closure: + possible_postproc = (setattr, (cell, 'cell_contents', obj)) + try: + topmost_postproc.remove(possible_postproc) + except ValueError: + continue + + # Change the value of the cell + pickler.save_reduce(*possible_postproc) + # pop None created by calling preprocessing step off stack + if PY3: + pickler.write(bytes('0', 'UTF-8')) + else: + pickler.write('0') log.info("# F1") else: diff --git a/tests/test_classdef.py b/tests/test_classdef.py index d59f6f66..ca6ab6ad 100644 --- a/tests/test_classdef.py +++ b/tests/test_classdef.py @@ -114,6 +114,20 @@ def test_namedtuple(): assert Bad._fields == dill.loads(dill.dumps(Bad))._fields assert tuple(Badi) == tuple(dill.loads(dill.dumps(Badi))) + class A: + class B(namedtuple("B", ["one", "two"])): + '''docstring''' + B.__module__ = 'testing' + + a = A() + assert dill.copy(a) + + assert dill.copy(A.B).__name__ == 'B' + if dill._dill.PY3: + assert dill.copy(A.B).__qualname__.endswith('..A.B') + assert dill.copy(A.B).__doc__ == 'docstring' + assert dill.copy(A.B).__module__ == 'testing' + def test_dtype(): try: import numpy as np @@ -127,7 +141,7 @@ def test_dtype(): def test_array_nested(): try: import numpy as np - + x = np.array([1]) y = (x,) dill.dumps(x) diff --git a/tests/test_functions.py b/tests/test_functions.py index c157dc7b..ec9670e2 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -28,8 +28,11 @@ def function_c(c, c1=1): def function_d(d, d1, d2=1): + """doc string""" return d + d1 + d2 +function_d.__module__ = 'a module' + if is_py3(): exec(''' @@ -63,6 +66,8 @@ def test_functions(): assert dill.loads(dumped_func_c)(1, 2) == 3 dumped_func_d = dill.dumps(function_d) + assert dill.loads(dumped_func_d).__doc__ == function_d.__doc__ + assert dill.loads(dumped_func_d).__module__ == function_d.__module__ assert dill.loads(dumped_func_d)(1, 2) == 4 assert dill.loads(dumped_func_d)(1, 2, 3) == 6 assert dill.loads(dumped_func_d)(1, 2, d2=3) == 6 From 23c47455da62d4cb8582d8f98f1de9fc6e0971ad Mon Sep 17 00:00:00 2001 From: Leonardo Gama Date: Sat, 30 Apr 2022 21:24:03 -0300 Subject: [PATCH 2/5] Fixes some bugs when using `dump_session()` with `byref=True` (#463) * fix `dump_session(byref=True)` bug when no objetcts are imported from modules When no objects are found to be imported from external modules, `_stash_modules()` should return `main_module` unmodified, and not the pair list of objects created from it (`original`). * fix `dump_session(byref=True)` bug when the `multiprocessing` module was imported Just calling `import multiprocessing` creates a `'__mp_main__'` entry in `sys.modules` that is simply a reference to the `__main__` module. In the old version of the test for module membership, objects in the global scope are wrongly attributed to this virtual `__mp_main__` module. And therefore `load_session()` fails. Testing for object identity instead of module name resolves the issue. * Save objects imported with an alias and top level modules by reference in `dump_session(byref=TRUE)` Currently, `dump_session(byref=True)` misses some imported objects. For example: - If the session had a statement `import numpy as np`, it may find a reference to the `numpy` named as `np` in some internal module listed in `sys.resources`. But if the module was imported with a non-canonical name, like `import numpy as nump`, it won't find it at all. Mapping the objects by id in `modmap` solves the issue. Note that just types of objects usually imported under an alias must be looked up by id, otherwise common objects like singletons may be wrongly attributed to a module, and such reference in the module could change to a different object depending on its initialization and state. - If a object in the global scope is a top level module, like `math`, again `save_session` may find a reference to it in another module and it works. But if this module isn't referenced anywhere else, it won't be found because the function only looks for objects inside the `sys.resources` modules and not for the modules themselves. This commit introduces two new attributes to session modules saved by reference: - `__dill_imported_as`: a list with (module name, object name, object alias in session) - `__dill_imported_top_level`: a list with (module name, module alias in session) I did it this way for forwards (complete) and backwards (partial) compatibility. Oh, and I got rid of that nasty `exec()` call in `_restore_modules()`! ;) * Deal with top level functions with `dump_session()` Fixes RecursionWarning error where a function defined in the top level of the module being saved with `dump_session(byref=True)`, of which "globals" is a reference to the original module's `__dict__`, makes dill to try to save this instead of the modified module's `__dict__`, triggering recursion. * Added tests for load_session() and dump_session() * fix singleton comparison, must be by identity, not by equality * split tests to different files to better test session use cases * Fix error Py2.7 and Py3.7 where there is a tuple in sys.modules for some reason * dump_session(): extra test for code coverage * dump_session and load_session: some minor improvements * dump_session(): more tests * dump_session(): dump modules other than __main__ by reference * dump_session(): minor code coverage investigation * dump_session() tests: adjustments * dump_session() tests: fix copyright notice * dump_session() tests: merge test files using subprocess to test loading in a new session * tests: Revert change. Test files are independent, should run in any order * dump_sessio() tests: use an unpickleable object available in PyPy --- dill/_dill.py | 142 ++++++++++++++++-------- tests/__main__.py | 1 - tests/test_session.py | 243 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 338 insertions(+), 48 deletions(-) create mode 100644 tests/test_session.py diff --git a/dill/_dill.py b/dill/_dill.py index cddc9cf5..40831e66 100644 --- a/dill/_dill.py +++ b/dill/_dill.py @@ -397,50 +397,87 @@ def loads(str, ignore=None, **kwds): ### End: Shorthands ### ### Pickle the Interpreter Session +SESSION_IMPORTED_AS_TYPES = (ModuleType, ClassType, TypeType, Exception, + FunctionType, MethodType, BuiltinMethodType) + def _module_map(): """get map of imported modules""" - from collections import defaultdict - modmap = defaultdict(list) + from collections import defaultdict, namedtuple + modmap = namedtuple('Modmap', ['by_name', 'by_id', 'top_level']) + modmap = modmap(defaultdict(list), defaultdict(list), {}) items = 'items' if PY3 else 'iteritems' - for name, module in getattr(sys.modules, items)(): - if module is None: + for modname, module in getattr(sys.modules, items)(): + if not isinstance(module, ModuleType): continue - for objname, obj in module.__dict__.items(): - modmap[objname].append((obj, name)) + if '.' not in modname: + modmap.top_level[id(module)] = modname + for objname, modobj in module.__dict__.items(): + modmap.by_name[objname].append((modobj, modname)) + modmap.by_id[id(modobj)].append((modobj, objname, modname)) return modmap -def _lookup_module(modmap, name, obj, main_module): #FIXME: needs work - """lookup name if module is imported""" - for modobj, modname in modmap[name]: - if modobj is obj and modname != main_module.__name__: - return modname +def _lookup_module(modmap, name, obj, main_module): + """lookup name or id of obj if module is imported""" + for modobj, modname in modmap.by_name[name]: + if modobj is obj and sys.modules[modname] is not main_module: + return modname, name + if isinstance(obj, SESSION_IMPORTED_AS_TYPES): + for modobj, objname, modname in modmap.by_id[id(obj)]: + if sys.modules[modname] is not main_module: + return modname, objname + return None, None def _stash_modules(main_module): modmap = _module_map() + newmod = ModuleType(main_module.__name__) + imported = [] + imported_as = [] + imported_top_level = [] # keep separeted for backwards compatibility original = {} items = 'items' if PY3 else 'iteritems' for name, obj in getattr(main_module.__dict__, items)(): - source_module = _lookup_module(modmap, name, obj, main_module) + if obj is main_module: + original[name] = newmod # self-reference + continue + + # Avoid incorrectly matching a singleton value in another package (ex.: __doc__). + if any(obj is singleton for singleton in (None, False, True)) or \ + isinstance(obj, ModuleType) and _is_builtin_module(obj): # always saved by ref + original[name] = obj + continue + + source_module, objname = _lookup_module(modmap, name, obj, main_module) if source_module: - imported.append((source_module, name)) + if objname == name: + imported.append((source_module, name)) + else: + imported_as.append((source_module, objname, name)) else: - original[name] = obj - if len(imported): - import types - newmod = types.ModuleType(main_module.__name__) + try: + imported_top_level.append((modmap.top_level[id(obj)], name)) + except KeyError: + original[name] = obj + + if len(original) < len(main_module.__dict__): newmod.__dict__.update(original) newmod.__dill_imported = imported + newmod.__dill_imported_as = imported_as + newmod.__dill_imported_top_level = imported_top_level return newmod else: - return original + return main_module -def _restore_modules(main_module): - if '__dill_imported' not in main_module.__dict__: - return - imports = main_module.__dict__.pop('__dill_imported') - for module, name in imports: - exec("from %s import %s" % (module, name), main_module.__dict__) +def _restore_modules(unpickler, main_module): + try: + for modname, name in main_module.__dict__.pop('__dill_imported'): + main_module.__dict__[name] = unpickler.find_class(modname, name) + for modname, objname, name in main_module.__dict__.pop('__dill_imported_as'): + main_module.__dict__[name] = unpickler.find_class(modname, objname) + for modname, name in main_module.__dict__.pop('__dill_imported_top_level'): + main_module.__dict__[name] = __import__(modname) + except KeyError: + pass #NOTE: 06/03/15 renamed main_module to main def dump_session(filename='/tmp/session.pkl', main=None, byref=False, **kwds): @@ -453,13 +490,16 @@ def dump_session(filename='/tmp/session.pkl', main=None, byref=False, **kwds): else: f = open(filename, 'wb') try: + pickler = Pickler(f, protocol, **kwds) + pickler._original_main = main if byref: main = _stash_modules(main) - pickler = Pickler(f, protocol, **kwds) pickler._main = main #FIXME: dill.settings are disabled pickler._byref = False # disable pickling by name reference pickler._recurse = False # disable pickling recursion for globals pickler._session = True # is best indicator of when pickling a session + pickler._first_pass = True + pickler._main_modified = main is not pickler._original_main pickler.dump(main) finally: if f is not filename: # If newly opened file @@ -480,7 +520,7 @@ def load_session(filename='/tmp/session.pkl', main=None, **kwds): module = unpickler.load() unpickler._session = False main.__dict__.update(module.__dict__) - _restore_modules(main) + _restore_modules(unpickler, main) finally: if f is not filename: # If newly opened file f.close() @@ -1060,9 +1100,11 @@ def _import_module(import_name, safe=False): return None raise -def _locate_function(obj, session=False): - if obj.__module__ in ['__main__', None]: # and session: +def _locate_function(obj, pickler=None): + if obj.__module__ in ['__main__', None] or \ + pickler and pickler._session and obj.__module__ == pickler._main.__name__: return False + found = _import_module(obj.__module__ + '.' + obj.__name__, safe=True) return found is obj @@ -1177,7 +1219,8 @@ def save_code(pickler, obj): @register(dict) def save_module_dict(pickler, obj): - if is_dill(pickler, child=False) and obj == pickler._main.__dict__ and not pickler._session: + if is_dill(pickler, child=False) and obj == pickler._main.__dict__ and \ + not (pickler._session and pickler._first_pass): log.info("D1: = 3 and sys.argv[1] == '--child': + byref = sys.argv[2] == 'True' + dill.load_session(session_file % byref) + test_modules(__main__, byref) + sys.exit() + +del test_modules + + +def _clean_up_cache(module): + cached = module.__file__.split('.', 1)[0] + '.pyc' + cached = module.__cached__ if hasattr(module, '__cached__') else cached + pycache = os.path.join(os.path.dirname(module.__file__), '__pycache__') + for remove, file in [(os.remove, cached), (os.removedirs, pycache)]: + try: + remove(file) + except OSError: + pass + + +# To clean up namespace before loading the session. +original_modules = set(sys.modules.keys()) - \ + set(['json', 'urllib', 'xml.sax', 'xml.dom.minidom', 'calendar', 'cmath']) +original_objects = set(__main__.__dict__.keys()) +original_objects.add('original_objects') + + +# Create various kinds of objects to test different internal logics. + +## Modules. +import json # top-level module +import urllib as url # top-level module under alias +from xml import sax # submodule +import xml.dom.minidom as dom # submodule under alias +import test_dictviews as local_mod # non-builtin top-level module +atexit.register(_clean_up_cache, local_mod) + +## Imported objects. +from calendar import Calendar, isleap, day_name # class, function, other object +from cmath import log as complex_log # imported with alias + +## Local objects. +x = 17 +empty = None +names = ['Alice', 'Bob', 'Carol'] +def squared(x): return x**2 +cubed = lambda x: x**3 +class Person: + def __init__(self, name, age): + self.name = name + self.age = age +person = Person(names[0], x) +class CalendarSubclass(Calendar): + def weekdays(self): + return [day_name[i] for i in self.iterweekdays()] +cal = CalendarSubclass() +selfref = __main__ + + +def test_objects(main, copy_dict, byref): + main_dict = main.__dict__ + + try: + for obj in ('json', 'url', 'local_mod', 'sax', 'dom'): + assert main_dict[obj].__name__ == copy_dict[obj].__name__ + + #FIXME: In the second test call, 'calendar' is not included in + # sys.modules, independent of the value of byref. Tried to run garbage + # collection before with no luck. This block fails even with + # "import calendar" before it. Needed to restore the original modules + # with the 'copy_modules' object. (Moved to "test_session_{1,2}.py".) + + #for obj in ('Calendar', 'isleap'): + # assert main_dict[obj] is sys.modules['calendar'].__dict__[obj] + #assert main_dict['day_name'].__module__ == 'calendar' + #if byref: + # assert main_dict['day_name'] is sys.modules['calendar'].__dict__['day_name'] + + for obj in ('x', 'empty', 'names'): + assert main_dict[obj] == copy_dict[obj] + + globs = '__globals__' if dill._dill.PY3 else 'func_globals' + for obj in ['squared', 'cubed']: + assert getattr(main_dict[obj], globs) is main_dict + assert main_dict[obj](3) == copy_dict[obj](3) + + assert main.Person.__module__ == main.__name__ + assert isinstance(main.person, main.Person) + assert main.person.age == copy_dict['person'].age + + assert issubclass(main.CalendarSubclass, main.Calendar) + assert isinstance(main.cal, main.CalendarSubclass) + assert main.cal.weekdays() == copy_dict['cal'].weekdays() + + assert main.selfref is main + + except AssertionError: + import traceback + error_line = traceback.format_exc().splitlines()[-2].replace('[obj]', '['+repr(obj)+']') + print("Error while testing (byref=%s):" % byref, error_line, sep="\n", file=sys.stderr) + raise + + +if __name__ == '__main__': + + # Test dump_session() and load_session(). + for byref in (False, True): + if byref: + # Test unpickleable imported object in main. + from sys import flags + + #print(sorted(set(sys.modules.keys()) - original_modules)) + dill._test_file = dill._dill.StringIO() + try: + # For the subprocess. + dill.dump_session(session_file % byref, byref=byref) + + dill.dump_session(dill._test_file, byref=byref) + dump = dill._test_file.getvalue() + dill._test_file.close() + + import __main__ + copy_dict = __main__.__dict__.copy() + copy_modules = sys.modules.copy() + del copy_dict['dump'] + del copy_dict['__main__'] + for name in copy_dict.keys(): + if name not in original_objects: + del __main__.__dict__[name] + for module in list(sys.modules.keys()): + if module not in original_modules: + del sys.modules[module] + + dill._test_file = dill._dill.StringIO(dump) + dill.load_session(dill._test_file) + #print(sorted(set(sys.modules.keys()) - original_modules)) + + # Test session loading in a new session. + from dill.tests.__main__ import python, shell, sp + error = sp.call([python, __file__, '--child', str(byref)], shell=shell) + if error: sys.exit(error) + del python, shell, sp + + finally: + dill._test_file.close() + try: + os.remove(session_file % byref) + except OSError: + pass + + test_objects(__main__, copy_dict, byref) + __main__.__dict__.update(copy_dict) + sys.modules.update(copy_modules) + del __main__, copy_dict, copy_modules, dump + + + # This is for code coverage, tests the use case of dump_session(byref=True) + # without imported objects in the namespace. It's a contrived example because + # even dill can't be in it. + from types import ModuleType + modname = '__test_main__' + main = ModuleType(modname) + main.x = 42 + + _main = dill._dill._stash_modules(main) + if _main is not main: + print("There are objects to save by referenece that shouldn't be:", + _main.__dill_imported, _main.__dill_imported_as, _main.__dill_imported_top_level, + file=sys.stderr) + + test_file = dill._dill.StringIO() + try: + dill.dump_session(test_file, main=main, byref=True) + dump = test_file.getvalue() + test_file.close() + + sys.modules[modname] = ModuleType(modname) # empty + # This should work after fixing https://github.com/uqfoundation/dill/issues/462 + test_file = dill._dill.StringIO(dump) + dill.load_session(test_file) + finally: + test_file.close() + + assert x == 42 + + + # Dump session for module that is not __main__: + import test_classdef as module + atexit.register(_clean_up_cache, module) + module.selfref = module + dict_objects = [obj for obj in module.__dict__.keys() if not obj.startswith('__')] + + test_file = dill._dill.StringIO() + try: + dill.dump_session(test_file, main=module) + dump = test_file.getvalue() + test_file.close() + + for obj in dict_objects: + del module.__dict__[obj] + + test_file = dill._dill.StringIO(dump) + dill.load_session(test_file, main=module) + finally: + test_file.close() + + assert all(obj in module.__dict__ for obj in dict_objects) + assert module.selfref is module From dc3471067815faee143c93dba83b6052543a700c Mon Sep 17 00:00:00 2001 From: mmckerns Date: Sat, 30 Apr 2022 20:48:53 -0400 Subject: [PATCH 3/5] fixes #467: remove NULL in nestedglobals --- dill/detect.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/dill/detect.py b/dill/detect.py index 2ceccbe9..d28b22b9 100644 --- a/dill/detect.py +++ b/dill/detect.py @@ -173,14 +173,19 @@ def nestedglobals(func, recurse=True): """get the names of any globals found within func""" func = code(func) if func is None: return list() + import sys from .temp import capture + IS_311a7 = sys.hexversion == 51052711 #FIXME: for odd behavior in 3.11a7 names = set() with capture('stdout') as out: dis.dis(func) #XXX: dis.dis(None) disassembles last traceback for line in out.getvalue().splitlines(): if '_GLOBAL' in line: name = line.split('(')[-1].split(')')[0] - names.add(name) + if IS_311a7: + names.add(name.replace('NULL + ', '')) + else: + names.add(name) for co in getattr(func, 'co_consts', tuple()): if co and recurse and iscode(co): names.update(nestedglobals(co, recurse=True)) From 6cafbf10f1b22a4926df2a6701afe12475ed070e Mon Sep 17 00:00:00 2001 From: mmckerns Date: Mon, 2 May 2022 10:57:45 -0400 Subject: [PATCH 4/5] use CAN_NULL for 3.11a7 and above --- dill/detect.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dill/detect.py b/dill/detect.py index d28b22b9..41575205 100644 --- a/dill/detect.py +++ b/dill/detect.py @@ -175,14 +175,14 @@ def nestedglobals(func, recurse=True): if func is None: return list() import sys from .temp import capture - IS_311a7 = sys.hexversion == 51052711 #FIXME: for odd behavior in 3.11a7 + CAN_NULL = sys.hexversion >= 51052711 #NULL may be prepended >= 3.11a7 names = set() with capture('stdout') as out: dis.dis(func) #XXX: dis.dis(None) disassembles last traceback for line in out.getvalue().splitlines(): if '_GLOBAL' in line: name = line.split('(')[-1].split(')')[0] - if IS_311a7: + if CAN_NULL: names.add(name.replace('NULL + ', '')) else: names.add(name) From df6ab368bde5344e6ce337ae417eff6c6aab3262 Mon Sep 17 00:00:00 2001 From: mmckerns Date: Thu, 5 May 2022 20:55:22 -0400 Subject: [PATCH 5/5] fix missing is_dill guards from #463 --- dill/_dill.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/dill/_dill.py b/dill/_dill.py index 40831e66..40c01fc5 100644 --- a/dill/_dill.py +++ b/dill/_dill.py @@ -1102,7 +1102,7 @@ def _import_module(import_name, safe=False): def _locate_function(obj, pickler=None): if obj.__module__ in ['__main__', None] or \ - pickler and pickler._session and obj.__module__ == pickler._main.__name__: + pickler and is_dill(pickler, child=False) and pickler._session and obj.__module__ == pickler._main.__name__: return False found = _import_module(obj.__module__ + '.' + obj.__name__, safe=True) @@ -1893,6 +1893,7 @@ def save_function(pickler, obj): _byref = getattr(pickler, '_byref', None) _postproc = getattr(pickler, '_postproc', None) _main_modified = getattr(pickler, '_main_modified', None) + _original_main = getattr(pickler, '_original_main', __builtin__)#'None' postproc_list = [] if _recurse: # recurse to get all globals referred to by obj @@ -1907,10 +1908,11 @@ def save_function(pickler, obj): else: globs_copy = obj.__globals__ if PY3 else obj.func_globals - # If the globals is the __dict__ from the module being save as a + # If the globals is the __dict__ from the module being saved as a # session, substitute it by the dictionary being actually saved. - if _main_modified and globs_copy is pickler._original_main.__dict__: - globs = globs_copy = pickler._main.__dict__ + if _main_modified and globs_copy is _original_main.__dict__: + globs_copy = getattr(pickler, '_main', _original_main).__dict__ + globs = globs_copy # If the globals is a module __dict__, do not save it in the pickle. elif globs_copy is not None and obj.__module__ is not None and \ getattr(_import_module(obj.__module__, True), '__dict__', None) is globs_copy: