Skip to content

Commit

Permalink
Support dict views and functools.lru_cache (#449)
Browse files Browse the repository at this point in the history
* Support dictionary views

* Remove some unnecessary functions and add LRU cache
  • Loading branch information
anivegesana authored Apr 19, 2022
1 parent e2831d0 commit 914d47f
Show file tree
Hide file tree
Showing 3 changed files with 226 additions and 58 deletions.
231 changes: 173 additions & 58 deletions dill/_dill.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def _trace(boolean):
OLDER = (PY3 and sys.hexversion < 0x3040000) or (sys.hexversion < 0x2070ab1)
OLD33 = (sys.hexversion < 0x3030000)
OLD37 = (sys.hexversion < 0x3070000)
OLD39 = (sys.hexversion < 0x3090000)
OLD310 = (sys.hexversion < 0x30a0000)
PY34 = (0x3040000 <= sys.hexversion < 0x3050000)
if PY3: #XXX: get types from .objtypes ?
import builtins as __builtin__
Expand Down Expand Up @@ -219,6 +221,14 @@ class _member(object):
ItemGetterType = type(itemgetter(0))
AttrGetterType = type(attrgetter('__repr__'))

try:
from functools import _lru_cache_wrapper as LRUCacheType
except:
LRUCacheType = None

if not isinstance(LRUCacheType, type):
LRUCacheType = None

def get_file_type(*args, **kwargs):
open = kwargs.pop("open", __builtin__.open)
f = open(os.devnull, *args, **kwargs)
Expand Down Expand Up @@ -262,6 +272,8 @@ def get_file_type(*args, **kwargs):
except NameError: ExitType = None
singletontypes = []

from collections import OrderedDict

import inspect

### Shims for different versions of Python and dill
Expand Down Expand Up @@ -914,6 +926,23 @@ def __getattribute__(self, attr):
attrs[index] = ".".join([attrs[index], attr])
return type(self)(attrs, index)

class _dictproxy_helper(dict):
def __ror__(self, a):
return a

_dictproxy_helper_instance = _dictproxy_helper()

__d = {}
try:
# In CPython 3.9 and later, this trick can be used to exploit the
# implementation of the __or__ function of MappingProxyType to get the true
# mapping referenced by the proxy. It may work for other implementations,
# but is not guaranteed.
MAPPING_PROXY_TRICK = __d is (DictProxyType(__d) | _dictproxy_helper_instance)
except:
MAPPING_PROXY_TRICK = False
del __d

# _CELL_REF and _CELL_EMPTY are used to stay compatible with versions of dill
# whose _create_cell functions do not have a default value.
# _CELL_REF can be safely removed entirely (replaced by empty tuples for calls
Expand Down Expand Up @@ -1166,6 +1195,62 @@ def save_module_dict(pickler, obj):
log.info("# D2")
return


if not OLD310 and MAPPING_PROXY_TRICK:
def save_dict_view(dicttype):
def save_dict_view_for_function(func):
def _save_dict_view(pickler, obj):
log.info("Dkvi: <%s>" % (obj,))
mapping = obj.mapping | _dictproxy_helper_instance
pickler.save_reduce(func, (mapping,), obj=obj)
log.info("# Dkvi")
return _save_dict_view
return [
(funcname, save_dict_view_for_function(getattr(dicttype, funcname)))
for funcname in ('keys', 'values', 'items')
]
else:
# The following functions are based on 'cloudpickle'
# https://github.com/cloudpipe/cloudpickle/blob/5d89947288a18029672596a4d719093cc6d5a412/cloudpickle/cloudpickle.py#L922-L940
# Copyright (c) 2012, Regents of the University of California.
# Copyright (c) 2009 `PiCloud, Inc. <http://www.picloud.com>`_.
# License: https://github.com/cloudpipe/cloudpickle/blob/master/LICENSE
def save_dict_view(dicttype):
def save_dict_keys(pickler, obj):
log.info("Dk: <%s>" % (obj,))
dict_constructor = _shims.Reduce(dicttype.fromkeys, (list(obj),))
pickler.save_reduce(dicttype.keys, (dict_constructor,), obj=obj)
log.info("# Dk")

def save_dict_values(pickler, obj):
log.info("Dv: <%s>" % (obj,))
dict_constructor = _shims.Reduce(dicttype, (enumerate(obj),))
pickler.save_reduce(dicttype.values, (dict_constructor,), obj=obj)
log.info("# Dv")

def save_dict_items(pickler, obj):
log.info("Di: <%s>" % (obj,))
pickler.save_reduce(dicttype.items, (dicttype(obj),), obj=obj)
log.info("# Di")

return (
('keys', save_dict_keys),
('values', save_dict_values),
('items', save_dict_items)
)

for __dicttype in (
dict,
OrderedDict
):
__obj = __dicttype()
for __funcname, __savefunc in save_dict_view(__dicttype):
__tview = type(getattr(__obj, __funcname)())
if __tview not in Pickler.dispatch:
Pickler.dispatch[__tview] = __savefunc
del __dicttype, __obj, __funcname, __tview, __savefunc


@register(ClassType)
def save_classobj(pickler, obj): #FIXME: enable pickler._byref
if obj.__module__ == '__main__': #XXX: use _main_module.__name__ everywhere?
Expand Down Expand Up @@ -1206,24 +1291,25 @@ def save_socket(pickler, obj):
log.info("# So")
return

@register(ItemGetterType)
def save_itemgetter(pickler, obj):
log.info("Ig: %s" % obj)
helper = _itemgetter_helper()
obj(helper)
pickler.save_reduce(type(obj), tuple(helper.items), obj=obj)
log.info("# Ig")
return
if sys.hexversion <= 0x3050000:
@register(ItemGetterType)
def save_itemgetter(pickler, obj):
log.info("Ig: %s" % obj)
helper = _itemgetter_helper()
obj(helper)
pickler.save_reduce(type(obj), tuple(helper.items), obj=obj)
log.info("# Ig")
return

@register(AttrGetterType)
def save_attrgetter(pickler, obj):
log.info("Ag: %s" % obj)
attrs = []
helper = _attrgetter_helper(attrs)
obj(helper)
pickler.save_reduce(type(obj), tuple(attrs), obj=obj)
log.info("# Ag")
return
@register(AttrGetterType)
def save_attrgetter(pickler, obj):
log.info("Ag: %s" % obj)
attrs = []
helper = _attrgetter_helper(attrs)
obj(helper)
pickler.save_reduce(type(obj), tuple(attrs), obj=obj)
log.info("# Ag")
return

def _save_file(pickler, obj, open_):
if obj.closed:
Expand Down Expand Up @@ -1303,13 +1389,33 @@ def save_stringo(pickler, obj):
log.info("# Io")
return

@register(PartialType)
def save_functor(pickler, obj):
log.info("Fu: %s" % obj)
pickler.save_reduce(_create_ftype, (type(obj), obj.func, obj.args,
obj.keywords), obj=obj)
log.info("# Fu")
return
if 0x2050000 <= sys.hexversion < 0x3010000:
@register(PartialType)
def save_functor(pickler, obj):
log.info("Fu: %s" % obj)
pickler.save_reduce(_create_ftype, (type(obj), obj.func, obj.args,
obj.keywords), obj=obj)
log.info("# Fu")
return

if LRUCacheType is not None:
from functools import lru_cache
@register(LRUCacheType)
def save_lru_cache(pickler, obj):
log.info("LRU: %s" % obj)
if OLD39:
kwargs = obj.cache_info()
args = (kwargs.maxsize,)
else:
kwargs = obj.cache_parameters()
args = (kwargs['maxsize'], kwargs['typed'])
if args != lru_cache.__defaults__:
wrapper = Reduce(lru_cache, args, is_callable=True)
else:
wrapper = lru_cache
pickler.save_reduce(wrapper, (obj.__wrapped__,), obj=obj)
log.info("# LRU")
return

@register(SuperType)
def save_super(pickler, obj):
Expand All @@ -1318,41 +1424,42 @@ def save_super(pickler, obj):
log.info("# Su")
return

@register(BuiltinMethodType)
def save_builtin_method(pickler, obj):
if obj.__self__ is not None:
if obj.__self__ is __builtin__:
module = 'builtins' if PY3 else '__builtin__'
_t = "B1"
log.info("%s: %s" % (_t, obj))
if OLDER or not PY3:
@register(BuiltinMethodType)
def save_builtin_method(pickler, obj):
if obj.__self__ is not None:
if obj.__self__ is __builtin__:
module = 'builtins' if PY3 else '__builtin__'
_t = "B1"
log.info("%s: %s" % (_t, obj))
else:
module = obj.__self__
_t = "B3"
log.info("%s: %s" % (_t, obj))
if is_dill(pickler, child=True):
_recurse = pickler._recurse
pickler._recurse = False
pickler.save_reduce(_get_attr, (module, obj.__name__), obj=obj)
if is_dill(pickler, child=True):
pickler._recurse = _recurse
log.info("# %s" % _t)
else:
module = obj.__self__
_t = "B3"
log.info("%s: %s" % (_t, obj))
if is_dill(pickler, child=True):
_recurse = pickler._recurse
pickler._recurse = False
pickler.save_reduce(_get_attr, (module, obj.__name__), obj=obj)
if is_dill(pickler, child=True):
pickler._recurse = _recurse
log.info("# %s" % _t)
else:
log.info("B2: %s" % obj)
name = getattr(obj, '__qualname__', getattr(obj, '__name__', None))
StockPickler.save_global(pickler, obj, name=name)
log.info("# B2")
return
log.info("B2: %s" % obj)
name = getattr(obj, '__qualname__', getattr(obj, '__name__', None))
StockPickler.save_global(pickler, obj, name=name)
log.info("# B2")
return

@register(MethodType) #FIXME: fails for 'hidden' or 'name-mangled' classes
def save_instancemethod0(pickler, obj):# example: cStringIO.StringI
log.info("Me: %s" % obj) #XXX: obj.__dict__ handled elsewhere?
if PY3:
pickler.save_reduce(MethodType, (obj.__func__, obj.__self__), obj=obj)
else:
pickler.save_reduce(MethodType, (obj.im_func, obj.im_self,
obj.im_class), obj=obj)
log.info("# Me")
return
@register(MethodType) #FIXME: fails for 'hidden' or 'name-mangled' classes
def save_instancemethod0(pickler, obj):# example: cStringIO.StringI
log.info("Me: %s" % obj) #XXX: obj.__dict__ handled elsewhere?
if PY3:
pickler.save_reduce(MethodType, (obj.__func__, obj.__self__), obj=obj)
else:
pickler.save_reduce(MethodType, (obj.im_func, obj.im_self,
obj.im_class), obj=obj)
log.info("# Me")
return

if sys.hexversion >= 0x20500f0:
if not IS_PYPY:
Expand Down Expand Up @@ -1440,7 +1547,15 @@ def save_cell(pickler, obj):
log.info("# Ce1")
return

if not IS_PYPY:
if MAPPING_PROXY_TRICK:
@register(DictProxyType)
def save_dictproxy(pickler, obj):
log.info("Mp: %s" % obj)
mapping = obj | _dictproxy_helper_instance
pickler.save_reduce(DictProxyType, (mapping,), obj=obj)
log.info("# Mp")
return
elif not IS_PYPY:
if not OLD33:
@register(DictProxyType)
def save_dictproxy(pickler, obj):
Expand Down
35 changes: 35 additions & 0 deletions tests/test_dictviews.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#!/usr/bin/env python
#
# Author: Mike McKerns (mmckerns @caltech and @uqfoundation)
# Copyright (c) 2008-2016 California Institute of Technology.
# Copyright (c) 2016-2021 The Uncertainty Quantification Foundation.
# License: 3-clause BSD. The full license text is available at:
# - https://github.com/uqfoundation/dill/blob/master/LICENSE

import dill
from dill._dill import OLD310, MAPPING_PROXY_TRICK

def test_dictviews():
x = {'a': 1}
assert dill.copy(x.keys())
assert dill.copy(x.values())
assert dill.copy(x.items())

def test_dictproxy_trick():
if not OLD310 and MAPPING_PROXY_TRICK:
x = {'a': 1}
all_views = (x.values(), x.items(), x.keys(), x)
seperate_views = dill.copy(all_views)
new_x = seperate_views[-1]
new_x['b'] = 2
new_x['c'] = 1
assert len(new_x) == 3 and len(x) == 1
assert len(seperate_views[0]) == 3 and len(all_views[0]) == 1
assert len(seperate_views[1]) == 3 and len(all_views[1]) == 1
assert len(seperate_views[2]) == 3 and len(all_views[2]) == 1
assert dict(all_views[1]) == x
assert dict(seperate_views[1]) == new_x

if __name__ == '__main__':
test_dictviews()
test_dictproxy_trick()
18 changes: 18 additions & 0 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# License: 3-clause BSD. The full license text is available at:
# - https://github.com/uqfoundation/dill/blob/master/LICENSE

import functools
import dill
import sys
dill.settings['recurse'] = True
Expand Down Expand Up @@ -35,6 +36,14 @@ def function_d(d, d1, d2=1):
def function_e(e, *e1, e2=1, e3=2):
return e + sum(e1) + e2 + e3''')

globalvar = 0

@functools.lru_cache(None)
def function_with_cache(x):
global globalvar
globalvar += x
return globalvar


def function_with_unassigned_variable():
if False:
Expand All @@ -58,6 +67,15 @@ def test_functions():
assert dill.loads(dumped_func_d)(1, 2, 3) == 6
assert dill.loads(dumped_func_d)(1, 2, d2=3) == 6

if is_py3():
function_with_cache(1)
globalvar = 0
dumped_func_cache = dill.dumps(function_with_cache)
assert function_with_cache(2) == 3
assert function_with_cache(1) == 1
assert function_with_cache(3) == 6
assert function_with_cache(2) == 3

empty_cell = function_with_unassigned_variable()
cell_copy = dill.loads(dill.dumps(empty_cell))
assert 'empty' in str(cell_copy.__closure__[0])
Expand Down

0 comments on commit 914d47f

Please sign in to comment.