Skip to content

Commit

Permalink
format tests/_(a)sync to pass unasync check
Browse files Browse the repository at this point in the history
  • Loading branch information
Tester authored and Tester committed Feb 13, 2024
1 parent 5c176c7 commit 779653c
Show file tree
Hide file tree
Showing 12 changed files with 155 additions and 31 deletions.
29 changes: 21 additions & 8 deletions tests/_async/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ async def test_concurrent_requests_not_available_on_http11_connections():
await conn.request("GET", "https://example.com/")


@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning")
@pytest.mark.anyio
@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning")
async def test_write_error_with_response_sent():
"""
If a server half-closes the connection while the client is sending
Expand All @@ -103,7 +103,9 @@ def __init__(self, buffer: typing.List[bytes], http2: bool = False) -> None:
self.count = 0

async def write(
self, buffer: bytes, timeout: typing.Optional[float] = None
self,
buffer: bytes,
timeout: typing.Optional[float] = None,
) -> None:
self.count += len(buffer)

Expand Down Expand Up @@ -157,7 +159,9 @@ def __init__(self, buffer: typing.List[bytes], http2: bool = False) -> None:
self.count = 0

async def write(
self, buffer: bytes, timeout: typing.Optional[float] = None
self,
buffer: bytes,
timeout: typing.Optional[float] = None,
) -> None:
self.count += len(buffer)

Expand Down Expand Up @@ -212,7 +216,9 @@ async def test_http2_connection():
)

async with AsyncHTTPConnection(
origin=origin, network_backend=network_backend, http2=True
origin=origin,
network_backend=network_backend,
http2=True,
) as conn:
response = await conn.request("GET", "https://example.com/")

Expand All @@ -229,7 +235,8 @@ async def test_request_to_incorrect_origin():
origin = Origin(b"https", b"example.com", 443)
network_backend = AsyncMockBackend([])
async with AsyncHTTPConnection(
origin=origin, network_backend=network_backend
origin=origin,
network_backend=network_backend,
) as conn:
with pytest.raises(RuntimeError):
await conn.request("GET", "https://other.com/")
Expand Down Expand Up @@ -266,18 +273,24 @@ async def connect_tcp(

class _NeedsRetryAsyncNetworkStream(AsyncNetworkStream):
def __init__(
self, backend: "NeedsRetryBackend", stream: AsyncNetworkStream
self,
backend: "NeedsRetryBackend",
stream: AsyncNetworkStream,
) -> None:
self._backend = backend
self._stream = stream

async def read(
self, max_bytes: int, timeout: typing.Optional[float] = None
self,
max_bytes: int,
timeout: typing.Optional[float] = None,
) -> bytes:
return await self._stream.read(max_bytes, timeout)

async def write(
self, buffer: bytes, timeout: typing.Optional[float] = None
self,
buffer: bytes,
timeout: typing.Optional[float] = None,
) -> None:
await self._stream.write(buffer, timeout)

Expand Down
16 changes: 12 additions & 4 deletions tests/_async/test_connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,9 @@ async def trace(name, kwargs):
# Sending an initial request, which once complete will not return to the pool.
with pytest.raises(Exception):
await pool.request(
"GET", "https://example.com/", extensions={"trace": trace}
"GET",
"https://example.com/",
extensions={"trace": trace},
)

info = [repr(c) for c in pool.connections]
Expand Down Expand Up @@ -452,7 +454,9 @@ async def trace(name, kwargs):
# Sending an initial request, which once complete will not return to the pool.
with pytest.raises(Exception):
await pool.request(
"GET", "https://example.com/", extensions={"trace": trace}
"GET",
"https://example.com/",
extensions={"trace": trace},
)

info = [repr(c) for c in pool.connections]
Expand Down Expand Up @@ -775,13 +779,17 @@ async def test_connection_pool_timeout_zero():
# Two consecutive requests with a pool timeout of zero.
# Both succeed without raising a timeout.
response = await pool.request(
"GET", "https://example.com/", extensions=extensions
"GET",
"https://example.com/",
extensions=extensions,
)
assert response.status == 200
assert response.content == b"Hello, world!"

response = await pool.request(
"GET", "https://example.com/", extensions=extensions
"GET",
"https://example.com/",
extensions=extensions,
)
assert response.status == 200
assert response.content == b"Hello, world!"
Expand Down
2 changes: 1 addition & 1 deletion tests/_async/test_http11.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ async def test_http11_connection():
assert repr(conn) == (
"<AsyncHTTP11Connection ['https://example.com:443', IDLE,"
" Request Count: 1]>"
)
) # fmt: skip


@pytest.mark.anyio
Expand Down
3 changes: 1 addition & 2 deletions tests/_async/test_http2.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ async def test_http2_connection():
conn.info() == "'https://example.com:443', HTTP/2, IDLE, Request Count: 1"
)
assert repr(conn) == (
"<AsyncHTTP2Connection ['https://example.com:443', IDLE,"
" Request Count: 1]>"
"<AsyncHTTP2Connection ['https://example.com:443', IDLE, Request Count: 1]>"
)


Expand Down
48 changes: 41 additions & 7 deletions tests/_sync/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)


# unasync anyio
def test_http_connection():
origin = Origin(b"https", b"example.com", 443)
network_backend = MockBackend(
Expand Down Expand Up @@ -60,6 +61,7 @@ def test_http_connection():
)


# unasync anyio
def test_concurrent_requests_not_available_on_http11_connections():
"""
Attempting to issue a request against an already active HTTP/1.1 connection
Expand All @@ -84,6 +86,7 @@ def test_concurrent_requests_not_available_on_http11_connections():
conn.request("GET", "https://example.com/")


# unasync anyio
@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning")
def test_write_error_with_response_sent():
"""
Expand All @@ -99,7 +102,11 @@ def __init__(self, buffer: typing.List[bytes], http2: bool = False) -> None:
super().__init__(buffer, http2)
self.count = 0

def write(self, buffer: bytes, timeout: typing.Optional[float] = None) -> None:
def write(
self,
buffer: bytes,
timeout: typing.Optional[float] = None,
) -> None:
self.count += len(buffer)

if self.count > 1_000_000:
Expand Down Expand Up @@ -136,6 +143,7 @@ def connect_tcp(
assert response.content == b"Request body exceeded 1,000,000 bytes"


# unasync anyio
@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning")
def test_write_error_without_response_sent():
"""
Expand All @@ -150,7 +158,11 @@ def __init__(self, buffer: typing.List[bytes], http2: bool = False) -> None:
super().__init__(buffer, http2)
self.count = 0

def write(self, buffer: bytes, timeout: typing.Optional[float] = None) -> None:
def write(
self,
buffer: bytes,
timeout: typing.Optional[float] = None,
) -> None:
self.count += len(buffer)

if self.count > 1_000_000:
Expand Down Expand Up @@ -179,6 +191,7 @@ def connect_tcp(
assert str(exc_info.value) == "Server disconnected without sending a response."


# unasync anyio
@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning")
def test_http2_connection():
origin = Origin(b"https", b"example.com", 443)
Expand All @@ -203,7 +216,9 @@ def test_http2_connection():
)

with HTTPConnection(
origin=origin, network_backend=network_backend, http2=True
origin=origin,
network_backend=network_backend,
http2=True,
) as conn:
response = conn.request("GET", "https://example.com/")

Expand All @@ -212,13 +227,17 @@ def test_http2_connection():
assert response.extensions["http_version"] == b"HTTP/2"


# unasync anyio
def test_request_to_incorrect_origin():
"""
A connection can only send requests whichever origin it is connected to.
"""
origin = Origin(b"https", b"example.com", 443)
network_backend = MockBackend([])
with HTTPConnection(origin=origin, network_backend=network_backend) as conn:
with HTTPConnection(
origin=origin,
network_backend=network_backend,
) as conn:
with pytest.raises(RuntimeError):
conn.request("GET", "https://other.com/")

Expand Down Expand Up @@ -253,14 +272,26 @@ def connect_tcp(
return self._NeedsRetryAsyncNetworkStream(self, stream)

class _NeedsRetryAsyncNetworkStream(NetworkStream):
def __init__(self, backend: "NeedsRetryBackend", stream: NetworkStream) -> None:
def __init__(
self,
backend: "NeedsRetryBackend",
stream: NetworkStream,
) -> None:
self._backend = backend
self._stream = stream

def read(self, max_bytes: int, timeout: typing.Optional[float] = None) -> bytes:
def read(
self,
max_bytes: int,
timeout: typing.Optional[float] = None,
) -> bytes:
return self._stream.read(max_bytes, timeout)

def write(self, buffer: bytes, timeout: typing.Optional[float] = None) -> None:
def write(
self,
buffer: bytes,
timeout: typing.Optional[float] = None,
) -> None:
self._stream.write(buffer, timeout)

def close(self) -> None:
Expand All @@ -283,6 +314,7 @@ def get_extra_info(self, info: str) -> typing.Any:
return self._stream.get_extra_info(info)


# unasync anyio
def test_connection_retries():
origin = Origin(b"https", b"example.com", 443)
content = [
Expand All @@ -309,6 +341,7 @@ def test_connection_retries():
conn.request("GET", "https://example.com/")


# unasync anyio
def test_connection_retries_tls():
origin = Origin(b"https", b"example.com", 443)
content = [
Expand Down Expand Up @@ -339,6 +372,7 @@ def test_connection_retries_tls():
conn.request("GET", "https://example.com/")


# unasync anyio
def test_uds_connections():
# We're not actually testing Unix Domain Sockets here, because we're just
# using a mock backend, but at least we're covering the UDS codepath
Expand Down
Loading

0 comments on commit 779653c

Please sign in to comment.