diff --git a/httpcore/_async/connection_pool.py b/httpcore/_async/connection_pool.py index 018b0ba2..5cbd3d22 100644 --- a/httpcore/_async/connection_pool.py +++ b/httpcore/_async/connection_pool.py @@ -7,7 +7,12 @@ from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol from .._models import Origin, Request, Response -from .._synchronization import AsyncEvent, AsyncShieldCancellation, AsyncThreadLock +from .._synchronization import ( + EXCEPTION_OR_CANCELLED, + AsyncEvent, + AsyncShieldCancellation, + AsyncThreadLock, +) from .connection import AsyncHTTPConnection from .interfaces import AsyncConnectionInterface, AsyncRequestInterface @@ -205,7 +210,7 @@ async def handle_async_request(self, request: Request) -> Response: else: break # pragma: nocover - except BaseException as exc: + except EXCEPTION_OR_CANCELLED as exc: with self._optional_thread_lock: # For any exception or cancellation we remove the request from # the queue, and then re-assign requests to connections. diff --git a/httpcore/_async/http11.py b/httpcore/_async/http11.py index 0493a923..cfffd27e 100644 --- a/httpcore/_async/http11.py +++ b/httpcore/_async/http11.py @@ -25,7 +25,11 @@ map_exceptions, ) from .._models import Origin, Request, Response -from .._synchronization import AsyncLock, AsyncShieldCancellation +from .._synchronization import ( + EXCEPTION_OR_CANCELLED, + AsyncLock, + AsyncShieldCancellation, +) from .._trace import Trace from .interfaces import AsyncConnectionInterface @@ -136,7 +140,7 @@ async def handle_async_request(self, request: Request) -> Response: "network_stream": network_stream, }, ) - except BaseException as exc: + except EXCEPTION_OR_CANCELLED as exc: with AsyncShieldCancellation(): async with Trace("response_closed", logger, request) as trace: await self._response_closed() @@ -340,7 +344,7 @@ async def __aiter__(self) -> AsyncIterator[bytes]: async with Trace("receive_response_body", logger, self._request, kwargs): async for chunk in self._connection._receive_response_body(**kwargs): yield chunk - except BaseException as exc: + except EXCEPTION_OR_CANCELLED as exc: # If we get an exception while streaming the response, # we want to close the response (and possibly the connection) # before raising that exception. diff --git a/httpcore/_async/http2.py b/httpcore/_async/http2.py index c201ee4c..85d2997a 100644 --- a/httpcore/_async/http2.py +++ b/httpcore/_async/http2.py @@ -17,7 +17,12 @@ RemoteProtocolError, ) from .._models import Origin, Request, Response -from .._synchronization import AsyncLock, AsyncSemaphore, AsyncShieldCancellation +from .._synchronization import ( + EXCEPTION_OR_CANCELLED, + AsyncLock, + AsyncSemaphore, + AsyncShieldCancellation, +) from .._trace import Trace from .interfaces import AsyncConnectionInterface @@ -107,7 +112,7 @@ async def handle_async_request(self, request: Request) -> Response: kwargs = {"request": request} async with Trace("send_connection_init", logger, request, kwargs): await self._send_connection_init(**kwargs) - except BaseException as exc: + except EXCEPTION_OR_CANCELLED as exc: with AsyncShieldCancellation(): await self.aclose() raise exc @@ -160,7 +165,7 @@ async def handle_async_request(self, request: Request) -> Response: "stream_id": stream_id, }, ) - except BaseException as exc: # noqa: PIE786 + except EXCEPTION_OR_CANCELLED as exc: # noqa: PIE786 with AsyncShieldCancellation(): kwargs = {"stream_id": stream_id} async with Trace("response_closed", logger, request, kwargs): @@ -573,7 +578,7 @@ async def __aiter__(self) -> typing.AsyncIterator[bytes]: request=self._request, stream_id=self._stream_id ): yield chunk - except BaseException as exc: + except EXCEPTION_OR_CANCELLED as exc: # If we get an exception while streaming the response, # we want to close the response (and possibly the connection) # before raising that exception. diff --git a/httpcore/_sync/connection_pool.py b/httpcore/_sync/connection_pool.py index 8dcf348c..4cff69aa 100644 --- a/httpcore/_sync/connection_pool.py +++ b/httpcore/_sync/connection_pool.py @@ -7,7 +7,12 @@ from .._backends.base import SOCKET_OPTION, NetworkBackend from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol from .._models import Origin, Request, Response -from .._synchronization import Event, ShieldCancellation, ThreadLock +from .._synchronization import ( + EXCEPTION_OR_CANCELLED, + Event, + ShieldCancellation, + ThreadLock, +) from .connection import HTTPConnection from .interfaces import ConnectionInterface, RequestInterface @@ -205,7 +210,7 @@ def handle_request(self, request: Request) -> Response: else: break # pragma: nocover - except BaseException as exc: + except EXCEPTION_OR_CANCELLED as exc: with self._optional_thread_lock: # For any exception or cancellation we remove the request from # the queue, and then re-assign requests to connections. diff --git a/httpcore/_sync/http11.py b/httpcore/_sync/http11.py index a74ff8e8..0edfa5e9 100644 --- a/httpcore/_sync/http11.py +++ b/httpcore/_sync/http11.py @@ -25,7 +25,11 @@ map_exceptions, ) from .._models import Origin, Request, Response -from .._synchronization import Lock, ShieldCancellation +from .._synchronization import ( + EXCEPTION_OR_CANCELLED, + Lock, + ShieldCancellation, +) from .._trace import Trace from .interfaces import ConnectionInterface @@ -136,7 +140,7 @@ def handle_request(self, request: Request) -> Response: "network_stream": network_stream, }, ) - except BaseException as exc: + except EXCEPTION_OR_CANCELLED as exc: with ShieldCancellation(): with Trace("response_closed", logger, request) as trace: self._response_closed() @@ -340,7 +344,7 @@ def __iter__(self) -> Iterator[bytes]: with Trace("receive_response_body", logger, self._request, kwargs): for chunk in self._connection._receive_response_body(**kwargs): yield chunk - except BaseException as exc: + except EXCEPTION_OR_CANCELLED as exc: # If we get an exception while streaming the response, # we want to close the response (and possibly the connection) # before raising that exception. diff --git a/httpcore/_sync/http2.py b/httpcore/_sync/http2.py index 1ee4bbb3..84a12865 100644 --- a/httpcore/_sync/http2.py +++ b/httpcore/_sync/http2.py @@ -17,7 +17,12 @@ RemoteProtocolError, ) from .._models import Origin, Request, Response -from .._synchronization import Lock, Semaphore, ShieldCancellation +from .._synchronization import ( + EXCEPTION_OR_CANCELLED, + Lock, + Semaphore, + ShieldCancellation, +) from .._trace import Trace from .interfaces import ConnectionInterface @@ -107,7 +112,7 @@ def handle_request(self, request: Request) -> Response: kwargs = {"request": request} with Trace("send_connection_init", logger, request, kwargs): self._send_connection_init(**kwargs) - except BaseException as exc: + except EXCEPTION_OR_CANCELLED as exc: with ShieldCancellation(): self.close() raise exc @@ -160,7 +165,7 @@ def handle_request(self, request: Request) -> Response: "stream_id": stream_id, }, ) - except BaseException as exc: # noqa: PIE786 + except EXCEPTION_OR_CANCELLED as exc: # noqa: PIE786 with ShieldCancellation(): kwargs = {"stream_id": stream_id} with Trace("response_closed", logger, request, kwargs): @@ -573,7 +578,7 @@ def __iter__(self) -> typing.Iterator[bytes]: request=self._request, stream_id=self._stream_id ): yield chunk - except BaseException as exc: + except EXCEPTION_OR_CANCELLED as exc: # If we get an exception while streaming the response, # we want to close the response (and possibly the connection) # before raising that exception. diff --git a/httpcore/_synchronization.py b/httpcore/_synchronization.py index 9619a398..375a9d8e 100644 --- a/httpcore/_synchronization.py +++ b/httpcore/_synchronization.py @@ -1,19 +1,30 @@ import threading from types import TracebackType -from typing import Optional, Type +from typing import Optional, Tuple, Type from ._exceptions import ExceptionMapping, PoolTimeout, map_exceptions +EXCEPTION_OR_CANCELLED: Tuple[Type[BaseException], ...] = (Exception,) + # Our async synchronization primatives use either 'anyio' or 'trio' depending # on if they're running under asyncio or trio. try: import trio + + EXCEPTION_OR_CANCELLED += (trio.Cancelled,) except ImportError: # pragma: nocover trio = None # type: ignore try: import anyio + + try: + import asyncio + + EXCEPTION_OR_CANCELLED += (asyncio.CancelledError,) + except ImportError: # pragma: nocover + pass except ImportError: # pragma: nocover anyio = None # type: ignore