diff --git a/CHANGELOG.md b/CHANGELOG.md index fa25423..171bf8c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # Python-RSA changelog +## Version 4.7.1 - in development + +- Fix threading issue introduced in 4.7 ([#173](https://github.com/sybrenstuvel/python-rsa/issues/173) + ## Version 4.7 - released 2021-01-10 - Fix [#165](https://github.com/sybrenstuvel/python-rsa/issues/165): diff --git a/rsa/key.py b/rsa/key.py index e0e7b11..d84ae05 100644 --- a/rsa/key.py +++ b/rsa/key.py @@ -32,6 +32,7 @@ """ import logging +import threading import typing import warnings @@ -49,7 +50,7 @@ class AbstractKey: """Abstract superclass for private and public keys.""" - __slots__ = ('n', 'e', 'blindfac', 'blindfac_inverse') + __slots__ = ('n', 'e', 'blindfac', 'blindfac_inverse', 'mutex') def __init__(self, n: int, e: int) -> None: self.n = n @@ -58,6 +59,10 @@ def __init__(self, n: int, e: int) -> None: # These will be computed properly on the first call to blind(). self.blindfac = self.blindfac_inverse = -1 + # Used to protect updates to the blinding factor in multi-threaded + # environments. + self.mutex = threading.Lock() + @classmethod def _load_pkcs1_pem(cls, keyfile: bytes) -> 'AbstractKey': """Loads a key in PKCS#1 PEM format, implement in a subclass. @@ -148,36 +153,33 @@ def save_pkcs1(self, format: str = 'PEM') -> bytes: method = self._assert_format_exists(format, methods) return method() - def blind(self, message: int) -> int: - """Performs blinding on the message using random number 'r'. + def blind(self, message: int) -> typing.Tuple[int, int]: + """Performs blinding on the message. :param message: the message, as integer, to blind. - :type message: int :param r: the random number to blind with. - :type r: int - :return: the blinded message. - :rtype: int + :return: tuple (the blinded message, the inverse of the used blinding factor) The blinding is such that message = unblind(decrypt(blind(encrypt(message))). See https://en.wikipedia.org/wiki/Blinding_%28cryptography%29 """ - self._update_blinding_factor() - return (message * pow(self.blindfac, self.e, self.n)) % self.n + blindfac, blindfac_inverse = self._update_blinding_factor() + blinded = (message * pow(blindfac, self.e, self.n)) % self.n + return blinded, blindfac_inverse - def unblind(self, blinded: int) -> int: - """Performs blinding on the message using random number 'r'. + def unblind(self, blinded: int, blindfac_inverse: int) -> int: + """Performs blinding on the message using random number 'blindfac_inverse'. :param blinded: the blinded message, as integer, to unblind. - :param r: the random number to unblind with. + :param blindfac: the factor to unblind with. :return: the original message. The blinding is such that message = unblind(decrypt(blind(encrypt(message))). See https://en.wikipedia.org/wiki/Blinding_%28cryptography%29 """ - - return (self.blindfac_inverse * blinded) % self.n + return (blindfac_inverse * blinded) % self.n def _initial_blinding_factor(self) -> int: for _ in range(1000): @@ -186,18 +188,29 @@ def _initial_blinding_factor(self) -> int: return blind_r raise RuntimeError('unable to find blinding factor') - def _update_blinding_factor(self): - if self.blindfac < 0: - # Compute initial blinding factor, which is rather slow to do. - self.blindfac = self._initial_blinding_factor() - self.blindfac_inverse = rsa.common.inverse(self.blindfac, self.n) - else: - # Reuse previous blinding factor as per section 9 of 'A Timing - # Attack against RSA with the Chinese Remainder Theorem' by Werner - # Schindler. - # See https://tls.mbed.org/public/WSchindler-RSA_Timing_Attack.pdf - self.blindfac = pow(self.blindfac, 2, self.n) - self.blindfac_inverse = pow(self.blindfac_inverse, 2, self.n) + def _update_blinding_factor(self) -> typing.Tuple[int, int]: + """Update blinding factors. + + Computing a blinding factor is expensive, so instead this function + does this once, then updates the blinding factor as per section 9 + of 'A Timing Attack against RSA with the Chinese Remainder Theorem' + by Werner Schindler. + See https://tls.mbed.org/public/WSchindler-RSA_Timing_Attack.pdf + + :return: the new blinding factor and its inverse. + """ + + with self.mutex: + if self.blindfac < 0: + # Compute initial blinding factor, which is rather slow to do. + self.blindfac = self._initial_blinding_factor() + self.blindfac_inverse = rsa.common.inverse(self.blindfac, self.n) + else: + # Reuse previous blinding factor. + self.blindfac = pow(self.blindfac, 2, self.n) + self.blindfac_inverse = pow(self.blindfac_inverse, 2, self.n) + + return self.blindfac, self.blindfac_inverse class PublicKey(AbstractKey): """Represents a public RSA key. @@ -446,9 +459,10 @@ def blinded_decrypt(self, encrypted: int) -> int: :rtype: int """ - blinded = self.blind(encrypted) # blind before decrypting + # Blinding and un-blinding should be using the same factor + blinded, blindfac_inverse = self.blind(encrypted) decrypted = rsa.core.decrypt_int(blinded, self.d, self.n) - return self.unblind(decrypted) + return self.unblind(decrypted, blindfac_inverse) def blinded_encrypt(self, message: int) -> int: """Encrypts the message using blinding to prevent side-channel attacks. @@ -460,9 +474,9 @@ def blinded_encrypt(self, message: int) -> int: :rtype: int """ - blinded = self.blind(message) # blind before encrypting + blinded, blindfac_inverse = self.blind(message) encrypted = rsa.core.encrypt_int(blinded, self.d, self.n) - return self.unblind(encrypted) + return self.unblind(encrypted, blindfac_inverse) @classmethod def _load_pkcs1_der(cls, keyfile: bytes) -> 'PrivateKey': diff --git a/tests/test_key.py b/tests/test_key.py index b00e26d..75e6e12 100644 --- a/tests/test_key.py +++ b/tests/test_key.py @@ -21,19 +21,19 @@ def test_blinding(self): message = 12345 encrypted = rsa.core.encrypt_int(message, pk.e, pk.n) - blinded_1 = pk.blind(encrypted) # blind before decrypting + blinded_1, unblind_1 = pk.blind(encrypted) # blind before decrypting decrypted = rsa.core.decrypt_int(blinded_1, pk.d, pk.n) - unblinded_1 = pk.unblind(decrypted) + unblinded_1 = pk.unblind(decrypted, unblind_1) self.assertEqual(unblinded_1, message) # Re-blinding should use a different blinding factor. - blinded_2 = pk.blind(encrypted) # blind before decrypting + blinded_2, unblind_2 = pk.blind(encrypted) # blind before decrypting self.assertNotEqual(blinded_1, blinded_2) # The unblinding should still work, though. decrypted = rsa.core.decrypt_int(blinded_2, pk.d, pk.n) - unblinded_2 = pk.unblind(decrypted) + unblinded_2 = pk.unblind(decrypted, unblind_2) self.assertEqual(unblinded_2, message) @@ -69,10 +69,9 @@ def getprime(_): # This exponent will cause two other primes to be generated. exponent = 136407 - (p, q, e, d) = rsa.key.gen_keys(64, - accurate=False, - getprime_func=getprime, - exponent=exponent) + (p, q, e, d) = rsa.key.gen_keys( + 64, accurate=False, getprime_func=getprime, exponent=exponent + ) self.assertEqual(39317, p) self.assertEqual(33107, q)