diff --git a/CHANGES/3892.feature b/CHANGES/3892.feature new file mode 100644 index 00000000000..9707d253046 --- /dev/null +++ b/CHANGES/3892.feature @@ -0,0 +1 @@ +allow ``raise_for_status`` to be a coroutine diff --git a/aiohttp/client.py b/aiohttp/client.py index 6dcff5a291b..0765cd8d4cf 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -11,6 +11,8 @@ from types import SimpleNamespace, TracebackType from typing import ( # noqa Any, + Awaitable, + Callable, Coroutine, Generator, Generic, @@ -194,7 +196,7 @@ def __init__(self, *, connector: Optional[BaseConnector]=None, version: HttpVersion=http.HttpVersion11, cookie_jar: Optional[AbstractCookieJar]=None, connector_owner: bool=True, - raise_for_status: bool=False, + raise_for_status: Union[bool, Callable[[ClientResponse], Awaitable[None]]]=False, # noqa read_timeout: Union[float, object]=sentinel, conn_timeout: Optional[float]=None, timeout: Union[object, ClientTimeout]=sentinel, @@ -336,7 +338,7 @@ async def _request( compress: Optional[str]=None, chunked: Optional[bool]=None, expect100: bool=False, - raise_for_status: Optional[bool]=None, + raise_for_status: Union[None, bool, Callable[[ClientResponse], Awaitable[None]]]=None, # noqa read_until_eof: bool=True, proxy: Optional[StrOrURL]=None, proxy_auth: Optional[BasicAuth]=None, @@ -584,7 +586,12 @@ async def _request( # check response status if raise_for_status is None: raise_for_status = self._raise_for_status - if raise_for_status: + + if raise_for_status is None: + pass + elif callable(raise_for_status): + await raise_for_status(resp) + elif raise_for_status: resp.raise_for_status() # register connection diff --git a/docs/client_reference.rst b/docs/client_reference.rst index 4378540d130..2a2d28aa35a 100644 --- a/docs/client_reference.rst +++ b/docs/client_reference.rst @@ -128,6 +128,22 @@ The client session supports the context manager protocol for self closing. requests where you need to handle responses with status 400 or higher. + You can also provide a coroutine which takes the response as an + argument and can raise an exception based on custom logic, e.g.:: + + async def custom_check(response): + if response.status not in {201, 202}: + raise RuntimeError('expected either 201 or 202') + text = await response.text() + if 'apple pie' not in text: + raise RuntimeError('I wanted to see "apple pie" in response') + + client_session = aiohttp.ClientSession(raise_for_status=custom_check) + ... + + As with boolean values, you're free to set this on the session and/or + overwrite it on a per-request basis. + :param timeout: a :class:`ClientTimeout` settings structure, 5min total timeout by default. diff --git a/tests/test_client_functional.py b/tests/test_client_functional.py index d860663a58d..a611809ca75 100644 --- a/tests/test_client_functional.py +++ b/tests/test_client_functional.py @@ -2333,6 +2333,56 @@ async def handler(request): assert False, "never executed" # pragma: no cover +async def test_session_raise_for_status_coro(aiohttp_client) -> None: + + async def handle(request): + return web.Response(text='ok') + + app = web.Application() + app.router.add_route('GET', '/', handle) + + raise_for_status_called = 0 + + async def custom_r4s(response): + nonlocal raise_for_status_called + raise_for_status_called += 1 + assert response.status == 200 + assert response.request_info.method == 'GET' + + client = await aiohttp_client(app, raise_for_status=custom_r4s) + await client.get('/') + assert raise_for_status_called == 1 + await client.get('/', raise_for_status=True) + assert raise_for_status_called == 1 # custom_r4s not called again + await client.get('/', raise_for_status=False) + assert raise_for_status_called == 1 # custom_r4s not called again + + +async def test_request_raise_for_status_coro(aiohttp_client) -> None: + + async def handle(request): + return web.Response(text='ok') + + app = web.Application() + app.router.add_route('GET', '/', handle) + + raise_for_status_called = 0 + + async def custom_r4s(response): + nonlocal raise_for_status_called + raise_for_status_called += 1 + assert response.status == 200 + assert response.request_info.method == 'GET' + + client = await aiohttp_client(app) + await client.get('/', raise_for_status=custom_r4s) + assert raise_for_status_called == 1 + await client.get('/', raise_for_status=True) + assert raise_for_status_called == 1 # custom_r4s not called again + await client.get('/', raise_for_status=False) + assert raise_for_status_called == 1 # custom_r4s not called again + + async def test_invalid_idna() -> None: session = aiohttp.ClientSession() try: