Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix cancellations being swallowed (#9030) #9257

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES/9030.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed (on Python 3.11+) some edge cases where a task cancellation may get incorrectly suppressed -- by :user:`Dreamsorcerer`.
37 changes: 29 additions & 8 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,11 +625,8 @@
"""Support coroutines that yields bytes objects."""
# 100 response
if self._continue is not None:
try:
await writer.drain()
await self._continue
except asyncio.CancelledError:
return
await writer.drain()
await self._continue

protocol = conn.protocol
assert protocol is not None
Expand Down Expand Up @@ -658,6 +655,7 @@
except asyncio.CancelledError:
# Body hasn't been fully sent, so connection can't be reused.
conn.close()
raise
except Exception as underlying_exc:
set_exception(
protocol,
Expand Down Expand Up @@ -764,8 +762,15 @@

async def close(self) -> None:
if self._writer is not None:
with contextlib.suppress(asyncio.CancelledError):
try:
await self._writer
except asyncio.CancelledError:
if (
sys.version_info >= (3, 11)
and (task := asyncio.current_task())
and task.cancelling()
):
raise

def terminate(self) -> None:
if self._writer is not None:
Expand Down Expand Up @@ -1113,7 +1118,15 @@

async def _wait_released(self) -> None:
if self._writer is not None:
await self._writer
try:
await self._writer
except asyncio.CancelledError:
if (
sys.version_info >= (3, 11)
and (task := asyncio.current_task())
and task.cancelling()
):
raise

Check warning on line 1129 in aiohttp/client_reqrep.py

View check run for this annotation

Codecov / codecov/patch

aiohttp/client_reqrep.py#L1129

Added line #L1129 was not covered by tests
self._release_connection()

def _cleanup_writer(self) -> None:
Expand All @@ -1129,7 +1142,15 @@

async def wait_for_close(self) -> None:
if self._writer is not None:
await self._writer
try:
await self._writer
except asyncio.CancelledError:
if (
sys.version_info >= (3, 11)
and (task := asyncio.current_task())
and task.cancelling()
):
raise

Check warning on line 1153 in aiohttp/client_reqrep.py

View check run for this annotation

Codecov / codecov/patch

aiohttp/client_reqrep.py#L1153

Added line #L1153 was not covered by tests
self.release()

async def read(self) -> bytes:
Expand Down
38 changes: 29 additions & 9 deletions aiohttp/web_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,17 +271,32 @@ async def shutdown(self, timeout: Optional[float] = 15.0) -> None:
# down while the handler is still processing a request
# to avoid creating a future for every request.
self._handler_waiter = self._loop.create_future()
with suppress(asyncio.CancelledError, asyncio.TimeoutError):
try:
async with ceil_timeout(timeout):
await self._handler_waiter
except (asyncio.CancelledError, asyncio.TimeoutError):
self._handler_waiter = None
if (
sys.version_info >= (3, 11)
and (task := asyncio.current_task())
and task.cancelling()
):
raise
# Then cancel handler and wait
with suppress(asyncio.CancelledError, asyncio.TimeoutError):
try:
async with ceil_timeout(timeout):
if self._current_request is not None:
self._current_request._cancel(asyncio.CancelledError())

if self._task_handler is not None and not self._task_handler.done():
await self._task_handler
await asyncio.shield(self._task_handler)
except (asyncio.CancelledError, asyncio.TimeoutError):
if (
sys.version_info >= (3, 11)
and (task := asyncio.current_task())
and task.cancelling()
):
raise

# force-close non-idle handler
if self._task_handler is not None:
Expand Down Expand Up @@ -517,8 +532,6 @@ async def start(self) -> None:
# wait for next request
self._waiter = loop.create_future()
await self._waiter
except asyncio.CancelledError:
break
finally:
self._waiter = None

Expand All @@ -545,7 +558,7 @@ async def start(self) -> None:
task = loop.create_task(coro)
try:
resp, reset = await task
except (asyncio.CancelledError, ConnectionError):
except ConnectionError:
self.log_debug("Ignored premature client disconnection")
break

Expand All @@ -569,12 +582,19 @@ async def start(self) -> None:
now = loop.time()
end_t = now + lingering_time

with suppress(asyncio.TimeoutError, asyncio.CancelledError):
try:
while not payload.is_eof() and now < end_t:
async with ceil_timeout(end_t - now):
# read and ignore
await payload.readany()
now = loop.time()
except (asyncio.CancelledError, asyncio.TimeoutError):
if (
sys.version_info >= (3, 11)
and (t := asyncio.current_task())
and t.cancelling()
):
raise

# if payload still uncompleted
if not payload.is_eof() and not self._force_close:
Expand All @@ -584,8 +604,8 @@ async def start(self) -> None:
payload.set_exception(_PAYLOAD_ACCESS_ERROR)

except asyncio.CancelledError:
self.log_debug("Ignored premature client disconnection ")
break
self.log_debug("Ignored premature client disconnection")
raise
except Exception as exc:
self.log_exception("Unhandled exception", exc_info=exc)
self.force_close()
Expand Down
19 changes: 18 additions & 1 deletion tests/test_client_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import hashlib
import io
import pathlib
import sys
import urllib.parse
import zlib
from http.cookies import BaseCookie, Morsel, SimpleCookie
Expand Down Expand Up @@ -1213,7 +1214,23 @@ async def test_oserror_on_write_bytes(loop, conn) -> None:
await req.close()


async def test_terminate(loop, conn) -> None:
@pytest.mark.skipif(sys.version_info < (3, 11), reason="Needs Task.cancelling()")
async def test_cancel_close(loop: asyncio.AbstractEventLoop, conn: mock.Mock) -> None:
req = ClientRequest("get", URL("http://python.org"), loop=loop)
req._writer = asyncio.Future() # type: ignore[assignment]

t = asyncio.create_task(req.close())

# Start waiting on _writer
await asyncio.sleep(0)

t.cancel()
# Cancellation should not be suppressed.
with pytest.raises(asyncio.CancelledError):
await t


async def test_terminate(loop: asyncio.AbstractEventLoop, conn: mock.Mock) -> None:
req = ClientRequest("get", URL("http://python.org"), loop=loop)

async def _mock_write_bytes(*args, **kwargs):
Expand Down
40 changes: 38 additions & 2 deletions tests/test_web_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import pathlib
import socket
import sys
import zlib
from typing import Any, NoReturn, Optional
from unittest import mock
Expand All @@ -22,6 +23,7 @@
web,
)
from aiohttp.hdrs import CONTENT_LENGTH, CONTENT_TYPE, TRANSFER_ENCODING
from aiohttp.pytest_plugin import AiohttpClient
from aiohttp.test_utils import make_mocked_coro
from aiohttp.typedefs import Handler
from aiohttp.web_protocol import RequestHandler
Expand Down Expand Up @@ -187,8 +189,42 @@ async def handler(request):
await resp.release()


async def test_post_form(aiohttp_client) -> None:
async def handler(request):
@pytest.mark.skipif(sys.version_info < (3, 11), reason="Needs Task.cancelling()")
async def test_cancel_shutdown(aiohttp_client: AiohttpClient) -> None:
async def handler(request: web.Request) -> web.Response:
t = asyncio.create_task(request.protocol.shutdown())
# Ensure it's started waiting
await asyncio.sleep(0)

t.cancel()
# Cancellation should not be suppressed
with pytest.raises(asyncio.CancelledError):
await t

# Repeat for second waiter in shutdown()
with mock.patch.object(request.protocol, "_request_in_progress", False):
with mock.patch.object(request.protocol, "_current_request", None):
t = asyncio.create_task(request.protocol.shutdown())
await asyncio.sleep(0)

t.cancel()
with pytest.raises(asyncio.CancelledError):
await t

return web.Response(body=b"OK")

app = web.Application()
app.router.add_get("/", handler)
client = await aiohttp_client(app)

async with client.get("/") as resp:
assert resp.status == 200
txt = await resp.text()
assert txt == "OK"


async def test_post_form(aiohttp_client: AiohttpClient) -> None:
async def handler(request: web.Request) -> web.Response:
data = await request.post()
assert {"a": "1", "b": "2", "c": ""} == data
return web.Response(body=b"OK")
Expand Down
Loading