Skip to content

Commit

Permalink
Fix ConnectionResetError not being raised when the transport is close… (
Browse files Browse the repository at this point in the history
#7199)

Co-authored-by: J. Nick Koston <nick@koston.org>
  • Loading branch information
Dreamsorcerer and bdraco authored Feb 11, 2023
1 parent 565cc21 commit 28854a4
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 13 deletions.
1 change: 1 addition & 0 deletions CHANGES/7180.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
``ConnectionResetError`` will always be raised when ``StreamWriter.write`` is called after ``connection_lost`` has been called on the ``BaseProtocol``
9 changes: 6 additions & 3 deletions aiohttp/base_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,15 @@ def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
self._loop: asyncio.AbstractEventLoop = loop
self._paused = False
self._drain_waiter: Optional[asyncio.Future[None]] = None
self._connection_lost = False
self._reading_paused = False

self.transport: Optional[asyncio.Transport] = None

@property
def connected(self) -> bool:
"""Return True if the connection is open."""
return self.transport is not None

def pause_writing(self) -> None:
assert not self._paused
self._paused = True
Expand Down Expand Up @@ -59,7 +63,6 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None:
self.transport = tr

def connection_lost(self, exc: Optional[BaseException]) -> None:
self._connection_lost = True
# Wake up the writer if currently paused.
self.transport = None
if not self._paused:
Expand All @@ -76,7 +79,7 @@ def connection_lost(self, exc: Optional[BaseException]) -> None:
waiter.set_exception(exc)

async def _drain_helper(self) -> None:
if self._connection_lost:
if not self.connected:
raise ConnectionResetError("Connection lost")
if not self._paused:
return
Expand Down
10 changes: 4 additions & 6 deletions aiohttp/http_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def __init__(
on_headers_sent: _T_OnHeadersSent = None,
) -> None:
self._protocol = protocol
self._transport = protocol.transport

self.loop = loop
self.length = None
Expand All @@ -52,7 +51,7 @@ def __init__(

@property
def transport(self) -> Optional[asyncio.Transport]:
return self._transport
return self._protocol.transport

@property
def protocol(self) -> BaseProtocol:
Expand All @@ -71,10 +70,10 @@ def _write(self, chunk: bytes) -> None:
size = len(chunk)
self.buffer_size += size
self.output_size += size

if self._transport is None or self._transport.is_closing():
transport = self.transport
if not self._protocol.connected or transport is None or transport.is_closing():
raise ConnectionResetError("Cannot write to closing transport")
self._transport.write(chunk)
transport.write(chunk)

async def write(
self, chunk: bytes, *, drain: bool = True, LIMIT: int = 0x10000
Expand Down Expand Up @@ -159,7 +158,6 @@ async def write_eof(self, chunk: bytes = b"") -> None:
await self.drain()

self._eof = True
self._transport = None

async def drain(self) -> None:
"""Flush the write buffer.
Expand Down
8 changes: 4 additions & 4 deletions tests/test_base_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,22 +45,22 @@ async def test_connection_lost_not_paused() -> None:
pr = BaseProtocol(loop=loop)
tr = mock.Mock()
pr.connection_made(tr)
assert not pr._connection_lost
assert pr.connected
pr.connection_lost(None)
assert pr.transport is None
assert pr._connection_lost
assert not pr.connected


async def test_connection_lost_paused_without_waiter() -> None:
loop = asyncio.get_event_loop()
pr = BaseProtocol(loop=loop)
tr = mock.Mock()
pr.connection_made(tr)
assert not pr._connection_lost
assert pr.connected
pr.pause_writing()
pr.connection_lost(None)
assert pr.transport is None
assert pr._connection_lost
assert not pr.connected


async def test_drain_lost() -> None:
Expand Down
14 changes: 14 additions & 0 deletions tests/test_client_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,17 @@ async def test_eof_received(loop) -> None:
assert proto._read_timeout_handle is not None
proto.eof_received()
assert proto._read_timeout_handle is None


async def test_connection_lost_sets_transport_to_none(loop, mocker) -> None:
"""Ensure that the transport is set to None when the connection is lost.
This ensures the writer knows that the connection is closed.
"""
proto = ResponseHandler(loop=loop)
proto.connection_made(mocker.Mock())
assert proto.transport is not None

proto.connection_lost(OSError())

assert proto.transport is None
15 changes: 15 additions & 0 deletions tests/test_http_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,21 @@ async def test_write_to_closing_transport(protocol, transport, loop) -> None:
await msg.write(b"After closing")


async def test_write_to_closed_transport(protocol, transport, loop) -> None:
"""Test that writing to a closed transport raises ConnectionResetError.
The StreamWriter checks to see if protocol.transport is None before
writing to the transport. If it is None, it raises ConnectionResetError.
"""
msg = http.StreamWriter(protocol, loop)

await msg.write(b"Before transport close")
protocol.transport = None

with pytest.raises(ConnectionResetError, match="Cannot write to closing transport"):
await msg.write(b"After transport closed")


async def test_drain(protocol, transport, loop) -> None:
msg = http.StreamWriter(protocol, loop)
await msg.drain()
Expand Down

0 comments on commit 28854a4

Please sign in to comment.