Skip to content

Commit

Permalink
Use the timeout context manager in the connection path
Browse files Browse the repository at this point in the history
Drop timeout management gymnastics from the `connect()` path and use the
`timeout` context manager instead.
  • Loading branch information
elprans committed Oct 9, 2023
1 parent 4bdd8a7 commit 395d364
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 55 deletions.
6 changes: 6 additions & 0 deletions asyncpg/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,9 @@ async def wait_closed(stream):
from ._asyncio_compat import wait_for as wait_for # noqa: F401
else:
from asyncio import wait_for as wait_for # noqa: F401


if sys.version_info < (3, 11):
from ._asyncio_compat import timeout_ctx as timeout # noqa: F401
else:
from asyncio import timeout as timeout # noqa: F401
45 changes: 12 additions & 33 deletions asyncpg/connect_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import stat
import struct
import sys
import time
import typing
import urllib.parse
import warnings
Expand Down Expand Up @@ -55,7 +54,6 @@ def parse(cls, sslmode):
'ssl',
'sslmode',
'direct_tls',
'connect_timeout',
'server_settings',
'target_session_attrs',
])
Expand Down Expand Up @@ -262,7 +260,7 @@ def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]:

def _parse_connect_dsn_and_args(*, dsn, host, port, user,
password, passfile, database, ssl,
direct_tls, connect_timeout, server_settings,
direct_tls, server_settings,
target_session_attrs):
# `auth_hosts` is the version of host information for the purposes
# of reading the pgpass file.
Expand Down Expand Up @@ -655,14 +653,14 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
params = _ConnectionParameters(
user=user, password=password, database=database, ssl=ssl,
sslmode=sslmode, direct_tls=direct_tls,
connect_timeout=connect_timeout, server_settings=server_settings,
server_settings=server_settings,
target_session_attrs=target_session_attrs)

return addrs, params


def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
database, timeout, command_timeout,
database, command_timeout,
statement_cache_size,
max_cached_statement_lifetime,
max_cacheable_statement_size,
Expand Down Expand Up @@ -695,7 +693,7 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
dsn=dsn, host=host, port=port, user=user,
password=password, passfile=passfile, ssl=ssl,
direct_tls=direct_tls, database=database,
connect_timeout=timeout, server_settings=server_settings,
server_settings=server_settings,
target_session_attrs=target_session_attrs)

config = _ClientConfiguration(
Expand Down Expand Up @@ -799,17 +797,13 @@ async def _connect_addr(
*,
addr,
loop,
timeout,
params,
config,
connection_class,
record_class
):
assert loop is not None

if timeout <= 0:
raise asyncio.TimeoutError

params_input = params
if callable(params.password):
password = params.password()
Expand All @@ -827,21 +821,16 @@ async def _connect_addr(
params_retry = params._replace(ssl=None)
else:
# skip retry if we don't have to
return await __connect_addr(params, timeout, False, *args)
return await __connect_addr(params, False, *args)

# first attempt
before = time.monotonic()
try:
return await __connect_addr(params, timeout, True, *args)
return await __connect_addr(params, True, *args)
except _RetryConnectSignal:
pass

# second attempt
timeout -= time.monotonic() - before
if timeout <= 0:
raise asyncio.TimeoutError
else:
return await __connect_addr(params_retry, timeout, False, *args)
return await __connect_addr(params_retry, False, *args)


class _RetryConnectSignal(Exception):
Expand All @@ -850,7 +839,6 @@ class _RetryConnectSignal(Exception):

async def __connect_addr(
params,
timeout,
retry,
addr,
loop,
Expand Down Expand Up @@ -882,15 +870,10 @@ async def __connect_addr(
else:
connector = loop.create_connection(proto_factory, *addr)

connector = asyncio.ensure_future(connector)
before = time.monotonic()
tr, pr = await compat.wait_for(connector, timeout=timeout)
timeout -= time.monotonic() - before
tr, pr = await connector

try:
if timeout <= 0:
raise asyncio.TimeoutError
await compat.wait_for(connected, timeout=timeout)
await connected
except (
exceptions.InvalidAuthorizationSpecificationError,
exceptions.ConnectionDoesNotExistError, # seen on Windows
Expand Down Expand Up @@ -993,23 +976,21 @@ async def _can_use_connection(connection, attr: SessionAttribute):
return await can_use(connection)


async def _connect(*, loop, timeout, connection_class, record_class, **kwargs):
async def _connect(*, loop, connection_class, record_class, **kwargs):
if loop is None:
loop = asyncio.get_event_loop()

addrs, params, config = _parse_connect_arguments(timeout=timeout, **kwargs)
addrs, params, config = _parse_connect_arguments(**kwargs)
target_attr = params.target_session_attrs

candidates = []
chosen_connection = None
last_error = None
for addr in addrs:
before = time.monotonic()
try:
conn = await _connect_addr(
addr=addr,
loop=loop,
timeout=timeout,
params=params,
config=config,
connection_class=connection_class,
Expand All @@ -1019,10 +1000,8 @@ async def _connect(*, loop, timeout, connection_class, record_class, **kwargs):
if await _can_use_connection(conn, target_attr):
chosen_connection = conn
break
except (OSError, asyncio.TimeoutError, ConnectionError) as ex:
except OSError as ex:
last_error = ex
finally:
timeout -= time.monotonic() - before
else:
if target_attr == SessionAttribute.prefer_standby and candidates:
chosen_connection = random.choice(candidates)
Expand Down
43 changes: 22 additions & 21 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import warnings
import weakref

from . import compat
from . import connect_utils
from . import cursor
from . import exceptions
Expand Down Expand Up @@ -2184,27 +2185,27 @@ async def connect(dsn=None, *,
if loop is None:
loop = asyncio.get_event_loop()

return await connect_utils._connect(
loop=loop,
timeout=timeout,
connection_class=connection_class,
record_class=record_class,
dsn=dsn,
host=host,
port=port,
user=user,
password=password,
passfile=passfile,
ssl=ssl,
direct_tls=direct_tls,
database=database,
server_settings=server_settings,
command_timeout=command_timeout,
statement_cache_size=statement_cache_size,
max_cached_statement_lifetime=max_cached_statement_lifetime,
max_cacheable_statement_size=max_cacheable_statement_size,
target_session_attrs=target_session_attrs
)
async with compat.timeout(timeout):
return await connect_utils._connect(
loop=loop,
connection_class=connection_class,
record_class=record_class,
dsn=dsn,
host=host,
port=port,
user=user,
password=password,
passfile=passfile,
ssl=ssl,
direct_tls=direct_tls,
database=database,
server_settings=server_settings,
command_timeout=command_timeout,
statement_cache_size=statement_cache_size,
max_cached_statement_lifetime=max_cached_statement_lifetime,
max_cacheable_statement_size=max_cacheable_statement_size,
target_session_attrs=target_session_attrs
)


class _StatementCacheEntry:
Expand Down
17 changes: 17 additions & 0 deletions tests/test_adversity.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,23 @@ async def test_connection_close_timeout(self):
with self.assertRaises(asyncio.TimeoutError):
await con.close(timeout=0.5)

@tb.with_timeout(30.0)
async def test_pool_acquire_timeout(self):
pool = await self.create_pool(
database='postgres', min_size=2, max_size=2)
try:
self.proxy.trigger_connectivity_loss()
for _ in range(2):
with self.assertRaises(asyncio.TimeoutError):
async with pool.acquire(timeout=0.5):
pass
self.proxy.restore_connectivity()
async with pool.acquire(timeout=0.5):
pass
finally:
self.proxy.restore_connectivity()
pool.terminate()

@tb.with_timeout(30.0)
async def test_pool_release_timeout(self):
pool = await self.create_pool(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,7 +891,7 @@ def run_testcase(self, testcase):
addrs, params = connect_utils._parse_connect_dsn_and_args(
dsn=dsn, host=host, port=port, user=user, password=password,
passfile=passfile, database=database, ssl=sslmode,
direct_tls=False, connect_timeout=None,
direct_tls=False,
server_settings=server_settings,
target_session_attrs=target_session_attrs)

Expand Down

0 comments on commit 395d364

Please sign in to comment.