Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix status code for multipart subscriptions #3610

Merged
merged 4 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Release type: patch

This release fixes an issue with the http multipart subscription where the
status code would be returned as `None`, instead of 200.

We also took the opportunity to update the internals to better support
additional protocols in future.
6 changes: 3 additions & 3 deletions strawberry/aiohttp/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,18 +189,18 @@ def create_response(

return sub_response

async def create_multipart_response(
async def create_streaming_response(
self,
request: web.Request,
stream: Callable[[], AsyncGenerator[str, None]],
sub_response: web.Response,
headers: Dict[str, str],
) -> web.StreamResponse:
response = web.StreamResponse(
status=sub_response.status,
headers={
**sub_response.headers,
"Transfer-Encoding": "chunked",
"Content-type": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json",
**headers,
},
)

Expand Down
9 changes: 5 additions & 4 deletions strawberry/asgi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Any,
AsyncIterator,
Callable,
Dict,
Mapping,
Optional,
Sequence,
Expand Down Expand Up @@ -221,18 +222,18 @@ def create_response(

return response

async def create_multipart_response(
async def create_streaming_response(
self,
request: Request | WebSocket,
stream: Callable[[], AsyncIterator[str]],
sub_response: Response,
headers: Dict[str, str],
) -> Response:
return StreamingResponse(
stream(),
status_code=sub_response.status_code,
status_code=sub_response.status_code or status.HTTP_200_OK,
headers={
**sub_response.headers,
"Transfer-Encoding": "chunked",
"Content-type": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json",
**headers,
},
)
14 changes: 11 additions & 3 deletions strawberry/channels/handlers/http_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,15 +273,23 @@ async def get_context(
async def get_sub_response(self, request: ChannelsRequest) -> TemporalResponse:
return TemporalResponse()

async def create_multipart_response(
async def create_streaming_response(
self,
request: ChannelsRequest,
stream: Callable[[], AsyncGenerator[str, None]],
sub_response: TemporalResponse,
headers: Dict[str, str],
) -> MultipartChannelsResponse:
status = sub_response.status_code or 200
headers = {k.encode(): v.encode() for k, v in sub_response.headers.items()}
return MultipartChannelsResponse(stream=stream, status=status, headers=headers)

response_headers = {
k.encode(): v.encode() for k, v in sub_response.headers.items()
}
response_headers.update({k.encode(): v.encode() for k, v in headers.items()})

return MultipartChannelsResponse(
stream=stream, status=status, headers=response_headers
)

async def render_graphql_ide(self, request: ChannelsRequest) -> ChannelsResponse:
return ChannelsResponse(
Expand Down
7 changes: 4 additions & 3 deletions strawberry/django/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Any,
AsyncIterator,
Callable,
Dict,
Mapping,
Optional,
Union,
Expand Down Expand Up @@ -185,19 +186,19 @@ def create_response(

return response

async def create_multipart_response(
async def create_streaming_response(
self,
request: HttpRequest,
stream: Callable[[], AsyncIterator[Any]],
sub_response: TemporalHttpResponse,
headers: Dict[str, str],
) -> HttpResponseBase:
return StreamingHttpResponse(
streaming_content=stream(),
status=sub_response.status_code,
headers={
**sub_response.headers,
"Transfer-Encoding": "chunked",
"Content-type": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json",
**headers,
},
)

Expand Down
8 changes: 4 additions & 4 deletions strawberry/fastapi/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,19 +332,19 @@ def create_response(

return response

async def create_multipart_response(
async def create_streaming_response(
self,
request: Request,
stream: Callable[[], AsyncIterator[str]],
sub_response: Response,
headers: Dict[str, str],
) -> Response:
return StreamingResponse(
stream(),
status_code=sub_response.status_code,
status_code=sub_response.status_code or status.HTTP_200_OK,
headers={
**sub_response.headers,
"Transfer-Encoding": "chunked",
"Content-type": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json",
**headers,
},
)

Expand Down
13 changes: 11 additions & 2 deletions strawberry/http/async_base_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,12 @@ def create_response(
@abc.abstractmethod
async def render_graphql_ide(self, request: Request) -> Response: ...

async def create_multipart_response(
async def create_streaming_response(
self,
request: Request,
stream: Callable[[], AsyncGenerator[str, None]],
sub_response: SubResponse,
headers: Dict[str, str],
) -> Response:
raise ValueError("Multipart responses are not supported")

Expand Down Expand Up @@ -199,7 +200,15 @@ async def run(
if isinstance(result, SubscriptionExecutionResult):
stream = self._get_stream(request, result)

return await self.create_multipart_response(request, stream, sub_response)
return await self.create_streaming_response(
request,
stream,
sub_response,
headers={
"Transfer-Encoding": "chunked",
"Content-Type": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json",
},
)

response_data = await self.process_result(request=request, result=result)

Expand Down
6 changes: 3 additions & 3 deletions strawberry/litestar/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,19 +280,19 @@ def create_response(

return response

async def create_multipart_response(
async def create_streaming_response(
self,
request: Request,
stream: Callable[[], AsyncIterator[str]],
sub_response: Response,
headers: Dict[str, str],
) -> Response:
return Stream(
stream(),
status_code=sub_response.status_code,
headers={
**sub_response.headers,
"Transfer-Encoding": "chunked",
"Content-type": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json",
**headers,
},
)

Expand Down
8 changes: 4 additions & 4 deletions strawberry/quart/views.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import warnings
from collections.abc import Mapping
from typing import TYPE_CHECKING, AsyncGenerator, Callable, Optional, cast
from typing import TYPE_CHECKING, AsyncGenerator, Callable, Dict, Optional, cast

from quart import Request, Response, request
from quart.views import View
Expand Down Expand Up @@ -103,19 +103,19 @@ async def dispatch_request(self) -> "ResponseReturnValue": # type: ignore
status=e.status_code,
)

async def create_multipart_response(
async def create_streaming_response(
self,
request: Request,
stream: Callable[[], AsyncGenerator[str, None]],
sub_response: Response,
headers: Dict[str, str],
) -> Response:
return (
stream(),
sub_response.status_code,
{ # type: ignore
**sub_response.headers,
"Transfer-Encoding": "chunked",
"Content-type": "multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json",
**headers,
},
)

Expand Down
6 changes: 3 additions & 3 deletions strawberry/sanic/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,18 +178,18 @@ async def get(self, request: Request) -> HTTPResponse: # type: ignore[override]
except HTTPException as e:
return HTTPResponse(e.reason, status=e.status_code)

async def create_multipart_response(
async def create_streaming_response(
self,
request: Request,
stream: Callable[[], AsyncGenerator[str, None]],
sub_response: TemporalResponse,
headers: Dict[str, str],
) -> HTTPResponse:
response = await self.request.respond(
content_type="multipart/mixed;boundary=graphql;subscriptionSpec=1.0,application/json",
status=sub_response.status_code,
headers={
**sub_response.headers,
"Transfer-Encoding": "chunked",
**headers,
},
)

Expand Down
2 changes: 2 additions & 0 deletions tests/http/test_multipart_subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,5 @@ async def test_multipart_subscription(
data = [d async for d in response.streaming_json()]

assert data == [{"payload": {"data": {"echo": "Hello world"}}}]

assert response.status_code == 200
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Consider using HTTP status constants for better readability and maintainability

Instead of using the hard-coded value 200, it's recommended to use the status.HTTP_200_OK constant from the fastapi or starlette library. This improves code readability and ensures consistency with the rest of the codebase.

Suggested change
assert response.status_code == 200
from starlette import status
assert response.status_code == status.HTTP_200_OK

Loading