diff --git a/six.py b/six.py index d0aece89f..7e1e64240 100644 --- a/six.py +++ b/six.py @@ -597,6 +597,8 @@ def iterlists(d, **kw): viewitems = operator.methodcaller("items") else: + import collections as _collections + def iterkeys(d, **kw): return d.iterkeys(**kw) @@ -609,11 +611,27 @@ def iteritems(d, **kw): def iterlists(d, **kw): return d.iterlists(**kw) - viewkeys = operator.methodcaller("viewkeys") - - viewvalues = operator.methodcaller("viewvalues") + def viewkeys(d): + return ( + _collections.KeysView(d) + if isinstance(d, types.DictProxyType) + else d.viewkeys() + ) + + def viewvalues(d): + return ( + _collections.ValuesView(d) + if isinstance(d, types.DictProxyType) + else d.viewvalues() + ) + + def viewitems(d): + return ( + _collections.ItemsView(d) + if isinstance(d, types.DictProxyType) + else d.viewitems() + ) - viewitems = operator.methodcaller("viewitems") _add_doc(iterkeys, "Return an iterator over the keys of a dictionary.") _add_doc(itervalues, "Return an iterator over the values of a dictionary.") diff --git a/test_six.py b/test_six.py index 0b7206741..a0cd27d40 100644 --- a/test_six.py +++ b/test_six.py @@ -434,6 +434,33 @@ def stock_method_name(viewwhat): view = meth(d) assert set(view) == set(getattr(d, name)()) +@pytest.mark.skipif("sys.version_info[:2] < (2, 7)", + reason="view methods on dictionaries only available on 2.7+") +def test_dictproxy_views(): + class Ham(object): + pass + + dictproxy = vars(Ham) + dictcopy = dict(dictproxy) + + fns = [six.viewkeys, six.viewvalues, six.viewitems] + dictproxy_views = [fn(dictproxy) for fn in fns] + + # test dictproxy six.view*s work the same as per a regular dict + assert ( + [set(v) for v in dictproxy_views] + == [set(fn(dictcopy)) for fn in fns] + ) + + # test that dictproxy mutations are also exposed on the relevent + # six.view*s results + eggs = object() + setattr(Ham, "spam", eggs) + keysview, valuesview, itemsview = dictproxy_views + assert "spam" in keysview + assert eggs in valuesview + assert ("spam", eggs) in itemsview + def test_advance_iterator(): assert six.next is six.advance_iterator