Skip to content

Commit

Permalink
Handle environments without home dir (MagicStack#1011)
Browse files Browse the repository at this point in the history
  • Loading branch information
LeonardBesson authored and lezram committed Apr 14, 2023
1 parent 6155213 commit 439fb96
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 20 deletions.
10 changes: 7 additions & 3 deletions asyncpg/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import asyncio
import pathlib
import platform
import typing


SYSTEM = platform.uname().system
Expand All @@ -18,7 +19,7 @@

CSIDL_APPDATA = 0x001a

def get_pg_home_directory() -> pathlib.Path:
def get_pg_home_directory() -> typing.Optional[pathlib.Path]:
# We cannot simply use expanduser() as that returns the user's
# home directory, whereas Postgres stores its config in
# %AppData% on Windows.
Expand All @@ -30,8 +31,11 @@ def get_pg_home_directory() -> pathlib.Path:
return pathlib.Path(buf.value) / 'postgresql'

else:
def get_pg_home_directory() -> pathlib.Path:
return pathlib.Path.home()
def get_pg_home_directory() -> typing.Optional[pathlib.Path]:
try:
return pathlib.Path.home()
except (RuntimeError, KeyError):
return None


async def wait_closed(stream):
Expand Down
49 changes: 32 additions & 17 deletions asyncpg/connect_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,13 @@ def _parse_tls_version(tls_version):
)


def _dot_postgresql_path(filename) -> pathlib.Path:
return (pathlib.Path.home() / '.postgresql' / filename).resolve()
def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]:
try:
homedir = pathlib.Path.home()
except (RuntimeError, KeyError):
return None

return (homedir / '.postgresql' / filename).resolve()


def _parse_connect_dsn_and_args(*, dsn, host, port, user,
Expand Down Expand Up @@ -504,11 +509,16 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
ssl.load_verify_locations(cafile=sslrootcert)
ssl.verify_mode = ssl_module.CERT_REQUIRED
else:
sslrootcert = _dot_postgresql_path('root.crt')
try:
sslrootcert = _dot_postgresql_path('root.crt')
assert sslrootcert is not None
ssl.load_verify_locations(cafile=sslrootcert)
except FileNotFoundError:
except (AssertionError, FileNotFoundError):
if sslmode > SSLMode.require:
if sslrootcert is None:
raise RuntimeError(
'Cannot determine home directory'
)
raise ValueError(
f'root certificate file "{sslrootcert}" does '
f'not exist\nEither provide the file or '
Expand All @@ -529,18 +539,20 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
ssl.verify_flags |= ssl_module.VERIFY_CRL_CHECK_CHAIN
else:
sslcrl = _dot_postgresql_path('root.crl')
try:
ssl.load_verify_locations(cafile=sslcrl)
except FileNotFoundError:
pass
else:
ssl.verify_flags |= ssl_module.VERIFY_CRL_CHECK_CHAIN
if sslcrl is not None:
try:
ssl.load_verify_locations(cafile=sslcrl)
except FileNotFoundError:
pass
else:
ssl.verify_flags |= \
ssl_module.VERIFY_CRL_CHECK_CHAIN

if sslkey is None:
sslkey = os.getenv('PGSSLKEY')
if not sslkey:
sslkey = _dot_postgresql_path('postgresql.key')
if not sslkey.exists():
if sslkey is not None and not sslkey.exists():
sslkey = None
if not sslpassword:
sslpassword = ''
Expand All @@ -552,12 +564,15 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
)
else:
sslcert = _dot_postgresql_path('postgresql.crt')
try:
ssl.load_cert_chain(
sslcert, keyfile=sslkey, password=lambda: sslpassword
)
except FileNotFoundError:
pass
if sslcert is not None:
try:
ssl.load_cert_chain(
sslcert,
keyfile=sslkey,
password=lambda: sslpassword
)
except FileNotFoundError:
pass

# OpenSSL 1.1.1 keylog file, copied from create_default_context()
if hasattr(ssl, 'keylog_filename'):
Expand Down
29 changes: 29 additions & 0 deletions tests/test_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,14 @@ def mock_dot_postgresql(*, ca=True, crl=False, client=False, protected=False):
yield


@contextlib.contextmanager
def mock_no_home_dir():
with unittest.mock.patch(
'pathlib.Path.home', unittest.mock.Mock(side_effect=RuntimeError)
):
yield


class TestSettings(tb.ConnectedTestCase):

async def test_get_settings_01(self):
Expand Down Expand Up @@ -1299,6 +1307,27 @@ async def test_connection_implicit_host(self):
user=conn_spec.get('user'))
await con.close()

@unittest.skipIf(os.environ.get('PGHOST'), 'unmanaged cluster')
async def test_connection_no_home_dir(self):
with mock_no_home_dir():
con = await self.connect(
dsn='postgresql://foo/',
user='postgres',
database='postgres',
host='localhost')
await con.fetchval('SELECT 42')
await con.close()

with self.assertRaisesRegex(
RuntimeError,
'Cannot determine home directory'
):
with mock_no_home_dir():
await self.connect(
host='localhost',
user='ssl_user',
ssl='verify-full')


class BaseTestSSLConnection(tb.ConnectedTestCase):
@classmethod
Expand Down

0 comments on commit 439fb96

Please sign in to comment.