Skip to content

Commit

Permalink
[PR #9340/8a97e03 backport][3.10] Use dunder writer internally in Cli…
Browse files Browse the repository at this point in the history
…entResponse (#9341)
  • Loading branch information
bdraco authored Sep 29, 2024
1 parent 873fad9 commit 523c4ea
Showing 1 changed file with 25 additions and 21 deletions.
46 changes: 25 additions & 21 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,7 +762,7 @@ async def send(self, conn: "Connection") -> "ClientResponse":
self.response = response_class(
self.method,
self.original_url,
writer=self._writer,
writer=task,
continue100=self._continue,
timer=self._timer,
request_info=self.request_info,
Expand All @@ -773,9 +773,9 @@ async def send(self, conn: "Connection") -> "ClientResponse":
return self.response

async def close(self) -> None:
if self._writer is not None:
if self.__writer is not None:
try:
await self._writer
await self.__writer
except asyncio.CancelledError:
if (
sys.version_info >= (3, 11)
Expand All @@ -785,11 +785,11 @@ async def close(self) -> None:
raise

def terminate(self) -> None:
if self._writer is not None:
if self.__writer is not None:
if not self.loop.is_closed():
self._writer.cancel()
self._writer.remove_done_callback(self.__reset_writer)
self._writer = None
self.__writer.cancel()
self.__writer.remove_done_callback(self.__reset_writer)
self.__writer = None

async def _on_chunk_request_sent(self, method: str, url: URL, chunk: bytes) -> None:
for trace in self._traces:
Expand Down Expand Up @@ -845,8 +845,8 @@ def __init__(

self._real_url = url
self._url = url.with_fragment(None)
self._body: Any = None
self._writer: Optional[asyncio.Task[None]] = writer
self._body: Optional[bytes] = None
self._writer = writer
self._continue = continue100 # None by default
self._closed = True
self._history: Tuple[ClientResponse, ...] = ()
Expand Down Expand Up @@ -874,10 +874,16 @@ def __reset_writer(self, _: object = None) -> None:

@property
def _writer(self) -> Optional["asyncio.Task[None]"]:
"""The writer task for streaming data.
_writer is only provided for backwards compatibility
for subclasses that may need to access it.
"""
return self.__writer

@_writer.setter
def _writer(self, writer: Optional["asyncio.Task[None]"]) -> None:
"""Set the writer task for streaming data."""
if self.__writer is not None:
self.__writer.remove_done_callback(self.__reset_writer)
self.__writer = writer
Expand Down Expand Up @@ -1128,16 +1134,16 @@ def raise_for_status(self) -> None:

def _release_connection(self) -> None:
if self._connection is not None:
if self._writer is None:
if self.__writer is None:
self._connection.release()
self._connection = None
else:
self._writer.add_done_callback(lambda f: self._release_connection())
self.__writer.add_done_callback(lambda f: self._release_connection())

async def _wait_released(self) -> None:
if self._writer is not None:
if self.__writer is not None:
try:
await self._writer
await self.__writer
except asyncio.CancelledError:
if (
sys.version_info >= (3, 11)
Expand All @@ -1148,8 +1154,8 @@ async def _wait_released(self) -> None:
self._release_connection()

def _cleanup_writer(self) -> None:
if self._writer is not None:
self._writer.cancel()
if self.__writer is not None:
self.__writer.cancel()
self._session = None

def _notify_content(self) -> None:
Expand All @@ -1159,9 +1165,9 @@ def _notify_content(self) -> None:
self._released = True

async def wait_for_close(self) -> None:
if self._writer is not None:
if self.__writer is not None:
try:
await self._writer
await self.__writer
except asyncio.CancelledError:
if (
sys.version_info >= (3, 11)
Expand Down Expand Up @@ -1189,7 +1195,7 @@ async def read(self) -> bytes:
protocol = self._connection and self._connection.protocol
if protocol is None or not protocol.upgraded:
await self._wait_released() # Underlying connection released
return self._body # type: ignore[no-any-return]
return self._body

def get_encoding(self) -> str:
ctype = self.headers.get(hdrs.CONTENT_TYPE, "").lower()
Expand Down Expand Up @@ -1222,9 +1228,7 @@ async def text(self, encoding: Optional[str] = None, errors: str = "strict") ->
if encoding is None:
encoding = self.get_encoding()

return self._body.decode( # type: ignore[no-any-return,union-attr]
encoding, errors=errors
)
return self._body.decode(encoding, errors=errors) # type: ignore[union-attr]

async def json(
self,
Expand Down

0 comments on commit 523c4ea

Please sign in to comment.