From 7d274ed3894b8e16fcbdc69a6456f8ca12b33711 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Sat, 9 Dec 2023 08:23:25 +0300 Subject: [PATCH] Create `http_protocol_cls` fixture (#2174) --- tests/conftest.py | 16 + tests/middleware/test_logging.py | 13 +- tests/middleware/test_proxy_headers.py | 6 +- tests/protocols/test_http.py | 382 ++++++++++-------- tests/protocols/test_websocket.py | 232 +++++------ uvicorn/_types.py | 42 +- .../protocols/websockets/websockets_impl.py | 11 +- uvicorn/protocols/websockets/wsproto_impl.py | 13 +- 8 files changed, 367 insertions(+), 348 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 400351d55..a405c3175 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -259,3 +259,19 @@ def unused_tcp_port() -> int: ) def ws_protocol_cls(request: pytest.FixtureRequest): return import_from_string(request.param) + + +@pytest.fixture( + params=[ + pytest.param( + "uvicorn.protocols.http.httptools_impl:HttpToolsProtocol", + marks=pytest.mark.skipif( + not importlib.util.find_spec("httptools"), + reason="httptools not installed.", + ), + ), + "uvicorn.protocols.http.h11_impl:H11Protocol", + ] +) +def http_protocol_cls(request: pytest.FixtureRequest): + return import_from_string(request.param) diff --git a/tests/middleware/test_logging.py b/tests/middleware/test_logging.py index 84e7c8985..bc49f3463 100644 --- a/tests/middleware/test_logging.py +++ b/tests/middleware/test_logging.py @@ -11,14 +11,6 @@ from tests.utils import run_server from uvicorn import Config -from uvicorn.protocols.http.h11_impl import H11Protocol - -try: - from uvicorn.protocols.http.httptools_impl import HttpToolsProtocol - - HTTP_PROTOCOLS = [H11Protocol, HttpToolsProtocol] -except ImportError: # pragma: nocover - HTTP_PROTOCOLS = [H11Protocol] if typing.TYPE_CHECKING: from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol @@ -69,14 +61,13 @@ async def test_trace_logging(caplog, logging_config, unused_tcp_port: int): @pytest.mark.anyio -@pytest.mark.parametrize("http_protocol", HTTP_PROTOCOLS) async def test_trace_logging_on_http_protocol( - http_protocol, caplog, logging_config, unused_tcp_port: int + http_protocol_cls, caplog, logging_config, unused_tcp_port: int ): config = Config( app=app, log_level="trace", - http=http_protocol, + http=http_protocol_cls, log_config=logging_config, port=unused_tcp_port, ) diff --git a/tests/middleware/test_proxy_headers.py b/tests/middleware/test_proxy_headers.py index b363f189f..53a4e70db 100644 --- a/tests/middleware/test_proxy_headers.py +++ b/tests/middleware/test_proxy_headers.py @@ -4,7 +4,6 @@ import pytest import websockets.client -from tests.protocols.test_http import HTTP_PROTOCOLS from tests.response import Response from tests.utils import run_server from uvicorn._types import ASGIReceiveCallable, ASGISendCallable, Scope @@ -12,6 +11,8 @@ from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware if TYPE_CHECKING: + from uvicorn.protocols.http.h11_impl import H11Protocol + from uvicorn.protocols.http.httptools_impl import HttpToolsProtocol from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol from uvicorn.protocols.websockets.wsproto_impl import WSProtocol @@ -114,10 +115,9 @@ async def test_proxy_headers_invalid_x_forwarded_for() -> None: @pytest.mark.anyio -@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_proxy_headers_websocket_x_forwarded_proto( ws_protocol_cls: "Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls, + http_protocol_cls: "Type[H11Protocol | HttpToolsProtocol]", unused_tcp_port: int, ) -> None: async def websocket_app(scope, receive, send): diff --git a/tests/protocols/test_http.py b/tests/protocols/test_http.py index b93cc276d..fde4cc70b 100644 --- a/tests/protocols/test_http.py +++ b/tests/protocols/test_http.py @@ -2,12 +2,13 @@ import socket import threading import time -from typing import TYPE_CHECKING, Optional, Type, Union +from typing import TYPE_CHECKING, Any, Dict, List, Type import pytest from tests.response import Response from uvicorn import Server +from uvicorn._types import ASGIApplication, ASGIReceiveCallable, ASGISendCallable, Scope from uvicorn.config import WS_PROTOCOLS, Config from uvicorn.lifespan.off import LifespanOff from uvicorn.lifespan.on import LifespanOn @@ -16,15 +17,16 @@ try: from uvicorn.protocols.http.httptools_impl import HttpToolsProtocol -except ImportError: # pragma: nocover - HttpToolsProtocol = None # type: ignore[misc,assignment] + + skip_if_no_httptools = pytest.mark.skipif(False, reason="httptools is installed") +except ModuleNotFoundError: + skip_if_no_httptools = pytest.mark.skipif(True, reason="httptools is not installed") if TYPE_CHECKING: from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol from uvicorn.protocols.websockets.wsproto_impl import WSProtocol -HTTP_PROTOCOLS = [p for p in [H11Protocol, HttpToolsProtocol] if p is not None] WEBSOCKET_PROTOCOLS = WS_PROTOCOLS.keys() SIMPLE_GET_REQUEST = b"\r\n".join([b"GET / HTTP/1.1", b"Host: example.org", b"", b""]) @@ -192,32 +194,31 @@ def add_done_callback(self, callback): def get_connected_protocol( - app, - protocol_cls, - lifespan: Optional[Union[LifespanOff, LifespanOn]] = None, - **kwargs, + app: ASGIApplication, + http_protocol_cls: "Type[HttpToolsProtocol | H11Protocol]", + lifespan: "LifespanOff | LifespanOn | None" = None, + **kwargs: Any, ): loop = MockLoop() transport = MockTransport() config = Config(app=app, **kwargs) lifespan = lifespan or LifespanOff(config) server_state = ServerState() - protocol = protocol_cls( + protocol = http_protocol_cls( config=config, server_state=server_state, app_state=lifespan.state, - _loop=loop, + _loop=loop, # type: ignore ) - protocol.connection_made(transport) + protocol.connection_made(transport) # type: ignore return protocol @pytest.mark.anyio -@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -async def test_get_request(protocol_cls): +async def test_get_request(http_protocol_cls: "Type[HttpToolsProtocol | H11Protocol]"): app = Response("Hello, world", media_type="text/plain") - protocol = get_connected_protocol(app, protocol_cls) + protocol = get_connected_protocol(app, http_protocol_cls) protocol.data_received(SIMPLE_GET_REQUEST) await protocol.loop.run_one() assert b"HTTP/1.1 200 OK" in protocol.transport.buffer @@ -226,8 +227,9 @@ async def test_get_request(protocol_cls): @pytest.mark.anyio @pytest.mark.parametrize("path", ["/", "/?foo", "/?foo=bar", "/?foo=bar&baz=1"]) -@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -async def test_request_logging(path, protocol_cls, caplog): +async def test_request_logging( + path, http_protocol_cls: "Type[HttpToolsProtocol | H11Protocol]", caplog +): get_request_with_query_string = b"\r\n".join( ["GET {} HTTP/1.1".format(path).encode("ascii"), b"Host: example.org", b"", b""] ) @@ -236,18 +238,17 @@ async def test_request_logging(path, protocol_cls, caplog): app = Response("Hello, world", media_type="text/plain") - protocol = get_connected_protocol(app, protocol_cls, log_config=None) + protocol = get_connected_protocol(app, http_protocol_cls, log_config=None) protocol.data_received(get_request_with_query_string) await protocol.loop.run_one() assert '"GET {} HTTP/1.1" 200'.format(path) in caplog.records[0].message @pytest.mark.anyio -@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -async def test_head_request(protocol_cls): +async def test_head_request(http_protocol_cls: "Type[HttpToolsProtocol | H11Protocol]"): app = Response("Hello, world", media_type="text/plain") - protocol = get_connected_protocol(app, protocol_cls) + protocol = get_connected_protocol(app, http_protocol_cls) protocol.data_received(SIMPLE_HEAD_REQUEST) await protocol.loop.run_one() assert b"HTTP/1.1 200 OK" in protocol.transport.buffer @@ -255,19 +256,19 @@ async def test_head_request(protocol_cls): @pytest.mark.anyio -@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -async def test_post_request(protocol_cls): - async def app(scope, receive, send): +async def test_post_request(http_protocol_cls: "Type[HttpToolsProtocol | H11Protocol]"): + async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): body = b"" more_body = True while more_body: message = await receive() + assert message["type"] == "http.request" body += message.get("body", b"") more_body = message.get("more_body", False) response = Response(b"Body: " + body, media_type="text/plain") await response(scope, receive, send) - protocol = get_connected_protocol(app, protocol_cls) + protocol = get_connected_protocol(app, http_protocol_cls) protocol.data_received(SIMPLE_POST_REQUEST) await protocol.loop.run_one() assert b"HTTP/1.1 200 OK" in protocol.transport.buffer @@ -275,11 +276,10 @@ async def app(scope, receive, send): @pytest.mark.anyio -@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -async def test_keepalive(protocol_cls): +async def test_keepalive(http_protocol_cls: "Type[HttpToolsProtocol | H11Protocol]"): app = Response(b"", status_code=204) - protocol = get_connected_protocol(app, protocol_cls) + protocol = get_connected_protocol(app, http_protocol_cls) protocol.data_received(SIMPLE_GET_REQUEST) await protocol.loop.run_one() @@ -288,11 +288,12 @@ async def test_keepalive(protocol_cls): @pytest.mark.anyio -@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -async def test_keepalive_timeout(protocol_cls): +async def test_keepalive_timeout( + http_protocol_cls: "Type[HttpToolsProtocol | H11Protocol]", +): app = Response(b"", status_code=204) - protocol = get_connected_protocol(app, protocol_cls) + protocol = get_connected_protocol(app, http_protocol_cls) protocol.data_received(SIMPLE_GET_REQUEST) await protocol.loop.run_one() assert b"HTTP/1.1 204 No Content" in protocol.transport.buffer @@ -304,11 +305,10 @@ async def test_keepalive_timeout(protocol_cls): @pytest.mark.anyio -@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -async def test_close(protocol_cls): +async def test_close(http_protocol_cls: "Type[HttpToolsProtocol | H11Protocol]"): app = Response(b"", status_code=204, headers={"connection": "close"}) - protocol = get_connected_protocol(app, protocol_cls) + protocol = get_connected_protocol(app, http_protocol_cls) protocol.data_received(SIMPLE_GET_REQUEST) await protocol.loop.run_one() assert b"HTTP/1.1 204 No Content" in protocol.transport.buffer @@ -316,13 +316,14 @@ async def test_close(protocol_cls): @pytest.mark.anyio -@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -async def test_chunked_encoding(protocol_cls): +async def test_chunked_encoding( + http_protocol_cls: "Type[HttpToolsProtocol | H11Protocol]", +): app = Response( b"Hello, world!", status_code=200, headers={"transfer-encoding": "chunked"} ) - protocol = get_connected_protocol(app, protocol_cls) + protocol = get_connected_protocol(app, http_protocol_cls) protocol.data_received(SIMPLE_GET_REQUEST) await protocol.loop.run_one() assert b"HTTP/1.1 200 OK" in protocol.transport.buffer @@ -331,13 +332,14 @@ async def test_chunked_encoding(protocol_cls): @pytest.mark.anyio -@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -async def test_chunked_encoding_empty_body(protocol_cls): +async def test_chunked_encoding_empty_body( + http_protocol_cls: "Type[HttpToolsProtocol | H11Protocol]", +): app = Response( b"Hello, world!", status_code=200, headers={"transfer-encoding": "chunked"} ) - protocol = get_connected_protocol(app, protocol_cls) + protocol = get_connected_protocol(app, http_protocol_cls) protocol.data_received(SIMPLE_GET_REQUEST) await protocol.loop.run_one() assert b"HTTP/1.1 200 OK" in protocol.transport.buffer @@ -346,13 +348,14 @@ async def test_chunked_encoding_empty_body(protocol_cls): @pytest.mark.anyio -@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -async def test_chunked_encoding_head_request(protocol_cls): +async def test_chunked_encoding_head_request( + http_protocol_cls: "Type[HttpToolsProtocol | H11Protocol]", +): app = Response( b"Hello, world!", status_code=200, headers={"transfer-encoding": "chunked"} ) - protocol = get_connected_protocol(app, protocol_cls) + protocol = get_connected_protocol(app, http_protocol_cls) protocol.data_received(SIMPLE_HEAD_REQUEST) await protocol.loop.run_one() assert b"HTTP/1.1 200 OK" in protocol.transport.buffer @@ -360,11 +363,12 @@ async def test_chunked_encoding_head_request(protocol_cls): @pytest.mark.anyio -@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -async def test_pipelined_requests(protocol_cls): +async def test_pipelined_requests( + http_protocol_cls: "Type[HttpToolsProtocol | H11Protocol]", +): app = Response("Hello, world", media_type="text/plain") - protocol = get_connected_protocol(app, protocol_cls) + protocol = get_connected_protocol(app, http_protocol_cls) protocol.data_received(SIMPLE_GET_REQUEST) protocol.data_received(SIMPLE_GET_REQUEST) protocol.data_received(SIMPLE_GET_REQUEST) @@ -383,33 +387,36 @@ async def test_pipelined_requests(protocol_cls): @pytest.mark.anyio -@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -async def test_undersized_request(protocol_cls): +async def test_undersized_request( + http_protocol_cls: "Type[HttpToolsProtocol | H11Protocol]", +): app = Response(b"xxx", headers={"content-length": "10"}) - protocol = get_connected_protocol(app, protocol_cls) + protocol = get_connected_protocol(app, http_protocol_cls) protocol.data_received(SIMPLE_GET_REQUEST) await protocol.loop.run_one() assert protocol.transport.is_closing() @pytest.mark.anyio -@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -async def test_oversized_request(protocol_cls): +async def test_oversized_request( + http_protocol_cls: "Type[HttpToolsProtocol | H11Protocol]", +): app = Response(b"xxx" * 20, headers={"content-length": "10"}) - protocol = get_connected_protocol(app, protocol_cls) + protocol = get_connected_protocol(app, http_protocol_cls) protocol.data_received(SIMPLE_GET_REQUEST) await protocol.loop.run_one() assert protocol.transport.is_closing() @pytest.mark.anyio -@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -async def test_large_post_request(protocol_cls): +async def test_large_post_request( + http_protocol_cls: "Type[HttpToolsProtocol | H11Protocol]", +): app = Response("Hello, world", media_type="text/plain") - protocol = get_connected_protocol(app, protocol_cls) + protocol = get_connected_protocol(app, http_protocol_cls) protocol.data_received(LARGE_POST_REQUEST) assert protocol.transport.read_paused await protocol.loop.run_one() @@ -417,22 +424,22 @@ async def test_large_post_request(protocol_cls): @pytest.mark.anyio -@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -async def test_invalid_http(protocol_cls): +async def test_invalid_http(http_protocol_cls: "Type[HttpToolsProtocol | H11Protocol]"): app = Response("Hello, world", media_type="text/plain") - protocol = get_connected_protocol(app, protocol_cls) + protocol = get_connected_protocol(app, http_protocol_cls) protocol.data_received(b"x" * 100000) assert protocol.transport.is_closing() @pytest.mark.anyio -@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -async def test_app_exception(protocol_cls): - async def app(scope, receive, send): +async def test_app_exception( + http_protocol_cls: "Type[HttpToolsProtocol | H11Protocol]", +): + async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): raise Exception() - protocol = get_connected_protocol(app, protocol_cls) + protocol = get_connected_protocol(app, http_protocol_cls) protocol.data_received(SIMPLE_GET_REQUEST) await protocol.loop.run_one() assert b"HTTP/1.1 500 Internal Server Error" in protocol.transport.buffer @@ -440,14 +447,15 @@ async def app(scope, receive, send): @pytest.mark.anyio -@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -async def test_exception_during_response(protocol_cls): - async def app(scope, receive, send): +async def test_exception_during_response( + http_protocol_cls: "Type[HttpToolsProtocol | H11Protocol]", +): + async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): await send({"type": "http.response.start", "status": 200}) await send({"type": "http.response.body", "body": b"1", "more_body": True}) raise Exception() - protocol = get_connected_protocol(app, protocol_cls) + protocol = get_connected_protocol(app, http_protocol_cls) protocol.data_received(SIMPLE_GET_REQUEST) await protocol.loop.run_one() assert b"HTTP/1.1 500 Internal Server Error" not in protocol.transport.buffer @@ -455,12 +463,13 @@ async def app(scope, receive, send): @pytest.mark.anyio -@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -async def test_no_response_returned(protocol_cls): - async def app(scope, receive, send): - pass +async def test_no_response_returned( + http_protocol_cls: "Type[HttpToolsProtocol | H11Protocol]", +): + async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): + ... - protocol = get_connected_protocol(app, protocol_cls) + protocol = get_connected_protocol(app, http_protocol_cls) protocol.data_received(SIMPLE_GET_REQUEST) await protocol.loop.run_one() assert b"HTTP/1.1 500 Internal Server Error" in protocol.transport.buffer @@ -468,12 +477,13 @@ async def app(scope, receive, send): @pytest.mark.anyio -@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -async def test_partial_response_returned(protocol_cls): - async def app(scope, receive, send): +async def test_partial_response_returned( + http_protocol_cls: "Type[HttpToolsProtocol | H11Protocol]", +): + async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): await send({"type": "http.response.start", "status": 200}) - protocol = get_connected_protocol(app, protocol_cls) + protocol = get_connected_protocol(app, http_protocol_cls) protocol.data_received(SIMPLE_GET_REQUEST) await protocol.loop.run_one() assert b"HTTP/1.1 500 Internal Server Error" not in protocol.transport.buffer @@ -481,13 +491,14 @@ async def app(scope, receive, send): @pytest.mark.anyio -@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -async def test_duplicate_start_message(protocol_cls): - async def app(scope, receive, send): +async def test_duplicate_start_message( + http_protocol_cls: "Type[HttpToolsProtocol | H11Protocol]", +): + async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): await send({"type": "http.response.start", "status": 200}) await send({"type": "http.response.start", "status": 200}) - protocol = get_connected_protocol(app, protocol_cls) + protocol = get_connected_protocol(app, http_protocol_cls) protocol.data_received(SIMPLE_GET_REQUEST) await protocol.loop.run_one() assert b"HTTP/1.1 500 Internal Server Error" not in protocol.transport.buffer @@ -495,12 +506,13 @@ async def app(scope, receive, send): @pytest.mark.anyio -@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -async def test_missing_start_message(protocol_cls): - async def app(scope, receive, send): +async def test_missing_start_message( + http_protocol_cls: "Type[HttpToolsProtocol | H11Protocol]", +): + async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): await send({"type": "http.response.body", "body": b""}) - protocol = get_connected_protocol(app, protocol_cls) + protocol = get_connected_protocol(app, http_protocol_cls) protocol.data_received(SIMPLE_GET_REQUEST) await protocol.loop.run_one() assert b"HTTP/1.1 500 Internal Server Error" in protocol.transport.buffer @@ -508,14 +520,15 @@ async def app(scope, receive, send): @pytest.mark.anyio -@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -async def test_message_after_body_complete(protocol_cls): - async def app(scope, receive, send): +async def test_message_after_body_complete( + http_protocol_cls: "Type[HttpToolsProtocol | H11Protocol]", +): + async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): await send({"type": "http.response.start", "status": 200}) await send({"type": "http.response.body", "body": b""}) await send({"type": "http.response.body", "body": b""}) - protocol = get_connected_protocol(app, protocol_cls) + protocol = get_connected_protocol(app, http_protocol_cls) protocol.data_received(SIMPLE_GET_REQUEST) await protocol.loop.run_one() assert b"HTTP/1.1 200 OK" in protocol.transport.buffer @@ -523,14 +536,15 @@ async def app(scope, receive, send): @pytest.mark.anyio -@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -async def test_value_returned(protocol_cls): - async def app(scope, receive, send): +async def test_value_returned( + http_protocol_cls: "Type[HttpToolsProtocol | H11Protocol]", +): + async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): await send({"type": "http.response.start", "status": 200}) await send({"type": "http.response.body", "body": b""}) return 123 - protocol = get_connected_protocol(app, protocol_cls) + protocol = get_connected_protocol(app, http_protocol_cls) protocol.data_received(SIMPLE_GET_REQUEST) await protocol.loop.run_one() assert b"HTTP/1.1 200 OK" in protocol.transport.buffer @@ -538,11 +552,12 @@ async def app(scope, receive, send): @pytest.mark.anyio -@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -async def test_early_disconnect(protocol_cls): +async def test_early_disconnect( + http_protocol_cls: "Type[HttpToolsProtocol | H11Protocol]", +): got_disconnect_event = False - async def app(scope, receive, send): + async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): nonlocal got_disconnect_event while True: @@ -552,7 +567,7 @@ async def app(scope, receive, send): got_disconnect_event = True - protocol = get_connected_protocol(app, protocol_cls) + protocol = get_connected_protocol(app, http_protocol_cls) protocol.data_received(SIMPLE_POST_REQUEST) protocol.eof_received() protocol.connection_lost(None) @@ -561,11 +576,12 @@ async def app(scope, receive, send): @pytest.mark.anyio -@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -async def test_early_response(protocol_cls): +async def test_early_response( + http_protocol_cls: "Type[HttpToolsProtocol | H11Protocol]", +): app = Response("Hello, world", media_type="text/plain") - protocol = get_connected_protocol(app, protocol_cls) + protocol = get_connected_protocol(app, http_protocol_cls) protocol.data_received(START_POST_REQUEST) await protocol.loop.run_one() assert b"HTTP/1.1 200 OK" in protocol.transport.buffer @@ -574,18 +590,19 @@ async def test_early_response(protocol_cls): @pytest.mark.anyio -@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -async def test_read_after_response(protocol_cls): +async def test_read_after_response( + http_protocol_cls: "Type[HttpToolsProtocol | H11Protocol]", +): message_after_response = None - async def app(scope, receive, send): + async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): nonlocal message_after_response response = Response("Hello, world", media_type="text/plain") await response(scope, receive, send) message_after_response = await receive() - protocol = get_connected_protocol(app, protocol_cls) + protocol = get_connected_protocol(app, http_protocol_cls) protocol.data_received(SIMPLE_POST_REQUEST) await protocol.loop.run_one() assert b"HTTP/1.1 200 OK" in protocol.transport.buffer @@ -593,14 +610,16 @@ async def app(scope, receive, send): @pytest.mark.anyio -@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -async def test_http10_request(protocol_cls): - async def app(scope, receive, send): +async def test_http10_request( + http_protocol_cls: "Type[HttpToolsProtocol | H11Protocol]", +): + async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): + assert scope["type"] == "http" content = "Version: %s" % scope["http_version"] response = Response(content, media_type="text/plain") await response(scope, receive, send) - protocol = get_connected_protocol(app, protocol_cls) + protocol = get_connected_protocol(app, http_protocol_cls) protocol.data_received(HTTP10_GET_REQUEST) await protocol.loop.run_one() assert b"HTTP/1.1 200 OK" in protocol.transport.buffer @@ -608,14 +627,14 @@ async def app(scope, receive, send): @pytest.mark.anyio -@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -async def test_root_path(protocol_cls): - async def app(scope, receive, send): +async def test_root_path(http_protocol_cls: "Type[HttpToolsProtocol | H11Protocol]"): + async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): + assert scope["type"] == "http" path = scope.get("root_path", "") + scope["path"] response = Response("Path: " + path, media_type="text/plain") await response(scope, receive, send) - protocol = get_connected_protocol(app, protocol_cls, root_path="/app") + protocol = get_connected_protocol(app, http_protocol_cls, root_path="/app") protocol.data_received(SIMPLE_GET_REQUEST) await protocol.loop.run_one() assert b"HTTP/1.1 200 OK" in protocol.transport.buffer @@ -623,9 +642,9 @@ async def app(scope, receive, send): @pytest.mark.anyio -@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -async def test_raw_path(protocol_cls): - async def app(scope, receive, send): +async def test_raw_path(http_protocol_cls: "Type[HttpToolsProtocol | H11Protocol]"): + async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): + assert scope["type"] == "http" path = scope["path"] raw_path = scope.get("raw_path", None) assert "/one/two" == path @@ -634,29 +653,31 @@ async def app(scope, receive, send): response = Response("Done", media_type="text/plain") await response(scope, receive, send) - protocol = get_connected_protocol(app, protocol_cls, root_path="/app") + protocol = get_connected_protocol(app, http_protocol_cls, root_path="/app") protocol.data_received(GET_REQUEST_WITH_RAW_PATH) await protocol.loop.run_one() assert b"Done" in protocol.transport.buffer @pytest.mark.anyio -@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -async def test_max_concurrency(protocol_cls): +async def test_max_concurrency( + http_protocol_cls: "Type[HttpToolsProtocol | H11Protocol]", +): app = Response("Hello, world", media_type="text/plain") - protocol = get_connected_protocol(app, protocol_cls, limit_concurrency=1) + protocol = get_connected_protocol(app, http_protocol_cls, limit_concurrency=1) protocol.data_received(SIMPLE_GET_REQUEST) await protocol.loop.run_one() assert b"HTTP/1.1 503 Service Unavailable" in protocol.transport.buffer @pytest.mark.anyio -@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -async def test_shutdown_during_request(protocol_cls): +async def test_shutdown_during_request( + http_protocol_cls: "Type[HttpToolsProtocol | H11Protocol]", +): app = Response(b"", status_code=204) - protocol = get_connected_protocol(app, protocol_cls) + protocol = get_connected_protocol(app, http_protocol_cls) protocol.data_received(SIMPLE_GET_REQUEST) protocol.shutdown() await protocol.loop.run_one() @@ -665,30 +686,33 @@ async def test_shutdown_during_request(protocol_cls): @pytest.mark.anyio -@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -async def test_shutdown_during_idle(protocol_cls): +async def test_shutdown_during_idle( + http_protocol_cls: "Type[HttpToolsProtocol | H11Protocol]", +): app = Response("Hello, world", media_type="text/plain") - protocol = get_connected_protocol(app, protocol_cls) + protocol = get_connected_protocol(app, http_protocol_cls) protocol.shutdown() assert protocol.transport.buffer == b"" assert protocol.transport.is_closing() @pytest.mark.anyio -@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -async def test_100_continue_sent_when_body_consumed(protocol_cls): - async def app(scope, receive, send): +async def test_100_continue_sent_when_body_consumed( + http_protocol_cls: "Type[HttpToolsProtocol | H11Protocol]", +): + async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): body = b"" more_body = True while more_body: message = await receive() + assert message["type"] == "http.request" body += message.get("body", b"") more_body = message.get("more_body", False) response = Response(b"Body: " + body, media_type="text/plain") await response(scope, receive, send) - protocol = get_connected_protocol(app, protocol_cls) + protocol = get_connected_protocol(app, http_protocol_cls) EXPECT_100_REQUEST = b"\r\n".join( [ b"POST / HTTP/1.1", @@ -708,11 +732,12 @@ async def app(scope, receive, send): @pytest.mark.anyio -@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -async def test_100_continue_not_sent_when_body_not_consumed(protocol_cls): +async def test_100_continue_not_sent_when_body_not_consumed( + http_protocol_cls: "Type[HttpToolsProtocol | H11Protocol]", +): app = Response(b"", status_code=204) - protocol = get_connected_protocol(app, protocol_cls) + protocol = get_connected_protocol(app, http_protocol_cls) EXPECT_100_REQUEST = b"\r\n".join( [ b"POST / HTTP/1.1", @@ -731,23 +756,25 @@ async def test_100_continue_not_sent_when_body_not_consumed(protocol_cls): @pytest.mark.anyio -@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -async def test_supported_upgrade_request(protocol_cls): +async def test_supported_upgrade_request( + http_protocol_cls: "Type[HttpToolsProtocol | H11Protocol]", +): pytest.importorskip("wsproto") app = Response("Hello, world", media_type="text/plain") - protocol = get_connected_protocol(app, protocol_cls, ws="wsproto") + protocol = get_connected_protocol(app, http_protocol_cls, ws="wsproto") protocol.data_received(UPGRADE_REQUEST) assert b"HTTP/1.1 426 " in protocol.transport.buffer @pytest.mark.anyio -@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -async def test_unsupported_ws_upgrade_request(protocol_cls): +async def test_unsupported_ws_upgrade_request( + http_protocol_cls: "Type[HttpToolsProtocol | H11Protocol]", +): app = Response("Hello, world", media_type="text/plain") - protocol = get_connected_protocol(app, protocol_cls, ws="none") + protocol = get_connected_protocol(app, http_protocol_cls, ws="none") protocol.data_received(UPGRADE_REQUEST) await protocol.loop.run_one() assert b"HTTP/1.1 200 OK" in protocol.transport.buffer @@ -755,13 +782,13 @@ async def test_unsupported_ws_upgrade_request(protocol_cls): @pytest.mark.anyio -@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) async def test_unsupported_ws_upgrade_request_warn_on_auto( - caplog: pytest.LogCaptureFixture, protocol_cls + caplog: pytest.LogCaptureFixture, + http_protocol_cls: "Type[HttpToolsProtocol | H11Protocol]", ): app = Response("Hello, world", media_type="text/plain") - protocol = get_connected_protocol(app, protocol_cls, ws="auto") + protocol = get_connected_protocol(app, http_protocol_cls, ws="auto") protocol.ws_protocol_class = None protocol.data_received(UPGRADE_REQUEST) await protocol.loop.run_one() @@ -779,41 +806,44 @@ async def test_unsupported_ws_upgrade_request_warn_on_auto( @pytest.mark.anyio -@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) async def test_http2_upgrade_request( - protocol_cls, ws_protocol_cls: "Type[WSProtocol | WebSocketProtocol]" + http_protocol_cls: "Type[HttpToolsProtocol | H11Protocol]", + ws_protocol_cls: "Type[WebSocketProtocol | WSProtocol]", ): app = Response("Hello, world", media_type="text/plain") - protocol = get_connected_protocol(app, protocol_cls, ws=ws_protocol_cls) + protocol = get_connected_protocol(app, http_protocol_cls, ws=ws_protocol_cls) protocol.data_received(UPGRADE_HTTP2_REQUEST) await protocol.loop.run_one() assert b"HTTP/1.1 200 OK" in protocol.transport.buffer assert b"Hello, world" in protocol.transport.buffer -async def asgi3app(scope, receive, send): +async def asgi3app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): pass -def asgi2app(scope): - async def asgi(receive, send): +def asgi2app(scope: Scope): + async def asgi(receive: ASGIReceiveCallable, send: ASGISendCallable): pass return asgi -asgi_scope_data = [ - (asgi3app, {"version": "3.0", "spec_version": "2.3"}), - (asgi2app, {"version": "2.0", "spec_version": "2.3"}), -] - - @pytest.mark.anyio -@pytest.mark.parametrize("asgi2or3_app, expected_scopes", asgi_scope_data) -@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -async def test_scopes(asgi2or3_app, expected_scopes, protocol_cls): - protocol = get_connected_protocol(asgi2or3_app, protocol_cls) +@pytest.mark.parametrize( + "asgi2or3_app, expected_scopes", + [ + (asgi3app, {"version": "3.0", "spec_version": "2.3"}), + (asgi2app, {"version": "2.0", "spec_version": "2.3"}), + ], +) +async def test_scopes( + asgi2or3_app: ASGIApplication, + expected_scopes: Dict[str, str], + http_protocol_cls: "Type[HttpToolsProtocol | H11Protocol]", +): + protocol = get_connected_protocol(asgi2or3_app, http_protocol_cls) protocol.data_received(SIMPLE_GET_REQUEST) await protocol.loop.run_one() assert expected_scopes == protocol.scope.get("asgi") @@ -828,24 +858,25 @@ async def test_scopes(asgi2or3_app, expected_scopes, protocol_cls): pytest.param(b"GET / HTTP1.1", id="invalid-http-version"), ], ) -@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -async def test_invalid_http_request(request_line, protocol_cls, caplog): +async def test_invalid_http_request( + request_line, http_protocol_cls: "Type[HttpToolsProtocol | H11Protocol]", caplog +): app = Response("Hello, world", media_type="text/plain") request = INVALID_REQUEST_TEMPLATE % request_line caplog.set_level(logging.INFO, logger="uvicorn.error") logging.getLogger("uvicorn.error").propagate = True - protocol = get_connected_protocol(app, protocol_cls) + protocol = get_connected_protocol(app, http_protocol_cls) protocol.data_received(request) assert b"HTTP/1.1 400 Bad Request" in protocol.transport.buffer assert b"Invalid HTTP request received." in protocol.transport.buffer -@pytest.mark.skipif(HttpToolsProtocol is None, reason="httptools is not installed") +@skip_if_no_httptools def test_fragmentation(unused_tcp_port: int): - def receive_all(sock): - chunks = [] + def receive_all(sock: socket.socket): + chunks: List[bytes] = [] while True: chunk = sock.recv(1024) if not chunk: @@ -855,7 +886,7 @@ def receive_all(sock): app = Response("Hello, world", media_type="text/plain") - def send_fragmented_req(path): + def send_fragmented_req(path: str): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(("127.0.0.1", unused_tcp_port)) d = ( @@ -904,7 +935,7 @@ async def test_huge_headers_h11protocol_failure(): @pytest.mark.anyio -@pytest.mark.skipif(HttpToolsProtocol is None, reason="httptools is not installed") +@skip_if_no_httptools async def test_huge_headers_httptools_will_pass(): app = Response("Hello, world", media_type="text/plain") @@ -934,7 +965,7 @@ async def test_huge_headers_h11protocol_failure_with_setting(): @pytest.mark.anyio -@pytest.mark.skipif(HttpToolsProtocol is None, reason="httptools is not installed") +@skip_if_no_httptools async def test_huge_headers_httptools(): app = Response("Hello, world", media_type="text/plain") @@ -967,11 +998,7 @@ async def test_huge_headers_h11_max_incomplete(): "protocol_cls,close_header", ( pytest.param( - HttpToolsProtocol, - b"connection: close", - marks=pytest.mark.skipif( - HttpToolsProtocol is None, reason="httptools is not installed" - ), + HttpToolsProtocol, b"connection: close", marks=skip_if_no_httptools ), (H11Protocol, b"Connection: close"), ), @@ -989,25 +1016,28 @@ async def test_return_close_header(protocol_cls, close_header: bytes): @pytest.mark.anyio -@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -async def test_iterator_headers(protocol_cls): - async def app(scope, receive, send): +async def test_iterator_headers( + http_protocol_cls: "Type[HttpToolsProtocol | H11Protocol]" +): + async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): headers = iter([(b"x-test-header", b"test value")]) await send({"type": "http.response.start", "status": 200, "headers": headers}) await send({"type": "http.response.body", "body": b""}) - protocol = get_connected_protocol(app, protocol_cls) + protocol = get_connected_protocol(app, http_protocol_cls) protocol.data_received(SIMPLE_GET_REQUEST) await protocol.loop.run_one() assert b"x-test-header: test value" in protocol.transport.buffer @pytest.mark.anyio -@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) -async def test_lifespan_state(protocol_cls): +async def test_lifespan_state( + http_protocol_cls: "Type[HttpToolsProtocol | H11Protocol]" +): expected_states = [{"a": 123, "b": [1]}, {"a": 123, "b": [1, 2]}] - async def app(scope, receive, send): + async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): + assert "state" in scope expected_state = expected_states.pop(0) assert scope["state"] == expected_state # modifications to keys are not preserved @@ -1021,7 +1051,7 @@ async def app(scope, receive, send): # in the lifespan tests lifespan.state.update({"a": 123, "b": [1]}) - protocol = get_connected_protocol(app, protocol_cls, lifespan=lifespan) + protocol = get_connected_protocol(app, http_protocol_cls, lifespan=lifespan) for _ in range(2): protocol.data_received(SIMPLE_GET_REQUEST) await protocol.loop.run_one() diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index 64bda3450..1ba20fc7c 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -10,8 +10,14 @@ from websockets.extensions.permessage_deflate import ClientPerMessageDeflateFactory from websockets.typing import Subprotocol -from tests.protocols.test_http import HTTP_PROTOCOLS from tests.utils import run_server +from uvicorn._types import ( + ASGIReceiveCallable, + ASGISendCallable, + Scope, + WebSocketCloseEvent, + WebSocketDisconnectEvent, +) from uvicorn.config import Config from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol @@ -22,9 +28,15 @@ except ModuleNotFoundError: skip_if_no_wsproto = pytest.mark.skipif(True, reason="wsproto is not installed.") +if typing.TYPE_CHECKING: + from uvicorn.protocols.http.h11_impl import H11Protocol + from uvicorn.protocols.http.httptools_impl import HttpToolsProtocol + class WebSocketResponse: - def __init__(self, scope, receive, send): + def __init__( + self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable + ): self.scope = scope self.receive = receive self.send = send @@ -44,13 +56,12 @@ async def asgi(self): @pytest.mark.anyio -@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_invalid_upgrade( ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls, + http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", unused_tcp_port: int, ): - def app(scope): + def app(scope: Scope): return None config = Config( @@ -86,10 +97,9 @@ def app(scope): @pytest.mark.anyio -@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_accept_connection( ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls, + http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", unused_tcp_port: int, ): class App(WebSocketResponse): @@ -113,10 +123,9 @@ async def open_connection(url): @pytest.mark.anyio -@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_shutdown( ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls, + http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", unused_tcp_port: int, ): class App(WebSocketResponse): @@ -137,10 +146,9 @@ async def websocket_connect(self, message): @pytest.mark.anyio -@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_supports_permessage_deflate_extension( ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls, + http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", unused_tcp_port: int, ): class App(WebSocketResponse): @@ -167,17 +175,16 @@ async def open_connection(url): @pytest.mark.anyio -@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_can_disable_permessage_deflate_extension( ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls, + http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", unused_tcp_port: int, ): class App(WebSocketResponse): async def websocket_connect(self, message): await self.send({"type": "websocket.accept"}) - async def open_connection(url): + async def open_connection(url: str): # enable per-message deflate on the client, so that we can check the server # won't support it when it's disabled. extension_factories = [ClientPerMessageDeflateFactory()] @@ -200,17 +207,16 @@ async def open_connection(url): @pytest.mark.anyio -@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_close_connection( ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls, + http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", unused_tcp_port: int, ): class App(WebSocketResponse): async def websocket_connect(self, message): await self.send({"type": "websocket.close"}) - async def open_connection(url): + async def open_connection(url: str): try: await websockets.client.connect(url) except websockets.exceptions.InvalidHandshake: @@ -230,21 +236,20 @@ async def open_connection(url): @pytest.mark.anyio -@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_headers( ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls, + http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", unused_tcp_port: int, ): class App(WebSocketResponse): async def websocket_connect(self, message): headers = self.scope.get("headers") - headers = dict(headers) - assert headers[b"host"].startswith(b"127.0.0.1") - assert headers[b"username"] == bytes("abraão", "utf-8") + headers = dict(headers) # type: ignore + assert headers[b"host"].startswith(b"127.0.0.1") # type: ignore + assert headers[b"username"] == bytes("abraão", "utf-8") # type: ignore await self.send({"type": "websocket.accept"}) - async def open_connection(url): + async def open_connection(url: str): async with websockets.client.connect( url, extra_headers=[("username", "abraão")] ) as websocket: @@ -263,10 +268,9 @@ async def open_connection(url): @pytest.mark.anyio -@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_extra_headers( ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls, + http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", unused_tcp_port: int, ): class App(WebSocketResponse): @@ -275,7 +279,7 @@ async def websocket_connect(self, message): {"type": "websocket.accept", "headers": [(b"extra", b"header")]} ) - async def open_connection(url): + async def open_connection(url: str): async with websockets.client.connect(url) as websocket: return websocket.response_headers @@ -292,10 +296,9 @@ async def open_connection(url): @pytest.mark.anyio -@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_path_and_raw_path( ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls, + http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", unused_tcp_port: int, ): class App(WebSocketResponse): @@ -306,7 +309,7 @@ async def websocket_connect(self, message): assert raw_path == b"/one%2Ftwo" await self.send({"type": "websocket.accept"}) - async def open_connection(url): + async def open_connection(url: str): async with websockets.client.connect(url) as websocket: return websocket.open @@ -323,10 +326,9 @@ async def open_connection(url): @pytest.mark.anyio -@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_send_text_data_to_client( ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls, + http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", unused_tcp_port: int, ): class App(WebSocketResponse): @@ -334,7 +336,7 @@ async def websocket_connect(self, message): await self.send({"type": "websocket.accept"}) await self.send({"type": "websocket.send", "text": "123"}) - async def get_data(url): + async def get_data(url: str): async with websockets.client.connect(url) as websocket: return await websocket.recv() @@ -351,10 +353,9 @@ async def get_data(url): @pytest.mark.anyio -@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_send_binary_data_to_client( ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls, + http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", unused_tcp_port: int, ): class App(WebSocketResponse): @@ -362,7 +363,7 @@ async def websocket_connect(self, message): await self.send({"type": "websocket.accept"}) await self.send({"type": "websocket.send", "bytes": b"123"}) - async def get_data(url): + async def get_data(url: str): async with websockets.client.connect(url) as websocket: return await websocket.recv() @@ -379,10 +380,9 @@ async def get_data(url): @pytest.mark.anyio -@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_send_and_close_connection( ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls, + http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", unused_tcp_port: int, ): class App(WebSocketResponse): @@ -391,7 +391,7 @@ async def websocket_connect(self, message): await self.send({"type": "websocket.send", "text": "123"}) await self.send({"type": "websocket.close"}) - async def get_data(url): + async def get_data(url: str): async with websockets.client.connect(url) as websocket: data = await websocket.recv() is_open = True @@ -415,10 +415,9 @@ async def get_data(url): @pytest.mark.anyio -@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_send_text_data_to_server( ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls, + http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", unused_tcp_port: int, ): class App(WebSocketResponse): @@ -429,7 +428,7 @@ async def websocket_receive(self, message): _text = message.get("text") await self.send({"type": "websocket.send", "text": _text}) - async def send_text(url): + async def send_text(url: str): async with websockets.client.connect(url) as websocket: await websocket.send("abc") return await websocket.recv() @@ -447,10 +446,9 @@ async def send_text(url): @pytest.mark.anyio -@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_send_binary_data_to_server( ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls, + http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", unused_tcp_port: int, ): class App(WebSocketResponse): @@ -461,7 +459,7 @@ async def websocket_receive(self, message): _bytes = message.get("bytes") await self.send({"type": "websocket.send", "bytes": _bytes}) - async def send_text(url): + async def send_text(url: str): async with websockets.client.connect(url) as websocket: await websocket.send(b"abc") return await websocket.recv() @@ -479,10 +477,9 @@ async def send_text(url): @pytest.mark.anyio -@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_send_after_protocol_close( ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls, + http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", unused_tcp_port: int, ): class App(WebSocketResponse): @@ -493,7 +490,7 @@ async def websocket_connect(self, message): with pytest.raises(Exception): await self.send({"type": "websocket.send", "text": "123"}) - async def get_data(url): + async def get_data(url: str): async with websockets.client.connect(url) as websocket: data = await websocket.recv() is_open = True @@ -517,16 +514,15 @@ async def get_data(url): @pytest.mark.anyio -@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_missing_handshake( ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls, + http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", unused_tcp_port: int, ): - async def app(app, receive, send): + async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): pass - async def connect(url): + async def connect(url: str): await websockets.client.connect(url) config = Config( @@ -543,16 +539,15 @@ async def connect(url): @pytest.mark.anyio -@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_send_before_handshake( ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls, + http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", unused_tcp_port: int, ): - async def app(scope, receive, send): + async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): await send({"type": "websocket.send", "text": "123"}) - async def connect(url): + async def connect(url: str): await websockets.client.connect(url) config = Config( @@ -569,17 +564,16 @@ async def connect(url): @pytest.mark.anyio -@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_duplicate_handshake( ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls, + http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", unused_tcp_port: int, ): - async def app(scope, receive, send): + async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): await send({"type": "websocket.accept"}) await send({"type": "websocket.accept"}) - async def connect(url): + async def connect(url: str): async with websockets.client.connect(url) as websocket: _ = await websocket.recv() @@ -597,10 +591,9 @@ async def connect(url): @pytest.mark.anyio -@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_asgi_return_value( ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls, + http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", unused_tcp_port: int, ): """ @@ -608,11 +601,11 @@ async def test_asgi_return_value( the connection is closed with an error condition. """ - async def app(scope, receive, send): + async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): await send({"type": "websocket.accept"}) return 123 - async def connect(url): + async def connect(url: str): async with websockets.client.connect(url) as websocket: _ = await websocket.recv() @@ -630,7 +623,6 @@ async def connect(url): @pytest.mark.anyio -@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) @pytest.mark.parametrize("code", [None, 1000, 1001]) @pytest.mark.parametrize( "reason", @@ -639,18 +631,18 @@ async def connect(url): ) async def test_app_close( ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls, + http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", unused_tcp_port: int, - code, - reason, + code: typing.Optional[int], + reason: typing.Optional[str], ): - async def app(scope, receive, send): + async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): while True: message = await receive() if message["type"] == "websocket.connect": await send({"type": "websocket.accept"}) elif message["type"] == "websocket.receive": - reply = {"type": "websocket.close"} + reply: WebSocketCloseEvent = {"type": "websocket.close"} if code is not None: reply["code"] = code @@ -662,7 +654,7 @@ async def app(scope, receive, send): elif message["type"] == "websocket.disconnect": break - async def websocket_session(url): + async def websocket_session(url: str): async with websockets.client.connect(url) as websocket: await websocket.ping() await websocket.send("abc") @@ -683,13 +675,12 @@ async def websocket_session(url): @pytest.mark.anyio -@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_client_close( ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls, + http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", unused_tcp_port: int, ): - async def app(scope, receive, send): + async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): while True: message = await receive() if message["type"] == "websocket.connect": @@ -699,7 +690,7 @@ async def app(scope, receive, send): elif message["type"] == "websocket.disconnect": break - async def websocket_session(url): + async def websocket_session(url: str): async with websockets.client.connect(url) as websocket: await websocket.ping() await websocket.send("abc") @@ -716,15 +707,14 @@ async def websocket_session(url): @pytest.mark.anyio -@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_client_connection_lost( ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls, + http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", unused_tcp_port: int, ): got_disconnect_event = False - async def app(scope, receive, send): + async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): nonlocal got_disconnect_event while True: message = await receive() @@ -755,25 +745,24 @@ async def app(scope, receive, send): @pytest.mark.anyio -@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_connection_lost_before_handshake_complete( ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls, + http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", unused_tcp_port: int, ): send_accept_task = asyncio.Event() - disconnect_message = {} + disconnect_message: WebSocketDisconnectEvent = {} # type: ignore - async def app(scope, receive, send): + async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): nonlocal disconnect_message message = await receive() if message["type"] == "websocket.connect": await send_accept_task.wait() - disconnect_message = await receive() + disconnect_message = await receive() # type: ignore response: typing.Optional[httpx.Response] = None - async def websocket_session(uri): + async def websocket_session(uri: str): nonlocal response async with httpx.AsyncClient() as client: response = await client.get( @@ -809,16 +798,15 @@ async def websocket_session(uri): @pytest.mark.anyio -@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_send_close_on_server_shutdown( ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls, + http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", unused_tcp_port: int, ): - disconnect_message = {} + disconnect_message: WebSocketDisconnectEvent = {} # type: ignore server_shutdown_event = asyncio.Event() - async def app(scope, receive, send): + async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): nonlocal disconnect_message while True: message = await receive() @@ -830,7 +818,7 @@ async def app(scope, receive, send): websocket: typing.Optional[websockets.client.WebSocketClientProtocol] = None - async def websocket_session(uri): + async def websocket_session(uri: str): nonlocal websocket async with websockets.client.connect(uri) as ws_connection: websocket = ws_connection @@ -859,12 +847,11 @@ async def websocket_session(uri): @pytest.mark.anyio -@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) @pytest.mark.parametrize("subprotocol", ["proto1", "proto2"]) async def test_subprotocols( ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls, - subprotocol, + http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", + subprotocol: str, unused_tcp_port: int, ): class App(WebSocketResponse): @@ -896,7 +883,6 @@ async def get_subprotocol(url: str): @pytest.mark.anyio -@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) @pytest.mark.parametrize( "client_size_sent, server_size_max, expected_result", [ @@ -913,7 +899,7 @@ async def get_subprotocol(url: str): ], ) async def test_send_binary_data_to_server_bigger_than_default_on_websockets( - http_protocol_cls, + http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", client_size_sent: int, server_size_max: int, expected_result: int, @@ -927,12 +913,10 @@ async def websocket_receive(self, message): _bytes = message.get("bytes") await self.send({"type": "websocket.send", "bytes": _bytes}) - async def send_text(url): - async with websockets.client.connect( - url, max_size=client_size_sent - ) as websocket: - await websocket.send(b"\x01" * client_size_sent) - return await websocket.recv() + async def send_text(url: str): + async with websockets.client.connect(url, max_size=client_size_sent) as ws: + await ws.send(b"\x01" * client_size_sent) + return await ws.recv() config = Config( app=App, @@ -953,13 +937,12 @@ async def send_text(url): @pytest.mark.anyio -@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_server_reject_connection( ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls, + http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", unused_tcp_port: int, ): - async def app(scope, receive, send): + async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): assert scope["type"] == "websocket" # Pull up first recv message. @@ -975,7 +958,7 @@ async def app(scope, receive, send): message = await receive() assert message["type"] == "websocket.disconnect" - async def websocket_session(url): + async def websocket_session(url: str): try: async with websockets.client.connect(url): pass # pragma: no cover @@ -994,10 +977,9 @@ async def websocket_session(url): @pytest.mark.anyio -@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_server_can_read_messages_in_buffer_after_close( ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls, + http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", unused_tcp_port: int, ): frames = [] @@ -1018,7 +1000,7 @@ async def websocket_disconnect(self, message): async def websocket_receive(self, message): frames.append(message.get("bytes")) - async def send_text(url): + async def send_text(url: str): async with websockets.client.connect(url) as websocket: await websocket.send(b"abc") await websocket.send(b"abc") @@ -1039,17 +1021,16 @@ async def send_text(url): @pytest.mark.anyio -@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_default_server_headers( ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls, + http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", unused_tcp_port: int, ): class App(WebSocketResponse): async def websocket_connect(self, message): await self.send({"type": "websocket.accept"}) - async def open_connection(url): + async def open_connection(url: str): async with websockets.client.connect(url) as websocket: return websocket.response_headers @@ -1066,17 +1047,16 @@ async def open_connection(url): @pytest.mark.anyio -@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_no_server_headers( ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls, + http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", unused_tcp_port: int, ): class App(WebSocketResponse): async def websocket_connect(self, message): await self.send({"type": "websocket.accept"}) - async def open_connection(url): + async def open_connection(url: str): async with websockets.client.connect(url) as websocket: return websocket.response_headers @@ -1094,17 +1074,16 @@ async def open_connection(url): @pytest.mark.anyio -@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) @skip_if_no_wsproto async def test_no_date_header_on_wsproto( - http_protocol_cls, + http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", unused_tcp_port: int, ): class App(WebSocketResponse): async def websocket_connect(self, message): await self.send({"type": "websocket.accept"}) - async def open_connection(url): + async def open_connection(url: str): async with websockets.client.connect(url) as websocket: return websocket.response_headers @@ -1122,10 +1101,9 @@ async def open_connection(url): @pytest.mark.anyio -@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_multiple_server_header( ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls, + http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", unused_tcp_port: int, ): class App(WebSocketResponse): @@ -1140,7 +1118,7 @@ async def websocket_connect(self, message): } ) - async def open_connection(url): + async def open_connection(url: str): async with websockets.client.connect(url) as websocket: return websocket.response_headers @@ -1157,10 +1135,9 @@ async def open_connection(url): @pytest.mark.anyio -@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_lifespan_state( ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls, + http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", unused_tcp_port: int, ): expected_states = [ @@ -1170,9 +1147,11 @@ async def test_lifespan_state( actual_states = [] - async def lifespan_app(scope, receive, send): + async def lifespan_app( + scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable + ): message = await receive() - assert message["type"] == "lifespan.startup" + assert message["type"] == "lifespan.startup" and "state" in scope scope["state"]["a"] = 123 scope["state"]["b"] = [1] await send({"type": "lifespan.startup.complete"}) @@ -1187,15 +1166,16 @@ async def websocket_connect(self, message): self.scope["state"]["b"].append(2) await self.send({"type": "websocket.accept"}) - async def open_connection(url): + async def open_connection(url: str): async with websockets.client.connect(url) as websocket: return websocket.open - async def app_wrapper(scope, receive, send): + async def app_wrapper( + scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable + ): if scope["type"] == "lifespan": return await lifespan_app(scope, receive, send) - else: - return await App(scope, receive, send) + return await App(scope, receive, send) config = Config( app=app_wrapper, diff --git a/uvicorn/_types.py b/uvicorn/_types.py index ecc3bd5c9..be96d940b 100644 --- a/uvicorn/_types.py +++ b/uvicorn/_types.py @@ -124,14 +124,14 @@ class HTTPResponseDebugEvent(TypedDict): class HTTPResponseStartEvent(TypedDict): type: Literal["http.response.start"] status: int - headers: Iterable[Tuple[bytes, bytes]] + headers: NotRequired[Iterable[Tuple[bytes, bytes]]] trailers: NotRequired[bool] class HTTPResponseBodyEvent(TypedDict): type: Literal["http.response.body"] body: bytes - more_body: bool + more_body: NotRequired[bool] class HTTPResponseTrailersEvent(TypedDict): @@ -156,20 +156,38 @@ class WebSocketConnectEvent(TypedDict): class WebSocketAcceptEvent(TypedDict): type: Literal["websocket.accept"] - subprotocol: Optional[str] - headers: Iterable[Tuple[bytes, bytes]] + subprotocol: NotRequired[Optional[str]] + headers: NotRequired[Iterable[Tuple[bytes, bytes]]] -class WebSocketReceiveEvent(TypedDict): +class _WebSocketReceiveEventBytes(TypedDict): type: Literal["websocket.receive"] - bytes: Optional[bytes] - text: Optional[str] + bytes: bytes + text: NotRequired[None] + + +class _WebSocketReceiveEventText(TypedDict): + type: Literal["websocket.receive"] + bytes: NotRequired[None] + text: str + +WebSocketReceiveEvent = Union[_WebSocketReceiveEventBytes, _WebSocketReceiveEventText] -class WebSocketSendEvent(TypedDict): + +class _WebSocketSendEventBytes(TypedDict): + type: Literal["websocket.send"] + bytes: bytes + text: NotRequired[None] + + +class _WebSocketSendEventText(TypedDict): type: Literal["websocket.send"] - bytes: Optional[bytes] - text: Optional[str] + bytes: NotRequired[None] + text: str + + +WebSocketSendEvent = Union[_WebSocketSendEventBytes, _WebSocketSendEventText] class WebSocketResponseStartEvent(TypedDict): @@ -191,8 +209,8 @@ class WebSocketDisconnectEvent(TypedDict): class WebSocketCloseEvent(TypedDict): type: Literal["websocket.close"] - code: int - reason: Optional[str] + code: NotRequired[int] + reason: NotRequired[Optional[str]] class LifespanStartupEvent(TypedDict): diff --git a/uvicorn/protocols/websockets/websockets_impl.py b/uvicorn/protocols/websockets/websockets_impl.py index 089eeb536..94f40f233 100644 --- a/uvicorn/protocols/websockets/websockets_impl.py +++ b/uvicorn/protocols/websockets/websockets_impl.py @@ -364,13 +364,6 @@ async def asgi_receive( return {"type": "websocket.disconnect", "code": 1012} return {"type": "websocket.disconnect", "code": exc.code} - msg: WebSocketReceiveEvent = { # type: ignore[typeddict-item] - "type": "websocket.receive" - } - if isinstance(data, str): - msg["text"] = data - else: - msg["bytes"] = data - - return msg + return {"type": "websocket.receive", "text": data} + return {"type": "websocket.receive", "bytes": data} diff --git a/uvicorn/protocols/websockets/wsproto_impl.py b/uvicorn/protocols/websockets/wsproto_impl.py index aa4bec8f2..d682eb9f9 100644 --- a/uvicorn/protocols/websockets/wsproto_impl.py +++ b/uvicorn/protocols/websockets/wsproto_impl.py @@ -15,7 +15,6 @@ WebSocketAcceptEvent, WebSocketCloseEvent, WebSocketEvent, - WebSocketReceiveEvent, WebSocketScope, WebSocketSendEvent, ) @@ -181,11 +180,7 @@ def handle_connect(self, event: events.Request) -> None: def handle_text(self, event: events.TextMessage) -> None: self.text += event.data if event.message_finished: - msg: "WebSocketReceiveEvent" = { # type: ignore[typeddict-item] - "type": "websocket.receive", - "text": self.text, - } - self.queue.put_nowait(msg) + self.queue.put_nowait({"type": "websocket.receive", "text": self.text}) self.text = "" if not self.read_paused: self.read_paused = True @@ -195,11 +190,7 @@ def handle_bytes(self, event: events.BytesMessage) -> None: self.bytes += event.data # todo: we may want to guard the size of self.bytes and self.text if event.message_finished: - msg: "WebSocketReceiveEvent" = { # type: ignore[typeddict-item] - "type": "websocket.receive", - "bytes": self.bytes, - } - self.queue.put_nowait(msg) + self.queue.put_nowait({"type": "websocket.receive", "bytes": self.bytes}) self.bytes = b"" if not self.read_paused: self.read_paused = True