Skip to content

Commit

Permalink
Add and use ClientConnectionResetError (#9137)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dreamsorcerer authored Sep 18, 2024
1 parent b93ef57 commit f95bcaf
Show file tree
Hide file tree
Showing 11 changed files with 48 additions and 18 deletions.
2 changes: 2 additions & 0 deletions CHANGES/9137.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Added :exc:`aiohttp.ClientConnectionResetError`. Client code that previously threw :exc:`ConnectionResetError`
will now throw this -- by :user:`Dreamsorcerer`.
2 changes: 2 additions & 0 deletions aiohttp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .client import (
BaseConnector,
ClientConnectionError,
ClientConnectionResetError,
ClientConnectorCertificateError,
ClientConnectorError,
ClientConnectorSSLError,
Expand Down Expand Up @@ -117,6 +118,7 @@
# client
"BaseConnector",
"ClientConnectionError",
"ClientConnectionResetError",
"ClientConnectorCertificateError",
"ClientConnectorError",
"ClientConnectorSSLError",
Expand Down
3 changes: 2 additions & 1 deletion aiohttp/base_protocol.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
from typing import Optional, cast

from .client_exceptions import ClientConnectionResetError
from .helpers import set_exception
from .tcp_helpers import tcp_nodelay

Expand Down Expand Up @@ -85,7 +86,7 @@ def connection_lost(self, exc: Optional[BaseException]) -> None:

async def _drain_helper(self) -> None:
if not self.connected:
raise ConnectionResetError("Connection lost")
raise ClientConnectionResetError("Connection lost")
if not self._paused:
return
waiter = self._drain_waiter
Expand Down
2 changes: 2 additions & 0 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from .abc import AbstractCookieJar
from .client_exceptions import (
ClientConnectionError,
ClientConnectionResetError,
ClientConnectorCertificateError,
ClientConnectorError,
ClientConnectorSSLError,
Expand Down Expand Up @@ -107,6 +108,7 @@
__all__ = (
# client_exceptions
"ClientConnectionError",
"ClientConnectionResetError",
"ClientConnectorCertificateError",
"ClientConnectorError",
"ClientConnectorSSLError",
Expand Down
9 changes: 7 additions & 2 deletions aiohttp/client_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from multidict import MultiMapping

from .http_parser import RawResponseMessage
from .typedefs import StrOrURL

try:
Expand All @@ -18,12 +17,14 @@

if TYPE_CHECKING:
from .client_reqrep import ClientResponse, ConnectionKey, Fingerprint, RequestInfo
from .http_parser import RawResponseMessage
else:
RequestInfo = ClientResponse = ConnectionKey = None
RequestInfo = ClientResponse = ConnectionKey = RawResponseMessage = None

__all__ = (
"ClientError",
"ClientConnectionError",
"ClientConnectionResetError",
"ClientOSError",
"ClientConnectorError",
"ClientProxyConnectionError",
Expand Down Expand Up @@ -126,6 +127,10 @@ class ClientConnectionError(ClientError):
"""Base class for client socket errors."""


class ClientConnectionResetError(ClientConnectionError, ConnectionResetError):
"""ConnectionResetError"""


class ClientOSError(ClientConnectionError, OSError):
"""OSError error."""

Expand Down
5 changes: 3 additions & 2 deletions aiohttp/http_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
)

from .base_protocol import BaseProtocol
from .client_exceptions import ClientConnectionResetError
from .compression_utils import ZLibCompressor, ZLibDecompressor
from .helpers import NO_EXTENSIONS, set_exception
from .streams import DataQueue
Expand Down Expand Up @@ -609,7 +610,7 @@ async def _send_frame(
) -> None:
"""Send a frame over the websocket with message as its payload."""
if self._closing and not (opcode & WSMsgType.CLOSE):
raise ConnectionResetError("Cannot write to closing transport")
raise ClientConnectionResetError("Cannot write to closing transport")

# RSV are the reserved bits in the frame header. They are used to
# indicate that the frame is using an extension.
Expand Down Expand Up @@ -704,7 +705,7 @@ def _make_compress_obj(self, compress: int) -> ZLibCompressor:

def _write(self, data: bytes) -> None:
if self.transport.is_closing():
raise ConnectionResetError("Cannot write to closing transport")
raise ClientConnectionResetError("Cannot write to closing transport")
self.transport.write(data)

async def pong(self, message: Union[bytes, str] = b"") -> None:
Expand Down
3 changes: 2 additions & 1 deletion aiohttp/http_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from .abc import AbstractStreamWriter
from .base_protocol import BaseProtocol
from .client_exceptions import ClientConnectionResetError
from .compression_utils import ZLibCompressor
from .helpers import NO_EXTENSIONS

Expand Down Expand Up @@ -72,7 +73,7 @@ def _write(self, chunk: bytes) -> None:
self.output_size += size
transport = self.transport
if not self._protocol.connected or transport is None or transport.is_closing():
raise ConnectionResetError("Cannot write to closing transport")
raise ClientConnectionResetError("Cannot write to closing transport")
transport.write(chunk)

async def write(
Expand Down
6 changes: 6 additions & 0 deletions docs/client_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2193,6 +2193,10 @@ Connection errors

Derived from :exc:`ClientError`

.. class:: ClientConnectionResetError

Derived from :exc:`ClientConnectionError` and :exc:`ConnectionResetError`

.. class:: ClientOSError

Subset of connection errors that are initiated by an :exc:`OSError`
Expand Down Expand Up @@ -2279,6 +2283,8 @@ Hierarchy of exceptions

* :exc:`ClientConnectionError`

* :exc:`ClientConnectionResetError`

* :exc:`ClientOSError`

* :exc:`ClientConnectorError`
Expand Down
20 changes: 14 additions & 6 deletions tests/test_client_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,19 @@
import base64
import hashlib
import os
from typing import Mapping
from typing import Mapping, Type
from unittest import mock

import pytest

import aiohttp
from aiohttp import client, hdrs
from aiohttp.client_exceptions import ServerDisconnectedError
from aiohttp.client_ws import ClientWSTimeout
from aiohttp import (
ClientConnectionResetError,
ClientWSTimeout,
ServerDisconnectedError,
client,
hdrs,
)
from aiohttp.http import WS_KEY
from aiohttp.streams import EofStream
from aiohttp.test_utils import make_mocked_coro
Expand Down Expand Up @@ -535,8 +539,12 @@ async def test_close_exc2(
await resp.close()


@pytest.mark.parametrize("exc", (ClientConnectionResetError, ConnectionResetError))
async def test_send_data_after_close(
ws_key: bytes, key_data: bytes, loop: asyncio.AbstractEventLoop
exc: Type[Exception],
ws_key: bytes,
key_data: bytes,
loop: asyncio.AbstractEventLoop,
) -> None:
mresp = mock.Mock()
mresp.status = 101
Expand All @@ -562,7 +570,7 @@ async def test_send_data_after_close(
(resp.send_bytes, (b"b",)),
(resp.send_json, ({},)),
):
with pytest.raises(ConnectionResetError):
with pytest.raises(exc): # Verify exc can be caught with both classes
await meth(*args)


Expand Down
4 changes: 2 additions & 2 deletions tests/test_client_ws_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest

import aiohttp
from aiohttp import ServerTimeoutError, WSMsgType, hdrs, web
from aiohttp import ClientConnectionResetError, ServerTimeoutError, WSMsgType, hdrs, web
from aiohttp.client_ws import ClientWSTimeout
from aiohttp.http import WSCloseCode
from aiohttp.pytest_plugin import AiohttpClient, AiohttpServer
Expand Down Expand Up @@ -681,7 +681,7 @@ async def handler(request: web.Request) -> NoReturn:
# would cancel the heartbeat task and we wouldn't get a ping
assert resp._conn is not None
with mock.patch.object(
resp._conn.transport, "write", side_effect=ConnectionResetError
resp._conn.transport, "write", side_effect=ClientConnectionResetError
), mock.patch.object(resp._writer, "ping", wraps=resp._writer.ping) as ping:
await resp.receive()
ping_count = ping.call_count
Expand Down
10 changes: 6 additions & 4 deletions tests/test_http_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest
from multidict import CIMultiDict

from aiohttp import http
from aiohttp import ClientConnectionResetError, http
from aiohttp.base_protocol import BaseProtocol
from aiohttp.test_utils import make_mocked_coro

Expand Down Expand Up @@ -301,7 +301,7 @@ async def test_write_to_closing_transport(
await msg.write(b"Before closing")
transport.is_closing.return_value = True # type: ignore[attr-defined]

with pytest.raises(ConnectionResetError):
with pytest.raises(ClientConnectionResetError):
await msg.write(b"After closing")


Expand All @@ -310,7 +310,7 @@ async def test_write_to_closed_transport(
transport: asyncio.Transport,
loop: asyncio.AbstractEventLoop,
) -> None:
"""Test that writing to a closed transport raises ConnectionResetError.
"""Test that writing to a closed transport raises ClientConnectionResetError.
The StreamWriter checks to see if protocol.transport is None before
writing to the transport. If it is None, it raises ConnectionResetError.
Expand All @@ -320,7 +320,9 @@ async def test_write_to_closed_transport(
await msg.write(b"Before transport close")
protocol.transport = None

with pytest.raises(ConnectionResetError, match="Cannot write to closing transport"):
with pytest.raises(
ClientConnectionResetError, match="Cannot write to closing transport"
):
await msg.write(b"After transport closed")


Expand Down

0 comments on commit f95bcaf

Please sign in to comment.