diff --git a/dill/_dill.py b/dill/_dill.py index f4532c1b..3205c371 100644 --- a/dill/_dill.py +++ b/dill/_dill.py @@ -405,50 +405,87 @@ def loads(str, ignore=None, **kwds): ### End: Shorthands ### ### Pickle the Interpreter Session +SESSION_IMPORTED_AS_TYPES = (ModuleType, ClassType, TypeType, Exception, + FunctionType, MethodType, BuiltinMethodType) + def _module_map(): """get map of imported modules""" - from collections import defaultdict - modmap = defaultdict(list) + from collections import defaultdict, namedtuple + modmap = namedtuple('Modmap', ['by_name', 'by_id', 'top_level']) + modmap = modmap(defaultdict(list), defaultdict(list), {}) items = 'items' if PY3 else 'iteritems' - for name, module in getattr(sys.modules, items)(): - if module is None: + for modname, module in getattr(sys.modules, items)(): + if not isinstance(module, ModuleType): continue - for objname, obj in module.__dict__.items(): - modmap[objname].append((obj, name)) + if '.' not in modname: + modmap.top_level[id(module)] = modname + for objname, modobj in module.__dict__.items(): + modmap.by_name[objname].append((modobj, modname)) + modmap.by_id[id(modobj)].append((modobj, objname, modname)) return modmap -def _lookup_module(modmap, name, obj, main_module): #FIXME: needs work - """lookup name if module is imported""" - for modobj, modname in modmap[name]: - if modobj is obj and modname != main_module.__name__: - return modname +def _lookup_module(modmap, name, obj, main_module): + """lookup name or id of obj if module is imported""" + for modobj, modname in modmap.by_name[name]: + if modobj is obj and sys.modules[modname] is not main_module: + return modname, name + if isinstance(obj, SESSION_IMPORTED_AS_TYPES): + for modobj, objname, modname in modmap.by_id[id(obj)]: + if sys.modules[modname] is not main_module: + return modname, objname + return None, None def _stash_modules(main_module): modmap = _module_map() + newmod = ModuleType(main_module.__name__) + imported = [] + imported_as = [] + imported_top_level = [] # keep separeted for backwards compatibility original = {} items = 'items' if PY3 else 'iteritems' for name, obj in getattr(main_module.__dict__, items)(): - source_module = _lookup_module(modmap, name, obj, main_module) + if obj is main_module: + original[name] = newmod # self-reference + continue + + # Avoid incorrectly matching a singleton value in another package (ex.: __doc__). + if any(obj is singleton for singleton in (None, False, True)) or \ + isinstance(obj, ModuleType) and _is_builtin_module(obj): # always saved by ref + original[name] = obj + continue + + source_module, objname = _lookup_module(modmap, name, obj, main_module) if source_module: - imported.append((source_module, name)) + if objname == name: + imported.append((source_module, name)) + else: + imported_as.append((source_module, objname, name)) else: - original[name] = obj - if len(imported): - import types - newmod = types.ModuleType(main_module.__name__) + try: + imported_top_level.append((modmap.top_level[id(obj)], name)) + except KeyError: + original[name] = obj + + if len(original) < len(main_module.__dict__): newmod.__dict__.update(original) newmod.__dill_imported = imported + newmod.__dill_imported_as = imported_as + newmod.__dill_imported_top_level = imported_top_level return newmod else: - return original + return main_module -def _restore_modules(main_module): - if '__dill_imported' not in main_module.__dict__: - return - imports = main_module.__dict__.pop('__dill_imported') - for module, name in imports: - exec("from %s import %s" % (module, name), main_module.__dict__) +def _restore_modules(unpickler, main_module): + try: + for modname, name in main_module.__dict__.pop('__dill_imported'): + main_module.__dict__[name] = unpickler.find_class(modname, name) + for modname, objname, name in main_module.__dict__.pop('__dill_imported_as'): + main_module.__dict__[name] = unpickler.find_class(modname, objname) + for modname, name in main_module.__dict__.pop('__dill_imported_top_level'): + main_module.__dict__[name] = __import__(modname) + except KeyError: + pass #NOTE: 06/03/15 renamed main_module to main def dump_session(filename='/tmp/session.pkl', main=None, byref=False, **kwds): @@ -461,13 +498,16 @@ def dump_session(filename='/tmp/session.pkl', main=None, byref=False, **kwds): else: f = open(filename, 'wb') try: + pickler = Pickler(f, protocol, **kwds) + pickler._original_main = main if byref: main = _stash_modules(main) - pickler = Pickler(f, protocol, **kwds) pickler._main = main #FIXME: dill.settings are disabled pickler._byref = False # disable pickling by name reference pickler._recurse = False # disable pickling recursion for globals pickler._session = True # is best indicator of when pickling a session + pickler._first_pass = True + pickler._main_modified = main is not pickler._original_main pickler.dump(main) finally: if f is not filename: # If newly opened file @@ -488,7 +528,7 @@ def load_session(filename='/tmp/session.pkl', main=None, **kwds): module = unpickler.load() unpickler._session = False main.__dict__.update(module.__dict__) - _restore_modules(main) + _restore_modules(unpickler, main) finally: if f is not filename: # If newly opened file f.close() @@ -728,9 +768,9 @@ def _create_function(fcode, fglobals, fname=None, fdefaults=None, fclosure=None, fdict=None, fkwdefaults=None): # same as FunctionType, but enable passing __dict__ to new function, # __dict__ is the storehouse for attributes added after function creation - if fdict is None: fdict = dict() func = FunctionType(fcode, fglobals or dict(), fname, fdefaults, fclosure) - func.__dict__.update(fdict) #XXX: better copy? option to copy? + if fdict is not None: + func.__dict__.update(fdict) #XXX: better copy? option to copy? if fkwdefaults is not None: func.__kwdefaults__ = fkwdefaults # 'recurse' only stores referenced modules/objects in fglobals, @@ -1009,14 +1049,23 @@ def _create_dtypemeta(scalar_type): return NumpyDType return type(NumpyDType(scalar_type)) -def _create_namedtuple(name, fieldnames, modulename): - class_ = _import_module(modulename + '.' + name, safe=True) - if class_ is not None: - return class_ - import collections - t = collections.namedtuple(name, fieldnames) - t.__module__ = modulename - return t +if OLD37: + def _create_namedtuple(name, fieldnames, modulename, defaults=None): + class_ = _import_module(modulename + '.' + name, safe=True) + if class_ is not None: + return class_ + import collections + t = collections.namedtuple(name, fieldnames) + t.__module__ = modulename + return t +else: + def _create_namedtuple(name, fieldnames, modulename, defaults=None): + class_ = _import_module(modulename + '.' + name, safe=True) + if class_ is not None: + return class_ + import collections + t = collections.namedtuple(name, fieldnames, defaults=defaults, module=modulename) + return t def _getattr(objclass, name, repr_str): # hack to grab the reference directly @@ -1074,9 +1123,10 @@ def _getattribute(obj, name): .format(name, obj)) return obj, parent -def _locate_function(obj, session=False): +def _locate_function(obj, pickler=None): module_name = getattr(obj, '__module__', None) - if module_name in ['__main__', None]: # and session: + if module_name in ['__main__', None] or \ + pickler and is_dill(pickler, child=False) and pickler._session and module_name == pickler._main.__name__: return False if hasattr(obj, '__qualname__'): module = _import_module(module_name, safe=True) @@ -1089,11 +1139,15 @@ def _locate_function(obj, session=False): found = _import_module(module_name + '.' + obj.__name__, safe=True) return found is obj +def _setitems(dest, source): + for k, v in source.items(): + dest[k] = v def _setitems(dest, source): for k, v in source.items(): dest[k] = v + def _save_with_postproc(pickler, reduction, is_pickler_dill=None, obj=Getattr.NO_DEFAULT, postproc_list=None): if obj is Getattr.NO_DEFAULT: obj = Reduce(reduction) # pragma: no cover @@ -1199,7 +1253,8 @@ def save_code(pickler, obj): @register(dict) def save_module_dict(pickler, obj): - if is_dill(pickler, child=False) and obj == pickler._main.__dict__ and not pickler._session: + if is_dill(pickler, child=False) and obj == pickler._main.__dict__ and \ + not (pickler._session and pickler._first_pass): log.info("D1: = 51052711 #NULL may be prepended >= 3.11a7 names = set() with capture('stdout') as out: dis.dis(func) #XXX: dis.dis(None) disassembles last traceback for line in out.getvalue().splitlines(): if '_GLOBAL' in line: name = line.split('(')[-1].split(')')[0] - names.add(name) + if CAN_NULL: + names.add(name.replace('NULL + ', '')) + else: + names.add(name) for co in getattr(func, 'co_consts', tuple()): if co and recurse and iscode(co): names.update(nestedglobals(co, recurse=True)) diff --git a/tests/__main__.py b/tests/__main__.py index c64168d1..e82993c6 100644 --- a/tests/__main__.py +++ b/tests/__main__.py @@ -28,4 +28,3 @@ if not p: print('.', end='') print('') - diff --git a/tests/test_classdef.py b/tests/test_classdef.py index 2d3781c6..d9ec1cf4 100644 --- a/tests/test_classdef.py +++ b/tests/test_classdef.py @@ -122,6 +122,20 @@ def test_namedtuple(): assert Bad._fields == dill.loads(dill.dumps(Bad))._fields assert tuple(Badi) == tuple(dill.loads(dill.dumps(Badi))) + class A: + class B(namedtuple("C", ["one", "two"])): + '''docstring''' + B.__module__ = 'testing' + + a = A() + assert dill.copy(a) + + assert dill.copy(A.B).__name__ == 'B' + if dill._dill.PY3: + assert dill.copy(A.B).__qualname__.endswith('..A.B') + assert dill.copy(A.B).__doc__ == 'docstring' + assert dill.copy(A.B).__module__ == 'testing' + def test_dtype(): try: import numpy as np diff --git a/tests/test_functions.py b/tests/test_functions.py index c157dc7b..ec9670e2 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -28,8 +28,11 @@ def function_c(c, c1=1): def function_d(d, d1, d2=1): + """doc string""" return d + d1 + d2 +function_d.__module__ = 'a module' + if is_py3(): exec(''' @@ -63,6 +66,8 @@ def test_functions(): assert dill.loads(dumped_func_c)(1, 2) == 3 dumped_func_d = dill.dumps(function_d) + assert dill.loads(dumped_func_d).__doc__ == function_d.__doc__ + assert dill.loads(dumped_func_d).__module__ == function_d.__module__ assert dill.loads(dumped_func_d)(1, 2) == 4 assert dill.loads(dumped_func_d)(1, 2, 3) == 6 assert dill.loads(dumped_func_d)(1, 2, d2=3) == 6 diff --git a/tests/test_session.py b/tests/test_session.py new file mode 100644 index 00000000..fd71ea05 --- /dev/null +++ b/tests/test_session.py @@ -0,0 +1,243 @@ +#!/usr/bin/env python + +# Author: Leonardo Gama (@leogama) +# Copyright (c) 2022 The Uncertainty Quantification Foundation. +# License: 3-clause BSD. The full license text is available at: +# - https://github.com/uqfoundation/dill/blob/master/LICENSE + +from __future__ import print_function +import atexit, dill, os, sys, __main__ + +session_file = os.path.join(os.path.dirname(__file__), 'session-byref-%s.pkl') + +def test_modules(main, byref): + main_dict = main.__dict__ + + try: + for obj in ('json', 'url', 'local_mod', 'sax', 'dom'): + assert main_dict[obj].__name__ in sys.modules + + for obj in ('Calendar', 'isleap'): + assert main_dict[obj] is sys.modules['calendar'].__dict__[obj] + assert main.day_name.__module__ == 'calendar' + if byref: + assert main.day_name is sys.modules['calendar'].__dict__['day_name'] + + assert main.complex_log is sys.modules['cmath'].__dict__['log'] + + except AssertionError: + import traceback + error_line = traceback.format_exc().splitlines()[-2].replace('[obj]', '['+repr(obj)+']') + print("Error while testing (byref=%s):" % byref, error_line, sep="\n", file=sys.stderr) + raise + + +# Test session loading in a fresh interpreter session. +if __name__ == '__main__' and len(sys.argv) >= 3 and sys.argv[1] == '--child': + byref = sys.argv[2] == 'True' + dill.load_session(session_file % byref) + test_modules(__main__, byref) + sys.exit() + +del test_modules + + +def _clean_up_cache(module): + cached = module.__file__.split('.', 1)[0] + '.pyc' + cached = module.__cached__ if hasattr(module, '__cached__') else cached + pycache = os.path.join(os.path.dirname(module.__file__), '__pycache__') + for remove, file in [(os.remove, cached), (os.removedirs, pycache)]: + try: + remove(file) + except OSError: + pass + + +# To clean up namespace before loading the session. +original_modules = set(sys.modules.keys()) - \ + set(['json', 'urllib', 'xml.sax', 'xml.dom.minidom', 'calendar', 'cmath']) +original_objects = set(__main__.__dict__.keys()) +original_objects.add('original_objects') + + +# Create various kinds of objects to test different internal logics. + +## Modules. +import json # top-level module +import urllib as url # top-level module under alias +from xml import sax # submodule +import xml.dom.minidom as dom # submodule under alias +import test_dictviews as local_mod # non-builtin top-level module +atexit.register(_clean_up_cache, local_mod) + +## Imported objects. +from calendar import Calendar, isleap, day_name # class, function, other object +from cmath import log as complex_log # imported with alias + +## Local objects. +x = 17 +empty = None +names = ['Alice', 'Bob', 'Carol'] +def squared(x): return x**2 +cubed = lambda x: x**3 +class Person: + def __init__(self, name, age): + self.name = name + self.age = age +person = Person(names[0], x) +class CalendarSubclass(Calendar): + def weekdays(self): + return [day_name[i] for i in self.iterweekdays()] +cal = CalendarSubclass() +selfref = __main__ + + +def test_objects(main, copy_dict, byref): + main_dict = main.__dict__ + + try: + for obj in ('json', 'url', 'local_mod', 'sax', 'dom'): + assert main_dict[obj].__name__ == copy_dict[obj].__name__ + + #FIXME: In the second test call, 'calendar' is not included in + # sys.modules, independent of the value of byref. Tried to run garbage + # collection before with no luck. This block fails even with + # "import calendar" before it. Needed to restore the original modules + # with the 'copy_modules' object. (Moved to "test_session_{1,2}.py".) + + #for obj in ('Calendar', 'isleap'): + # assert main_dict[obj] is sys.modules['calendar'].__dict__[obj] + #assert main_dict['day_name'].__module__ == 'calendar' + #if byref: + # assert main_dict['day_name'] is sys.modules['calendar'].__dict__['day_name'] + + for obj in ('x', 'empty', 'names'): + assert main_dict[obj] == copy_dict[obj] + + globs = '__globals__' if dill._dill.PY3 else 'func_globals' + for obj in ['squared', 'cubed']: + assert getattr(main_dict[obj], globs) is main_dict + assert main_dict[obj](3) == copy_dict[obj](3) + + assert main.Person.__module__ == main.__name__ + assert isinstance(main.person, main.Person) + assert main.person.age == copy_dict['person'].age + + assert issubclass(main.CalendarSubclass, main.Calendar) + assert isinstance(main.cal, main.CalendarSubclass) + assert main.cal.weekdays() == copy_dict['cal'].weekdays() + + assert main.selfref is main + + except AssertionError: + import traceback + error_line = traceback.format_exc().splitlines()[-2].replace('[obj]', '['+repr(obj)+']') + print("Error while testing (byref=%s):" % byref, error_line, sep="\n", file=sys.stderr) + raise + + +if __name__ == '__main__': + + # Test dump_session() and load_session(). + for byref in (False, True): + if byref: + # Test unpickleable imported object in main. + from sys import flags + + #print(sorted(set(sys.modules.keys()) - original_modules)) + dill._test_file = dill._dill.StringIO() + try: + # For the subprocess. + dill.dump_session(session_file % byref, byref=byref) + + dill.dump_session(dill._test_file, byref=byref) + dump = dill._test_file.getvalue() + dill._test_file.close() + + import __main__ + copy_dict = __main__.__dict__.copy() + copy_modules = sys.modules.copy() + del copy_dict['dump'] + del copy_dict['__main__'] + for name in copy_dict.keys(): + if name not in original_objects: + del __main__.__dict__[name] + for module in list(sys.modules.keys()): + if module not in original_modules: + del sys.modules[module] + + dill._test_file = dill._dill.StringIO(dump) + dill.load_session(dill._test_file) + #print(sorted(set(sys.modules.keys()) - original_modules)) + + # Test session loading in a new session. + from dill.tests.__main__ import python, shell, sp + error = sp.call([python, __file__, '--child', str(byref)], shell=shell) + if error: sys.exit(error) + del python, shell, sp + + finally: + dill._test_file.close() + try: + os.remove(session_file % byref) + except OSError: + pass + + test_objects(__main__, copy_dict, byref) + __main__.__dict__.update(copy_dict) + sys.modules.update(copy_modules) + del __main__, copy_dict, copy_modules, dump + + + # This is for code coverage, tests the use case of dump_session(byref=True) + # without imported objects in the namespace. It's a contrived example because + # even dill can't be in it. + from types import ModuleType + modname = '__test_main__' + main = ModuleType(modname) + main.x = 42 + + _main = dill._dill._stash_modules(main) + if _main is not main: + print("There are objects to save by referenece that shouldn't be:", + _main.__dill_imported, _main.__dill_imported_as, _main.__dill_imported_top_level, + file=sys.stderr) + + test_file = dill._dill.StringIO() + try: + dill.dump_session(test_file, main=main, byref=True) + dump = test_file.getvalue() + test_file.close() + + sys.modules[modname] = ModuleType(modname) # empty + # This should work after fixing https://github.com/uqfoundation/dill/issues/462 + test_file = dill._dill.StringIO(dump) + dill.load_session(test_file) + finally: + test_file.close() + + assert x == 42 + + + # Dump session for module that is not __main__: + import test_classdef as module + atexit.register(_clean_up_cache, module) + module.selfref = module + dict_objects = [obj for obj in module.__dict__.keys() if not obj.startswith('__')] + + test_file = dill._dill.StringIO() + try: + dill.dump_session(test_file, main=module) + dump = test_file.getvalue() + test_file.close() + + for obj in dict_objects: + del module.__dict__[obj] + + test_file = dill._dill.StringIO(dump) + dill.load_session(test_file, main=module) + finally: + test_file.close() + + assert all(obj in module.__dict__ for obj in dict_objects) + assert module.selfref is module