Skip to content

Commit

Permalink
Fix status code for multipart subscriptions (#3610)
Browse files Browse the repository at this point in the history
* Fix status code

* Don't couple multipart with view

* Add release file

* Fix headers in channels integration
  • Loading branch information
patrick91 authored Sep 2, 2024
1 parent 98b8563 commit a51d09c
Show file tree
Hide file tree
Showing 11 changed files with 57 additions and 29 deletions.
7 changes: 7 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Release type: patch

This release fixes an issue with the http multipart subscription where the
status code would be returned as `None`, instead of 200.

We also took the opportunity to update the internals to better support
additional protocols in future.
6 changes: 3 additions & 3 deletions strawberry/aiohttp/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,18 +189,18 @@ def create_response(

return sub_response

async def create_multipart_response(
async def create_streaming_response(
self,
request: web.Request,
stream: Callable[[], AsyncGenerator[str, None]],
sub_response: web.Response,
headers: Dict[str, str],
) -> web.StreamResponse:
response = web.StreamResponse(
status=sub_response.status,
headers={
**sub_response.headers,
"Transfer-Encoding": "chunked",
"Content-type": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json",
**headers,
},
)

Expand Down
9 changes: 5 additions & 4 deletions strawberry/asgi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Any,
AsyncIterator,
Callable,
Dict,
Mapping,
Optional,
Sequence,
Expand Down Expand Up @@ -221,18 +222,18 @@ def create_response(

return response

async def create_multipart_response(
async def create_streaming_response(
self,
request: Request | WebSocket,
stream: Callable[[], AsyncIterator[str]],
sub_response: Response,
headers: Dict[str, str],
) -> Response:
return StreamingResponse(
stream(),
status_code=sub_response.status_code,
status_code=sub_response.status_code or status.HTTP_200_OK,
headers={
**sub_response.headers,
"Transfer-Encoding": "chunked",
"Content-type": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json",
**headers,
},
)
14 changes: 11 additions & 3 deletions strawberry/channels/handlers/http_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,15 +273,23 @@ async def get_context(
async def get_sub_response(self, request: ChannelsRequest) -> TemporalResponse:
return TemporalResponse()

async def create_multipart_response(
async def create_streaming_response(
self,
request: ChannelsRequest,
stream: Callable[[], AsyncGenerator[str, None]],
sub_response: TemporalResponse,
headers: Dict[str, str],
) -> MultipartChannelsResponse:
status = sub_response.status_code or 200
headers = {k.encode(): v.encode() for k, v in sub_response.headers.items()}
return MultipartChannelsResponse(stream=stream, status=status, headers=headers)

response_headers = {
k.encode(): v.encode() for k, v in sub_response.headers.items()
}
response_headers.update({k.encode(): v.encode() for k, v in headers.items()})

return MultipartChannelsResponse(
stream=stream, status=status, headers=response_headers
)

async def render_graphql_ide(self, request: ChannelsRequest) -> ChannelsResponse:
return ChannelsResponse(
Expand Down
7 changes: 4 additions & 3 deletions strawberry/django/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Any,
AsyncIterator,
Callable,
Dict,
Mapping,
Optional,
Union,
Expand Down Expand Up @@ -185,19 +186,19 @@ def create_response(

return response

async def create_multipart_response(
async def create_streaming_response(
self,
request: HttpRequest,
stream: Callable[[], AsyncIterator[Any]],
sub_response: TemporalHttpResponse,
headers: Dict[str, str],
) -> HttpResponseBase:
return StreamingHttpResponse(
streaming_content=stream(),
status=sub_response.status_code,
headers={
**sub_response.headers,
"Transfer-Encoding": "chunked",
"Content-type": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json",
**headers,
},
)

Expand Down
8 changes: 4 additions & 4 deletions strawberry/fastapi/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,19 +332,19 @@ def create_response(

return response

async def create_multipart_response(
async def create_streaming_response(
self,
request: Request,
stream: Callable[[], AsyncIterator[str]],
sub_response: Response,
headers: Dict[str, str],
) -> Response:
return StreamingResponse(
stream(),
status_code=sub_response.status_code,
status_code=sub_response.status_code or status.HTTP_200_OK,
headers={
**sub_response.headers,
"Transfer-Encoding": "chunked",
"Content-type": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json",
**headers,
},
)

Expand Down
13 changes: 11 additions & 2 deletions strawberry/http/async_base_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,12 @@ def create_response(
@abc.abstractmethod
async def render_graphql_ide(self, request: Request) -> Response: ...

async def create_multipart_response(
async def create_streaming_response(
self,
request: Request,
stream: Callable[[], AsyncGenerator[str, None]],
sub_response: SubResponse,
headers: Dict[str, str],
) -> Response:
raise ValueError("Multipart responses are not supported")

Expand Down Expand Up @@ -199,7 +200,15 @@ async def run(
if isinstance(result, SubscriptionExecutionResult):
stream = self._get_stream(request, result)

return await self.create_multipart_response(request, stream, sub_response)
return await self.create_streaming_response(
request,
stream,
sub_response,
headers={
"Transfer-Encoding": "chunked",
"Content-Type": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json",
},
)

response_data = await self.process_result(request=request, result=result)

Expand Down
6 changes: 3 additions & 3 deletions strawberry/litestar/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,19 +280,19 @@ def create_response(

return response

async def create_multipart_response(
async def create_streaming_response(
self,
request: Request,
stream: Callable[[], AsyncIterator[str]],
sub_response: Response,
headers: Dict[str, str],
) -> Response:
return Stream(
stream(),
status_code=sub_response.status_code,
headers={
**sub_response.headers,
"Transfer-Encoding": "chunked",
"Content-type": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json",
**headers,
},
)

Expand Down
8 changes: 4 additions & 4 deletions strawberry/quart/views.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import warnings
from collections.abc import Mapping
from typing import TYPE_CHECKING, AsyncGenerator, Callable, Optional, cast
from typing import TYPE_CHECKING, AsyncGenerator, Callable, Dict, Optional, cast

from quart import Request, Response, request
from quart.views import View
Expand Down Expand Up @@ -103,19 +103,19 @@ async def dispatch_request(self) -> "ResponseReturnValue": # type: ignore
status=e.status_code,
)

async def create_multipart_response(
async def create_streaming_response(
self,
request: Request,
stream: Callable[[], AsyncGenerator[str, None]],
sub_response: Response,
headers: Dict[str, str],
) -> Response:
return (
stream(),
sub_response.status_code,
{ # type: ignore
**sub_response.headers,
"Transfer-Encoding": "chunked",
"Content-type": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json",
**headers,
},
)

Expand Down
6 changes: 3 additions & 3 deletions strawberry/sanic/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,18 +178,18 @@ async def get(self, request: Request) -> HTTPResponse: # type: ignore[override]
except HTTPException as e:
return HTTPResponse(e.reason, status=e.status_code)

async def create_multipart_response(
async def create_streaming_response(
self,
request: Request,
stream: Callable[[], AsyncGenerator[str, None]],
sub_response: TemporalResponse,
headers: Dict[str, str],
) -> HTTPResponse:
response = await self.request.respond(
content_type="multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json",
status=sub_response.status_code,
headers={
**sub_response.headers,
"Transfer-Encoding": "chunked",
**headers,
},
)

Expand Down
2 changes: 2 additions & 0 deletions tests/http/test_multipart_subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,5 @@ async def test_multipart_subscription(
data = [d async for d in response.streaming_json()]

assert data == [{"payload": {"data": {"echo": "Hello world"}}}]

assert response.status_code == 200

0 comments on commit a51d09c

Please sign in to comment.