diff --git a/httpcore/_async/connection_pool.py b/httpcore/_async/connection_pool.py index 214dfc4b..6f64c3d9 100644 --- a/httpcore/_async/connection_pool.py +++ b/httpcore/_async/connection_pool.py @@ -7,7 +7,7 @@ from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol from .._models import Origin, Request, Response -from .._synchronization import AsyncEvent, AsyncShieldCancellation, AsyncThreadLock +from .._synchronization import AsyncEvent, AsyncThreadLock, async_cancel_shield from .connection import AsyncHTTPConnection from .interfaces import AsyncConnectionInterface, AsyncRequestInterface @@ -299,11 +299,16 @@ def _assign_requests_to_connections(self) -> List[AsyncConnectionInterface]: return closing_connections async def _close_connections(self, closing: List[AsyncConnectionInterface]) -> None: + if not closing: + return + # Close connections which have been removed from the pool. - with AsyncShieldCancellation(): + async def close() -> None: for connection in closing: await connection.aclose() + await async_cancel_shield(close) + async def aclose(self) -> None: # Explicitly close the connection pool. # Clears all existing requests and connections. @@ -369,9 +374,9 @@ async def __aiter__(self) -> AsyncIterator[bytes]: async def aclose(self) -> None: if not self._closed: self._closed = True - with AsyncShieldCancellation(): - if hasattr(self._stream, "aclose"): - await self._stream.aclose() + + if hasattr(self._stream, "aclose"): + await async_cancel_shield(self._stream.aclose) with self._pool._optional_thread_lock: self._pool._requests.remove(self._pool_request) diff --git a/httpcore/_async/http11.py b/httpcore/_async/http11.py index 0493a923..569e3cd3 100644 --- a/httpcore/_async/http11.py +++ b/httpcore/_async/http11.py @@ -25,7 +25,7 @@ map_exceptions, ) from .._models import Origin, Request, Response -from .._synchronization import AsyncLock, AsyncShieldCancellation +from .._synchronization import AsyncLock, async_cancel_shield from .._trace import Trace from .interfaces import AsyncConnectionInterface @@ -137,9 +137,8 @@ async def handle_async_request(self, request: Request) -> Response: }, ) except BaseException as exc: - with AsyncShieldCancellation(): - async with Trace("response_closed", logger, request) as trace: - await self._response_closed() + async with Trace("response_closed", logger, request) as trace: + await async_cancel_shield(self._response_closed) raise exc # Sending the request... @@ -344,8 +343,7 @@ async def __aiter__(self) -> AsyncIterator[bytes]: # If we get an exception while streaming the response, # we want to close the response (and possibly the connection) # before raising that exception. - with AsyncShieldCancellation(): - await self.aclose() + await async_cancel_shield(self.aclose) raise exc async def aclose(self) -> None: diff --git a/httpcore/_async/http2.py b/httpcore/_async/http2.py index c201ee4c..da47fbb2 100644 --- a/httpcore/_async/http2.py +++ b/httpcore/_async/http2.py @@ -17,7 +17,7 @@ RemoteProtocolError, ) from .._models import Origin, Request, Response -from .._synchronization import AsyncLock, AsyncSemaphore, AsyncShieldCancellation +from .._synchronization import AsyncLock, AsyncSemaphore, async_cancel_shield from .._trace import Trace from .interfaces import AsyncConnectionInterface @@ -108,8 +108,7 @@ async def handle_async_request(self, request: Request) -> Response: async with Trace("send_connection_init", logger, request, kwargs): await self._send_connection_init(**kwargs) except BaseException as exc: - with AsyncShieldCancellation(): - await self.aclose() + await async_cancel_shield(self.aclose) raise exc self._sent_connection_init = True @@ -160,11 +159,12 @@ async def handle_async_request(self, request: Request) -> Response: "stream_id": stream_id, }, ) - except BaseException as exc: # noqa: PIE786 - with AsyncShieldCancellation(): - kwargs = {"stream_id": stream_id} - async with Trace("response_closed", logger, request, kwargs): - await self._response_closed(stream_id=stream_id) + except BaseException as exc: + kwargs = {"stream_id": stream_id} + async with Trace("response_closed", logger, request, kwargs): + await async_cancel_shield( + lambda: self._response_closed(stream_id=stream_id) + ) if isinstance(exc, h2.exceptions.ProtocolError): # One case where h2 can raise a protocol error is when a @@ -577,8 +577,7 @@ async def __aiter__(self) -> typing.AsyncIterator[bytes]: # If we get an exception while streaming the response, # we want to close the response (and possibly the connection) # before raising that exception. - with AsyncShieldCancellation(): - await self.aclose() + await async_cancel_shield(self.aclose) raise exc async def aclose(self) -> None: diff --git a/httpcore/_backends/auto.py b/httpcore/_backends/auto.py index 3ac05f4d..9b362764 100644 --- a/httpcore/_backends/auto.py +++ b/httpcore/_backends/auto.py @@ -1,14 +1,14 @@ import typing from typing import Optional -from .._synchronization import current_async_library +from .._synchronization import current_async_backend from .base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream class AutoBackend(AsyncNetworkBackend): async def _init_backend(self) -> None: if not (hasattr(self, "_backend")): - backend = current_async_library() + backend = current_async_backend() if backend == "trio": from .trio import TrioBackend diff --git a/httpcore/_sync/connection_pool.py b/httpcore/_sync/connection_pool.py index 01bec59e..fc8c4d7b 100644 --- a/httpcore/_sync/connection_pool.py +++ b/httpcore/_sync/connection_pool.py @@ -7,7 +7,7 @@ from .._backends.base import SOCKET_OPTION, NetworkBackend from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol from .._models import Origin, Request, Response -from .._synchronization import Event, ShieldCancellation, ThreadLock +from .._synchronization import Event, ThreadLock, sync_cancel_shield from .connection import HTTPConnection from .interfaces import ConnectionInterface, RequestInterface @@ -299,11 +299,16 @@ def _assign_requests_to_connections(self) -> List[ConnectionInterface]: return closing_connections def _close_connections(self, closing: List[ConnectionInterface]) -> None: + if not closing: + return + # Close connections which have been removed from the pool. - with ShieldCancellation(): + def close() -> None: for connection in closing: connection.close() + sync_cancel_shield(close) + def close(self) -> None: # Explicitly close the connection pool. # Clears all existing requests and connections. @@ -369,9 +374,9 @@ def __iter__(self) -> Iterator[bytes]: def close(self) -> None: if not self._closed: self._closed = True - with ShieldCancellation(): - if hasattr(self._stream, "close"): - self._stream.close() + + if hasattr(self._stream, "close"): + sync_cancel_shield(self._stream.close) with self._pool._optional_thread_lock: self._pool._requests.remove(self._pool_request) diff --git a/httpcore/_sync/http11.py b/httpcore/_sync/http11.py index a74ff8e8..12b0eca4 100644 --- a/httpcore/_sync/http11.py +++ b/httpcore/_sync/http11.py @@ -25,7 +25,7 @@ map_exceptions, ) from .._models import Origin, Request, Response -from .._synchronization import Lock, ShieldCancellation +from .._synchronization import Lock, sync_cancel_shield from .._trace import Trace from .interfaces import ConnectionInterface @@ -137,9 +137,8 @@ def handle_request(self, request: Request) -> Response: }, ) except BaseException as exc: - with ShieldCancellation(): - with Trace("response_closed", logger, request) as trace: - self._response_closed() + with Trace("response_closed", logger, request) as trace: + sync_cancel_shield(self._response_closed) raise exc # Sending the request... @@ -344,8 +343,7 @@ def __iter__(self) -> Iterator[bytes]: # If we get an exception while streaming the response, # we want to close the response (and possibly the connection) # before raising that exception. - with ShieldCancellation(): - self.close() + sync_cancel_shield(self.close) raise exc def close(self) -> None: diff --git a/httpcore/_sync/http2.py b/httpcore/_sync/http2.py index 1ee4bbb3..ea0b02b7 100644 --- a/httpcore/_sync/http2.py +++ b/httpcore/_sync/http2.py @@ -17,7 +17,7 @@ RemoteProtocolError, ) from .._models import Origin, Request, Response -from .._synchronization import Lock, Semaphore, ShieldCancellation +from .._synchronization import Lock, Semaphore, sync_cancel_shield from .._trace import Trace from .interfaces import ConnectionInterface @@ -108,8 +108,7 @@ def handle_request(self, request: Request) -> Response: with Trace("send_connection_init", logger, request, kwargs): self._send_connection_init(**kwargs) except BaseException as exc: - with ShieldCancellation(): - self.close() + sync_cancel_shield(self.close) raise exc self._sent_connection_init = True @@ -160,11 +159,12 @@ def handle_request(self, request: Request) -> Response: "stream_id": stream_id, }, ) - except BaseException as exc: # noqa: PIE786 - with ShieldCancellation(): - kwargs = {"stream_id": stream_id} - with Trace("response_closed", logger, request, kwargs): - self._response_closed(stream_id=stream_id) + except BaseException as exc: + kwargs = {"stream_id": stream_id} + with Trace("response_closed", logger, request, kwargs): + sync_cancel_shield( + lambda: self._response_closed(stream_id=stream_id) + ) if isinstance(exc, h2.exceptions.ProtocolError): # One case where h2 can raise a protocol error is when a @@ -577,8 +577,7 @@ def __iter__(self) -> typing.Iterator[bytes]: # If we get an exception while streaming the response, # we want to close the response (and possibly the connection) # before raising that exception. - with ShieldCancellation(): - self.close() + sync_cancel_shield(self.close) raise exc def close(self) -> None: diff --git a/httpcore/_synchronization.py b/httpcore/_synchronization.py index 9619a398..ccf22faf 100644 --- a/httpcore/_synchronization.py +++ b/httpcore/_synchronization.py @@ -1,10 +1,12 @@ +import asyncio +import sys import threading from types import TracebackType -from typing import Optional, Type +from typing import Any, Callable, Coroutine, Literal, Optional, Protocol, Type from ._exceptions import ExceptionMapping, PoolTimeout, map_exceptions -# Our async synchronization primatives use either 'anyio' or 'trio' depending +# Our async synchronization primitives use either 'asyncio' or 'trio' depending # on if they're running under asyncio or trio. try: @@ -18,30 +20,49 @@ anyio = None # type: ignore -def current_async_library() -> str: +if sys.version_info >= (3, 11): # pragma: nocover + import asyncio as asyncio_timeout +else: # pragma: nocover + import async_timeout as asyncio_timeout + + +AsyncBackend = Literal["asyncio", "trio"] + + +def current_async_backend() -> AsyncBackend: # Determine if we're running under trio or asyncio. # See https://sniffio.readthedocs.io/en/latest/ try: import sniffio except ImportError: # pragma: nocover - environment = "asyncio" + backend: AsyncBackend = "asyncio" else: - environment = sniffio.current_async_library() + backend = sniffio.current_async_library() # type: ignore[assignment] - if environment not in ("asyncio", "trio"): # pragma: nocover - raise RuntimeError("Running under an unsupported async environment.") + if backend not in ("asyncio", "trio"): # pragma: nocover + raise RuntimeError("Running under an unsupported async backend.") - if environment == "asyncio" and anyio is None: # pragma: nocover + if backend == "asyncio" and anyio is None: # pragma: nocover raise RuntimeError( "Running with asyncio requires installation of 'httpcore[asyncio]'." ) - if environment == "trio" and trio is None: # pragma: nocover + if backend == "trio" and trio is None: # pragma: nocover raise RuntimeError( "Running with trio requires installation of 'httpcore[trio]'." ) - return environment + return backend + + +class _LockProto(Protocol): + async def acquire(self) -> Any: ... + def release(self) -> None: ... + + +class _EventProto(Protocol): + def set(self) -> None: ... + async def wait(self) -> Any: ... class AsyncLock: @@ -53,28 +74,25 @@ class AsyncLock: """ def __init__(self) -> None: - self._backend = "" + self._lock: Optional[_LockProto] = None def setup(self) -> None: """ Detect if we're running under 'asyncio' or 'trio' and create a lock with the correct implementation. """ - self._backend = current_async_library() - if self._backend == "trio": - self._trio_lock = trio.Lock() - elif self._backend == "asyncio": - self._anyio_lock = anyio.Lock() + if current_async_backend() == "trio": + self._lock = trio.Lock() + else: + # Note: asyncio.Lock has better performance characteristics than anyio.Lock + # https://github.com/encode/httpx/issues/3215 + self._lock = asyncio.Lock() async def __aenter__(self) -> "AsyncLock": - if not self._backend: + if self._lock is None: self.setup() - - if self._backend == "trio": - await self._trio_lock.acquire() - elif self._backend == "asyncio": - await self._anyio_lock.acquire() - + lock: _LockProto = self._lock # type: ignore[assignment] + await lock.acquire() return self async def __aexit__( @@ -83,10 +101,8 @@ async def __aexit__( exc_value: Optional[BaseException] = None, traceback: Optional[TracebackType] = None, ) -> None: - if self._backend == "trio": - self._trio_lock.release() - elif self._backend == "asyncio": - self._anyio_lock.release() + lock: _LockProto = self._lock # type: ignore[assignment] + lock.release() class AsyncThreadLock: @@ -112,117 +128,95 @@ def __exit__( class AsyncEvent: def __init__(self) -> None: self._backend = "" + self._event: Optional[_EventProto] = None def setup(self) -> None: """ Detect if we're running under 'asyncio' or 'trio' and create a lock with the correct implementation. """ - self._backend = current_async_library() + self._backend = current_async_backend() if self._backend == "trio": - self._trio_event = trio.Event() - elif self._backend == "asyncio": - self._anyio_event = anyio.Event() + self._event = trio.Event() + else: + # Note: asyncio.Event has better performance characteristics than anyio.Event + self._event = asyncio.Event() def set(self) -> None: - if not self._backend: + if self._event is None: self.setup() - - if self._backend == "trio": - self._trio_event.set() - elif self._backend == "asyncio": - self._anyio_event.set() + event: _EventProto = self._event # type: ignore[assignment] + event.set() async def wait(self, timeout: Optional[float] = None) -> None: - if not self._backend: + if self._event is None: self.setup() + event: _EventProto = self._event # type: ignore[assignment] if self._backend == "trio": trio_exc_map: ExceptionMapping = {trio.TooSlowError: PoolTimeout} timeout_or_inf = float("inf") if timeout is None else timeout with map_exceptions(trio_exc_map): with trio.fail_after(timeout_or_inf): - await self._trio_event.wait() - elif self._backend == "asyncio": - anyio_exc_map: ExceptionMapping = {TimeoutError: PoolTimeout} - with map_exceptions(anyio_exc_map): - with anyio.fail_after(timeout): - await self._anyio_event.wait() + await event.wait() + else: + asyncio_exc_map: ExceptionMapping = { + asyncio.exceptions.TimeoutError: PoolTimeout + } + with map_exceptions(asyncio_exc_map): + async with asyncio_timeout.timeout(timeout): + await event.wait() class AsyncSemaphore: def __init__(self, bound: int) -> None: self._bound = bound - self._backend = "" + self._semaphore: Optional[_LockProto] = None def setup(self) -> None: """ Detect if we're running under 'asyncio' or 'trio' and create a semaphore with the correct implementation. """ - self._backend = current_async_library() - if self._backend == "trio": - self._trio_semaphore = trio.Semaphore( - initial_value=self._bound, max_value=self._bound - ) - elif self._backend == "asyncio": - self._anyio_semaphore = anyio.Semaphore( + if current_async_backend() == "trio": + self._semaphore = trio.Semaphore( initial_value=self._bound, max_value=self._bound ) + else: + # Note: asyncio.BoundedSemaphore has better performance characteristics than anyio.Semaphore + self._semaphore = asyncio.BoundedSemaphore(self._bound) async def acquire(self) -> None: - if not self._backend: + if self._semaphore is None: self.setup() - - if self._backend == "trio": - await self._trio_semaphore.acquire() - elif self._backend == "asyncio": - await self._anyio_semaphore.acquire() + semaphore: _LockProto = self._semaphore # type: ignore[assignment] + await semaphore.acquire() async def release(self) -> None: - if self._backend == "trio": - self._trio_semaphore.release() - elif self._backend == "asyncio": - self._anyio_semaphore.release() + semaphore: _LockProto = self._semaphore # type: ignore[assignment] + semaphore.release() -class AsyncShieldCancellation: - # For certain portions of our codebase where we're dealing with - # closing connections during exception handling we want to shield - # the operation from being cancelled. - # - # with AsyncShieldCancellation(): - # ... # clean-up operations, shielded from cancellation. - - def __init__(self) -> None: - """ - Detect if we're running under 'asyncio' or 'trio' and create - a shielded scope with the correct implementation. - """ - self._backend = current_async_library() - - if self._backend == "trio": - self._trio_shield = trio.CancelScope(shield=True) - elif self._backend == "asyncio": - self._anyio_shield = anyio.CancelScope(shield=True) - - def __enter__(self) -> "AsyncShieldCancellation": - if self._backend == "trio": - self._trio_shield.__enter__() - elif self._backend == "asyncio": - self._anyio_shield.__enter__() - return self - - def __exit__( - self, - exc_type: Optional[Type[BaseException]] = None, - exc_value: Optional[BaseException] = None, - traceback: Optional[TracebackType] = None, - ) -> None: - if self._backend == "trio": - self._trio_shield.__exit__(exc_type, exc_value, traceback) - elif self._backend == "asyncio": - self._anyio_shield.__exit__(exc_type, exc_value, traceback) +async def async_cancel_shield( + shielded: Callable[[], Coroutine[Any, Any, None]], +) -> None: + if current_async_backend() == "trio": + with trio.CancelScope(shield=True): + await shielded() + else: + inner_task = asyncio.create_task(shielded()) + retry = False + while True: + try: + await asyncio.shield(inner_task) + break + except asyncio.CancelledError: + if inner_task.done() or retry: + break + # We may get multiple cancellations. + # Retry once to get inner_task finished here by best effort. + retry = True + continue # Our thread-based synchronization primitives... @@ -301,17 +295,8 @@ def release(self) -> None: self._semaphore.release() -class ShieldCancellation: - # Thread-synchronous codebases don't support cancellation semantics. - # We have this class because we need to mirror the async and sync - # cases within our package, but it's just a no-op. - def __enter__(self) -> "ShieldCancellation": - return self - - def __exit__( - self, - exc_type: Optional[Type[BaseException]] = None, - exc_value: Optional[BaseException] = None, - traceback: Optional[TracebackType] = None, - ) -> None: - pass +# Thread-synchronous codebases don't support cancellation semantics. +# We have this class because we need to mirror the async and sync +# cases within our package, but it's just a no-op. +def sync_cancel_shield(fn: Callable[[], None]) -> None: + fn() diff --git a/pyproject.toml b/pyproject.toml index 85c78740..7a9fea16 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ classifiers = [ dependencies = [ "certifi", "h11>=0.13,<0.15", + "async-timeout==4.*; python_version < '3.11'", ] [project.optional-dependencies] diff --git a/scripts/unasync.py b/scripts/unasync.py index 5a5627d7..b81bc638 100644 --- a/scripts/unasync.py +++ b/scripts/unasync.py @@ -24,6 +24,7 @@ ('@pytest.mark.anyio', ''), ('@pytest.mark.trio', ''), ('AutoBackend', 'SyncBackend'), + ('async_cancel_shield', 'sync_cancel_shield'), ] COMPILED_SUBS = [ (re.compile(r'(^|\b)' + regex + r'($|\b)'), repl)