diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index f98935f5..c09bf5e0 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -757,10 +757,9 @@ async def _connect_addr( params_input = params if callable(params.password): - if inspect.iscoroutinefunction(params.password): - password = await params.password() - else: - password = params.password() + password = params.password() + if inspect.isawaitable(password): + password = await password params = params._replace(password=password) args = (addr, loop, config, connection_class, record_class, params_input) diff --git a/tests/test_connect.py b/tests/test_connect.py index b78f4f48..34ffbb34 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -282,6 +282,25 @@ async def get_wrongpassword(): user='password_user', password=get_wrongpassword) + async def test_auth_password_cleartext_callable_awaitable(self): + async def get_correctpassword(): + return 'correctpassword' + + async def get_wrongpassword(): + return 'wrongpassword' + + conn = await self.connect( + user='password_user', + password=lambda: get_correctpassword()) + await conn.close() + + with self.assertRaisesRegex( + asyncpg.InvalidPasswordError, + 'password authentication failed for user "password_user"'): + await self._try_connect( + user='password_user', + password=lambda: get_wrongpassword()) + async def test_auth_password_md5(self): conn = await self.connect( user='md5_user', password='correctpassword')