diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..9824d7ece0 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,7 @@ +Release type: patch + +This release fixes an issue with the http multipart subscription where the +status code would be returned as `None`, instead of 200. + +We also took the opportunity to update the internals to better support +additional protocols in future. diff --git a/strawberry/aiohttp/views.py b/strawberry/aiohttp/views.py index 56a755b2c9..0a8143657f 100644 --- a/strawberry/aiohttp/views.py +++ b/strawberry/aiohttp/views.py @@ -189,18 +189,18 @@ def create_response( return sub_response - async def create_multipart_response( + async def create_streaming_response( self, request: web.Request, stream: Callable[[], AsyncGenerator[str, None]], sub_response: web.Response, + headers: Dict[str, str], ) -> web.StreamResponse: response = web.StreamResponse( status=sub_response.status, headers={ **sub_response.headers, - "Transfer-Encoding": "chunked", - "Content-type": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json", + **headers, }, ) diff --git a/strawberry/asgi/__init__.py b/strawberry/asgi/__init__.py index 26ead659e6..d5aae404f6 100644 --- a/strawberry/asgi/__init__.py +++ b/strawberry/asgi/__init__.py @@ -7,6 +7,7 @@ Any, AsyncIterator, Callable, + Dict, Mapping, Optional, Sequence, @@ -221,18 +222,18 @@ def create_response( return response - async def create_multipart_response( + async def create_streaming_response( self, request: Request | WebSocket, stream: Callable[[], AsyncIterator[str]], sub_response: Response, + headers: Dict[str, str], ) -> Response: return StreamingResponse( stream(), - status_code=sub_response.status_code, + status_code=sub_response.status_code or status.HTTP_200_OK, headers={ **sub_response.headers, - "Transfer-Encoding": "chunked", - "Content-type": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json", + **headers, }, ) diff --git a/strawberry/channels/handlers/http_handler.py b/strawberry/channels/handlers/http_handler.py index e7a96d1d7b..9169265cbb 100644 --- a/strawberry/channels/handlers/http_handler.py +++ b/strawberry/channels/handlers/http_handler.py @@ -273,15 +273,23 @@ async def get_context( async def get_sub_response(self, request: ChannelsRequest) -> TemporalResponse: return TemporalResponse() - async def create_multipart_response( + async def create_streaming_response( self, request: ChannelsRequest, stream: Callable[[], AsyncGenerator[str, None]], sub_response: TemporalResponse, + headers: Dict[str, str], ) -> MultipartChannelsResponse: status = sub_response.status_code or 200 - headers = {k.encode(): v.encode() for k, v in sub_response.headers.items()} - return MultipartChannelsResponse(stream=stream, status=status, headers=headers) + + response_headers = { + k.encode(): v.encode() for k, v in sub_response.headers.items() + } + response_headers.update({k.encode(): v.encode() for k, v in headers.items()}) + + return MultipartChannelsResponse( + stream=stream, status=status, headers=response_headers + ) async def render_graphql_ide(self, request: ChannelsRequest) -> ChannelsResponse: return ChannelsResponse( diff --git a/strawberry/django/views.py b/strawberry/django/views.py index ee831eabcb..0ce5bf920a 100644 --- a/strawberry/django/views.py +++ b/strawberry/django/views.py @@ -7,6 +7,7 @@ Any, AsyncIterator, Callable, + Dict, Mapping, Optional, Union, @@ -185,19 +186,19 @@ def create_response( return response - async def create_multipart_response( + async def create_streaming_response( self, request: HttpRequest, stream: Callable[[], AsyncIterator[Any]], sub_response: TemporalHttpResponse, + headers: Dict[str, str], ) -> HttpResponseBase: return StreamingHttpResponse( streaming_content=stream(), status=sub_response.status_code, headers={ **sub_response.headers, - "Transfer-Encoding": "chunked", - "Content-type": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json", + **headers, }, ) diff --git a/strawberry/fastapi/router.py b/strawberry/fastapi/router.py index e25dfcd820..833b656383 100644 --- a/strawberry/fastapi/router.py +++ b/strawberry/fastapi/router.py @@ -332,19 +332,19 @@ def create_response( return response - async def create_multipart_response( + async def create_streaming_response( self, request: Request, stream: Callable[[], AsyncIterator[str]], sub_response: Response, + headers: Dict[str, str], ) -> Response: return StreamingResponse( stream(), - status_code=sub_response.status_code, + status_code=sub_response.status_code or status.HTTP_200_OK, headers={ **sub_response.headers, - "Transfer-Encoding": "chunked", - "Content-type": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json", + **headers, }, ) diff --git a/strawberry/http/async_base_view.py b/strawberry/http/async_base_view.py index e210861a55..9e3ace71d1 100644 --- a/strawberry/http/async_base_view.py +++ b/strawberry/http/async_base_view.py @@ -92,11 +92,12 @@ def create_response( @abc.abstractmethod async def render_graphql_ide(self, request: Request) -> Response: ... - async def create_multipart_response( + async def create_streaming_response( self, request: Request, stream: Callable[[], AsyncGenerator[str, None]], sub_response: SubResponse, + headers: Dict[str, str], ) -> Response: raise ValueError("Multipart responses are not supported") @@ -199,7 +200,15 @@ async def run( if isinstance(result, SubscriptionExecutionResult): stream = self._get_stream(request, result) - return await self.create_multipart_response(request, stream, sub_response) + return await self.create_streaming_response( + request, + stream, + sub_response, + headers={ + "Transfer-Encoding": "chunked", + "Content-Type": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json", + }, + ) response_data = await self.process_result(request=request, result=result) diff --git a/strawberry/litestar/controller.py b/strawberry/litestar/controller.py index dc4e37a0af..7ff68c69ad 100644 --- a/strawberry/litestar/controller.py +++ b/strawberry/litestar/controller.py @@ -280,19 +280,19 @@ def create_response( return response - async def create_multipart_response( + async def create_streaming_response( self, request: Request, stream: Callable[[], AsyncIterator[str]], sub_response: Response, + headers: Dict[str, str], ) -> Response: return Stream( stream(), status_code=sub_response.status_code, headers={ **sub_response.headers, - "Transfer-Encoding": "chunked", - "Content-type": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json", + **headers, }, ) diff --git a/strawberry/quart/views.py b/strawberry/quart/views.py index e6938a6034..f9db21a01d 100644 --- a/strawberry/quart/views.py +++ b/strawberry/quart/views.py @@ -1,6 +1,6 @@ import warnings from collections.abc import Mapping -from typing import TYPE_CHECKING, AsyncGenerator, Callable, Optional, cast +from typing import TYPE_CHECKING, AsyncGenerator, Callable, Dict, Optional, cast from quart import Request, Response, request from quart.views import View @@ -103,19 +103,19 @@ async def dispatch_request(self) -> "ResponseReturnValue": # type: ignore status=e.status_code, ) - async def create_multipart_response( + async def create_streaming_response( self, request: Request, stream: Callable[[], AsyncGenerator[str, None]], sub_response: Response, + headers: Dict[str, str], ) -> Response: return ( stream(), sub_response.status_code, { # type: ignore **sub_response.headers, - "Transfer-Encoding": "chunked", - "Content-type": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json", + **headers, }, ) diff --git a/strawberry/sanic/views.py b/strawberry/sanic/views.py index 83b7d3ca5c..edb30075f6 100644 --- a/strawberry/sanic/views.py +++ b/strawberry/sanic/views.py @@ -178,18 +178,18 @@ async def get(self, request: Request) -> HTTPResponse: # type: ignore[override] except HTTPException as e: return HTTPResponse(e.reason, status=e.status_code) - async def create_multipart_response( + async def create_streaming_response( self, request: Request, stream: Callable[[], AsyncGenerator[str, None]], sub_response: TemporalResponse, + headers: Dict[str, str], ) -> HTTPResponse: response = await self.request.respond( - content_type="multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json", status=sub_response.status_code, headers={ **sub_response.headers, - "Transfer-Encoding": "chunked", + **headers, }, ) diff --git a/tests/http/test_multipart_subscription.py b/tests/http/test_multipart_subscription.py index d9cb289d8c..52e08c6f2d 100644 --- a/tests/http/test_multipart_subscription.py +++ b/tests/http/test_multipart_subscription.py @@ -71,3 +71,5 @@ async def test_multipart_subscription( data = [d async for d in response.streaming_json()] assert data == [{"payload": {"data": {"echo": "Hello world"}}}] + + assert response.status_code == 200