Skip to content

Commit

Permalink
Ensure writer is always reset on completion (#7815)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dreamsorcerer authored Nov 12, 2023
1 parent 366ba40 commit 8f2f048
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 34 deletions.
1 change: 1 addition & 0 deletions CHANGES/7815.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed an issue where the client could go into an infinite loop. -- by :user:`Dreamsorcerer`
72 changes: 48 additions & 24 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,13 @@
reify,
set_result,
)
from .http import SERVER_SOFTWARE, HttpVersion10, HttpVersion11, StreamWriter
from .http import (
SERVER_SOFTWARE,
HttpVersion,
HttpVersion10,
HttpVersion11,
StreamWriter,
)
from .log import client_logger
from .streams import StreamReader
from .typedefs import (
Expand Down Expand Up @@ -178,7 +184,7 @@ class ClientRequest:
auth = None
response = None

_writer = None # async task for streaming data
__writer = None # async task for streaming data
_continue = None # waiter future for '100 Continue' response

# N.B.
Expand Down Expand Up @@ -265,6 +271,21 @@ def __init__(
traces = []
self._traces = traces

def __reset_writer(self, _: object = None) -> None:
self.__writer = None

@property
def _writer(self) -> Optional["asyncio.Task[None]"]:
return self.__writer

@_writer.setter
def _writer(self, writer: Optional["asyncio.Task[None]"]) -> None:
if self.__writer is not None:
self.__writer.remove_done_callback(self.__reset_writer)
self.__writer = writer
if writer is not None:
writer.add_done_callback(self.__reset_writer)

def is_ssl(self) -> bool:
return self.url.scheme in ("https", "wss")

Expand Down Expand Up @@ -563,8 +584,6 @@ async def write_bytes(
else:
await writer.write_eof()
protocol.start_timeout()
finally:
self._writer = None

async def send(self, conn: "Connection") -> "ClientResponse":
# Specify request target:
Expand Down Expand Up @@ -649,16 +668,14 @@ async def send(self, conn: "Connection") -> "ClientResponse":

async def close(self) -> None:
if self._writer is not None:
try:
with contextlib.suppress(asyncio.CancelledError):
await self._writer
finally:
self._writer = None
with contextlib.suppress(asyncio.CancelledError):
await self._writer

def terminate(self) -> None:
if self._writer is not None:
if not self.loop.is_closed():
self._writer.cancel()
self._writer.remove_done_callback(self.__reset_writer)
self._writer = None

async def _on_chunk_request_sent(self, method: str, url: URL, chunk: bytes) -> None:
Expand All @@ -677,9 +694,9 @@ class ClientResponse(HeadersMixin):
# but will be set by the start() method.
# As the end user will likely never see the None values, we cheat the types below.
# from the Status-Line of the response
version = None # HTTP-Version
version: Optional[HttpVersion] = None # HTTP-Version
status: int = None # type: ignore[assignment] # Status-Code
reason = None # Reason-Phrase
reason: Optional[str] = None # Reason-Phrase

content: StreamReader = None # type: ignore[assignment] # Payload stream
_headers: CIMultiDictProxy[str] = None # type: ignore[assignment]
Expand All @@ -691,6 +708,7 @@ class ClientResponse(HeadersMixin):
# post-init stage allows to not change ctor signature
_closed = True # to allow __del__ for non-initialized properly response
_released = False
__writer = None

def __init__(
self,
Expand Down Expand Up @@ -737,6 +755,21 @@ def __init__(
if loop.get_debug():
self._source_traceback = traceback.extract_stack(sys._getframe(1))

def __reset_writer(self, _: object = None) -> None:
self.__writer = None

@property
def _writer(self) -> Optional["asyncio.Task[None]"]:
return self.__writer

@_writer.setter
def _writer(self, writer: Optional["asyncio.Task[None]"]) -> None:
if self.__writer is not None:
self.__writer.remove_done_callback(self.__reset_writer)
self.__writer = writer
if writer is not None:
writer.add_done_callback(self.__reset_writer)

@reify
def url(self) -> URL:
return self._url
Expand Down Expand Up @@ -797,7 +830,7 @@ def __repr__(self) -> str:
"ascii", "backslashreplace"
).decode("ascii")
else:
ascii_encodable_reason = self.reason
ascii_encodable_reason = "None"
print(
"<ClientResponse({}) [{} {}]>".format(
ascii_encodable_url, self.status, ascii_encodable_reason
Expand Down Expand Up @@ -978,18 +1011,12 @@ def _release_connection(self) -> None:

async def _wait_released(self) -> None:
if self._writer is not None:
try:
await self._writer
finally:
self._writer = None
await self._writer
self._release_connection()

def _cleanup_writer(self) -> None:
if self._writer is not None:
if self._writer.done():
self._writer = None
else:
self._writer.cancel()
self._writer.cancel()
self._session = None

def _notify_content(self) -> None:
Expand All @@ -1001,10 +1028,7 @@ def _notify_content(self) -> None:

async def wait_for_close(self) -> None:
if self._writer is not None:
try:
await self._writer
finally:
self._writer = None
await self._writer
self.release()

async def read(self) -> bytes:
Expand Down
5 changes: 4 additions & 1 deletion tests/test_client_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import gc
import sys
from json import JSONDecodeError
from typing import Any
from typing import Any, Callable
from unittest import mock

import pytest
Expand All @@ -22,6 +22,9 @@ class WriterMock(mock.AsyncMock):
def __await__(self) -> None:
return self().__await__()

def add_done_callback(self, cb: Callable[[], None]) -> None:
cb()

def done(self) -> bool:
return True

Expand Down
18 changes: 9 additions & 9 deletions tests/test_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def test_proxy_server_hostname_default(self, ClientRequestMock) -> None:
"get",
URL("http://proxy.example.com"),
request_info=mock.Mock(),
writer=mock.Mock(),
writer=None,
continue100=None,
timer=TimerNoop(),
traces=[],
Expand Down Expand Up @@ -261,7 +261,7 @@ def test_proxy_server_hostname_override(self, ClientRequestMock) -> None:
"get",
URL("http://proxy.example.com"),
request_info=mock.Mock(),
writer=mock.Mock(),
writer=None,
continue100=None,
timer=TimerNoop(),
traces=[],
Expand Down Expand Up @@ -323,7 +323,7 @@ def test_https_connect(self, ClientRequestMock: Any) -> None:
"get",
URL("http://proxy.example.com"),
request_info=mock.Mock(),
writer=mock.Mock(),
writer=None,
continue100=None,
timer=TimerNoop(),
traces=[],
Expand Down Expand Up @@ -383,7 +383,7 @@ def test_https_connect_certificate_error(self, ClientRequestMock: Any) -> None:
"get",
URL("http://proxy.example.com"),
request_info=mock.Mock(),
writer=mock.Mock(),
writer=None,
continue100=None,
timer=TimerNoop(),
traces=[],
Expand Down Expand Up @@ -437,7 +437,7 @@ def test_https_connect_ssl_error(self, ClientRequestMock: Any) -> None:
"get",
URL("http://proxy.example.com"),
request_info=mock.Mock(),
writer=mock.Mock(),
writer=None,
continue100=None,
timer=TimerNoop(),
traces=[],
Expand Down Expand Up @@ -493,7 +493,7 @@ def test_https_connect_http_proxy_error(self, ClientRequestMock: Any) -> None:
"get",
URL("http://proxy.example.com"),
request_info=mock.Mock(),
writer=mock.Mock(),
writer=None,
continue100=None,
timer=TimerNoop(),
traces=[],
Expand Down Expand Up @@ -552,7 +552,7 @@ def test_https_connect_resp_start_error(self, ClientRequestMock: Any) -> None:
"get",
URL("http://proxy.example.com"),
request_info=mock.Mock(),
writer=mock.Mock(),
writer=None,
continue100=None,
timer=TimerNoop(),
traces=[],
Expand Down Expand Up @@ -663,7 +663,7 @@ def test_https_connect_pass_ssl_context(self, ClientRequestMock: Any) -> None:
"get",
URL("http://proxy.example.com"),
request_info=mock.Mock(),
writer=mock.Mock(),
writer=None,
continue100=None,
timer=TimerNoop(),
traces=[],
Expand Down Expand Up @@ -734,7 +734,7 @@ def test_https_auth(self, ClientRequestMock: Any) -> None:
"get",
URL("http://proxy.example.com"),
request_info=mock.Mock(),
writer=mock.Mock(),
writer=None,
continue100=None,
timer=TimerNoop(),
traces=[],
Expand Down

0 comments on commit 8f2f048

Please sign in to comment.