diff --git a/asyncpg/compat.py b/asyncpg/compat.py index 29b8e16e..b9b13fa5 100644 --- a/asyncpg/compat.py +++ b/asyncpg/compat.py @@ -8,6 +8,7 @@ import asyncio import pathlib import platform +import typing SYSTEM = platform.uname().system @@ -18,7 +19,7 @@ CSIDL_APPDATA = 0x001a - def get_pg_home_directory() -> pathlib.Path: + def get_pg_home_directory() -> typing.Optional[pathlib.Path]: # We cannot simply use expanduser() as that returns the user's # home directory, whereas Postgres stores its config in # %AppData% on Windows. @@ -30,8 +31,11 @@ def get_pg_home_directory() -> pathlib.Path: return pathlib.Path(buf.value) / 'postgresql' else: - def get_pg_home_directory() -> pathlib.Path: - return pathlib.Path.home() + def get_pg_home_directory() -> typing.Optional[pathlib.Path]: + try: + return pathlib.Path.home() + except (RuntimeError, KeyError): + return None async def wait_closed(stream): diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index e0c10442..8b29c0fc 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -251,8 +251,13 @@ def _parse_tls_version(tls_version): ) -def _dot_postgresql_path(filename) -> pathlib.Path: - return (pathlib.Path.home() / '.postgresql' / filename).resolve() +def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]: + try: + homedir = pathlib.Path.home() + except (RuntimeError, KeyError): + return None + + return (homedir / '.postgresql' / filename).resolve() def _parse_connect_dsn_and_args(*, dsn, host, port, user, @@ -504,11 +509,16 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, ssl.load_verify_locations(cafile=sslrootcert) ssl.verify_mode = ssl_module.CERT_REQUIRED else: - sslrootcert = _dot_postgresql_path('root.crt') try: + sslrootcert = _dot_postgresql_path('root.crt') + assert sslrootcert is not None ssl.load_verify_locations(cafile=sslrootcert) - except FileNotFoundError: + except (AssertionError, FileNotFoundError): if sslmode > SSLMode.require: + if sslrootcert is None: + raise RuntimeError( + 'Cannot determine home directory' + ) raise ValueError( f'root certificate file "{sslrootcert}" does ' f'not exist\nEither provide the file or ' @@ -529,18 +539,20 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, ssl.verify_flags |= ssl_module.VERIFY_CRL_CHECK_CHAIN else: sslcrl = _dot_postgresql_path('root.crl') - try: - ssl.load_verify_locations(cafile=sslcrl) - except FileNotFoundError: - pass - else: - ssl.verify_flags |= ssl_module.VERIFY_CRL_CHECK_CHAIN + if sslcrl is not None: + try: + ssl.load_verify_locations(cafile=sslcrl) + except FileNotFoundError: + pass + else: + ssl.verify_flags |= \ + ssl_module.VERIFY_CRL_CHECK_CHAIN if sslkey is None: sslkey = os.getenv('PGSSLKEY') if not sslkey: sslkey = _dot_postgresql_path('postgresql.key') - if not sslkey.exists(): + if sslkey is not None and not sslkey.exists(): sslkey = None if not sslpassword: sslpassword = '' @@ -552,12 +564,15 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, ) else: sslcert = _dot_postgresql_path('postgresql.crt') - try: - ssl.load_cert_chain( - sslcert, keyfile=sslkey, password=lambda: sslpassword - ) - except FileNotFoundError: - pass + if sslcert is not None: + try: + ssl.load_cert_chain( + sslcert, + keyfile=sslkey, + password=lambda: sslpassword + ) + except FileNotFoundError: + pass # OpenSSL 1.1.1 keylog file, copied from create_default_context() if hasattr(ssl, 'keylog_filename'): diff --git a/tests/test_connect.py b/tests/test_connect.py index 02a6a50b..628d8aba 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -71,6 +71,14 @@ def mock_dot_postgresql(*, ca=True, crl=False, client=False, protected=False): yield +@contextlib.contextmanager +def mock_no_home_dir(): + with unittest.mock.patch( + 'pathlib.Path.home', unittest.mock.Mock(side_effect=RuntimeError) + ): + yield + + class TestSettings(tb.ConnectedTestCase): async def test_get_settings_01(self): @@ -1299,6 +1307,27 @@ async def test_connection_implicit_host(self): user=conn_spec.get('user')) await con.close() + @unittest.skipIf(os.environ.get('PGHOST'), 'unmanaged cluster') + async def test_connection_no_home_dir(self): + with mock_no_home_dir(): + con = await self.connect( + dsn='postgresql://foo/', + user='postgres', + database='postgres', + host='localhost') + await con.fetchval('SELECT 42') + await con.close() + + with self.assertRaisesRegex( + RuntimeError, + 'Cannot determine home directory' + ): + with mock_no_home_dir(): + await self.connect( + host='localhost', + user='ssl_user', + ssl='verify-full') + class BaseTestSSLConnection(tb.ConnectedTestCase): @classmethod