diff --git a/asyncpg/protocol/coreproto.pyx b/asyncpg/protocol/coreproto.pyx index e7d7c2bc..6bf1adc4 100644 --- a/asyncpg/protocol/coreproto.pyx +++ b/asyncpg/protocol/coreproto.pyx @@ -5,7 +5,7 @@ # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 -from hashlib import md5 as hashlib_md5 # for MD5 authentication +import hashlib include "scram.pyx" @@ -150,15 +150,28 @@ cdef class CoreProtocol: cdef _process__auth(self, char mtype): if mtype == b'R': # Authentication... - self._parse_msg_authentication() - if self.result_type != RESULT_OK: + try: + self._parse_msg_authentication() + except Exception as ex: + # Exception in authentication parsing code + # is usually either malformed authentication data + # or missing support for cryptographic primitives + # in the hashlib module. + self.result_type = RESULT_FAILED + self.result = apg_exc.InternalClientError( + f"unexpected error while performing authentication: {ex}") + self.result.__cause__ = ex self.con_status = CONNECTION_BAD self._push_result() + else: + if self.result_type != RESULT_OK: + self.con_status = CONNECTION_BAD + self._push_result() - elif self.auth_msg is not None: - # Server wants us to send auth data, so do that. - self._write(self.auth_msg) - self.auth_msg = None + elif self.auth_msg is not None: + # Server wants us to send auth data, so do that. + self._write(self.auth_msg) + self.auth_msg = None elif mtype == b'K': # BackendKeyData @@ -634,7 +647,7 @@ cdef class CoreProtocol: # 'md5' + md5(md5(password + username) + salt)) userpass = ((self.password or '') + (self.user or '')).encode('ascii') - hash = hashlib_md5(hashlib_md5(userpass).hexdigest().\ + hash = hashlib.md5(hashlib.md5(userpass).hexdigest().\ encode('ascii') + salt).hexdigest().encode('ascii') msg.write_bytestring(b'md5' + hash) diff --git a/tests/test_connect.py b/tests/test_connect.py index 34ffbb34..d90ad8a4 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -359,8 +359,13 @@ async def test_auth_password_scram_sha_256(self): await self.con.execute(alter_password) await self.con.execute("SET password_encryption = 'md5';") - async def test_auth_unsupported(self): - pass + @unittest.mock.patch('hashlib.md5', side_effect=ValueError("no md5")) + async def test_auth_md5_unsupported(self, _): + with self.assertRaisesRegex( + exceptions.InternalClientError, + ".*no md5.*", + ): + await self.connect(user='md5_user', password='correctpassword') class TestConnectParams(tb.TestCase):