From 716a20fc79c98e250c90a3d3e9f2218bec181a8d Mon Sep 17 00:00:00 2001 From: Michael Merickel Date: Sun, 16 Nov 2014 23:11:15 -0600 Subject: [PATCH] use hmac.compare_digest if available --- CHANGES.txt | 5 +++++ pyramid/tests/test_util.py | 43 ++++++++++++++++++++++++++++++++++++++ pyramid/util.py | 32 +++++++++++++++++++++------- 3 files changed, 72 insertions(+), 8 deletions(-) diff --git a/CHANGES.txt b/CHANGES.txt index a893ebae4d..bbaa6739e8 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -33,6 +33,11 @@ Features - Greatly improve the readability of the ``pcreate`` shell script output. See https://github.com/Pylons/pyramid/pull/1453 +- Improve robustness to timing attacks in the ``AuthTktCookieHelper`` and + the ``SignedCookieSessionFactory`` classes by using the stdlib's + ``hmac.compare_digest`` if it is available (such as Python 2.7.7+ and 3.3+). + See https://github.com/Pylons/pyramid/pull/1457 + Bug Fixes --------- diff --git a/pyramid/tests/test_util.py b/pyramid/tests/test_util.py index 2ca4c4a668..a18fa8d168 100644 --- a/pyramid/tests/test_util.py +++ b/pyramid/tests/test_util.py @@ -217,6 +217,49 @@ def test_empty(self): self.assertEqual(list(wos), []) self.assertEqual(wos.last, None) +class Test_strings_differ(unittest.TestCase): + def _callFUT(self, *args, **kw): + from pyramid.util import strings_differ + return strings_differ(*args, **kw) + + def test_it(self): + self.assertFalse(self._callFUT(b'foo', b'foo')) + self.assertTrue(self._callFUT(b'123', b'345')) + self.assertTrue(self._callFUT(b'1234', b'123')) + self.assertTrue(self._callFUT(b'123', b'1234')) + + def test_it_with_internal_comparator(self): + result = self._callFUT(b'foo', b'foo', compare_digest=None) + self.assertFalse(result) + + result = self._callFUT(b'123', b'abc', compare_digest=None) + self.assertTrue(result) + + def test_it_with_external_comparator(self): + class DummyComparator(object): + called = False + def __init__(self, ret_val): + self.ret_val = ret_val + + def __call__(self, a, b): + self.called = True + return self.ret_val + + dummy_compare = DummyComparator(True) + result = self._callFUT(b'foo', b'foo', compare_digest=dummy_compare) + self.assertTrue(dummy_compare.called) + self.assertFalse(result) + + dummy_compare = DummyComparator(False) + result = self._callFUT(b'123', b'345', compare_digest=dummy_compare) + self.assertTrue(dummy_compare.called) + self.assertTrue(result) + + dummy_compare = DummyComparator(False) + result = self._callFUT(b'abc', b'abc', compare_digest=dummy_compare) + self.assertTrue(dummy_compare.called) + self.assertTrue(result) + class Test_object_description(unittest.TestCase): def _callFUT(self, object): from pyramid.util import object_description diff --git a/pyramid/util.py b/pyramid/util.py index 6b92f17fc5..6de53d559a 100644 --- a/pyramid/util.py +++ b/pyramid/util.py @@ -1,4 +1,9 @@ import functools +try: + # py2.7.7+ and py3.3+ have native comparison support + from hmac import compare_digest +except ImportError: # pragma: nocover + compare_digest = None import inspect import traceback import weakref @@ -227,7 +232,7 @@ def last(self): oid = self._order[-1] return self._items[oid]() -def strings_differ(string1, string2): +def strings_differ(string1, string2, compare_digest=compare_digest): """Check whether two strings differ while avoiding timing attacks. This function returns True if the given strings differ and False @@ -237,14 +242,25 @@ def strings_differ(string1, string2): http://seb.dbzteam.org/crypto/python-oauth-timing-hmac.pdf - """ - if len(string1) != len(string2): - return True - - invalid_bits = 0 - for a, b in zip(string1, string2): - invalid_bits += a != b + .. versionchanged:: 1.6 + Support :func:`hmac.compare_digest` if it is available (Python 2.7.7+ + and Python 3.3+). + """ + len_eq = len(string1) == len(string2) + if len_eq: + invalid_bits = 0 + left = string1 + else: + invalid_bits = 1 + left = string2 + right = string2 + + if compare_digest is not None: + invalid_bits += not compare_digest(left, right) + else: + for a, b in zip(left, right): + invalid_bits += a != b return invalid_bits != 0 def object_description(object):