diff --git a/tests/protocols/test_http.py b/tests/protocols/test_http.py index 79946ad2c..8db279504 100644 --- a/tests/protocols/test_http.py +++ b/tests/protocols/test_http.py @@ -966,3 +966,17 @@ async def test_return_close_header(protocol_cls, close_header: bytes): assert b"content-type: text/plain" in protocol.transport.buffer assert b"content-length: 12" in protocol.transport.buffer assert close_header in protocol.transport.buffer + + +@pytest.mark.anyio +@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) +async def test_iterator_headers(protocol_cls): + async def app(scope, receive, send): + headers = iter([(b"x-test-header", b"test value")]) + await send({"type": "http.response.start", "status": 200, "headers": headers}) + await send({"type": "http.response.body", "body": b""}) + + protocol = get_connected_protocol(app, protocol_cls) + protocol.data_received(SIMPLE_GET_REQUEST) + await protocol.loop.run_one() + assert b"x-test-header: test value" in protocol.transport.buffer diff --git a/uvicorn/protocols/http/h11_impl.py b/uvicorn/protocols/http/h11_impl.py index 9f6931ec5..c2764b028 100644 --- a/uvicorn/protocols/http/h11_impl.py +++ b/uvicorn/protocols/http/h11_impl.py @@ -468,10 +468,7 @@ async def send(self, message: "ASGISendEvent") -> None: self.waiting_for_100_continue = False status_code = message["status"] - message_headers = cast( - List[Tuple[bytes, bytes]], message.get("headers", []) - ) - headers = self.default_headers + message_headers + headers = self.default_headers + list(message.get("headers", [])) if CLOSE_HEADER in self.scope["headers"] and CLOSE_HEADER not in headers: headers = headers + [CLOSE_HEADER]