Skip to content

Commit

Permalink
allow using Request.form() as a context manager (#1903)
Browse files Browse the repository at this point in the history
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
  • Loading branch information
adriangb and Kludex authored Feb 6, 2023
1 parent 0a63a6e commit c568b55
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 10 deletions.
10 changes: 5 additions & 5 deletions docs/requests.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ There are a few different interfaces for returning the body of the request:

The request body as bytes: `await request.body()`

The request body, parsed as form data or multipart: `await request.form()`
The request body, parsed as form data or multipart: `async with request.form() as form:`

The request body, parsed as JSON: `await request.json()`

Expand Down Expand Up @@ -114,7 +114,7 @@ state with `disconnected = await request.is_disconnected()`.

Request files are normally sent as multipart form data (`multipart/form-data`).

When you call `await request.form()` you receive a `starlette.datastructures.FormData` which is an immutable
When you call `async with request.form() as form` you receive a `starlette.datastructures.FormData` which is an immutable
multidict, containing both file uploads and text input. File upload items are represented as instances of `starlette.datastructures.UploadFile`.

`UploadFile` has the following attributes:
Expand All @@ -137,9 +137,9 @@ As all these methods are `async` methods, you need to "await" them.
For example, you can get the file name and the contents with:

```python
form = await request.form()
filename = form["upload_file"].filename
contents = await form["upload_file"].read()
async with request.form() as form:
filename = form["upload_file"].filename
contents = await form["upload_file"].read()
```

!!! info
Expand Down
62 changes: 62 additions & 0 deletions starlette/_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
import asyncio
import functools
import sys
import typing
from types import TracebackType

if sys.version_info < (3, 8): # pragma: no cover
from typing_extensions import Protocol
else: # pragma: no cover
from typing import Protocol


def is_async_callable(obj: typing.Any) -> bool:
Expand All @@ -10,3 +17,58 @@ def is_async_callable(obj: typing.Any) -> bool:
return asyncio.iscoroutinefunction(obj) or (
callable(obj) and asyncio.iscoroutinefunction(obj.__call__)
)


T_co = typing.TypeVar("T_co", covariant=True)


# TODO: once 3.8 is the minimum supported version (27 Jun 2023)
# this can just become
# class AwaitableOrContextManager(
# typing.Awaitable[T_co],
# typing.AsyncContextManager[T_co],
# typing.Protocol[T_co],
# ):
# pass
class AwaitableOrContextManager(Protocol[T_co]):
def __await__(self) -> typing.Generator[typing.Any, None, T_co]:
... # pragma: no cover

async def __aenter__(self) -> T_co:
... # pragma: no cover

async def __aexit__(
self,
__exc_type: typing.Optional[typing.Type[BaseException]],
__exc_value: typing.Optional[BaseException],
__traceback: typing.Optional[TracebackType],
) -> typing.Union[bool, None]:
... # pragma: no cover


class SupportsAsyncClose(Protocol):
async def close(self) -> None:
... # pragma: no cover


SupportsAsyncCloseType = typing.TypeVar(
"SupportsAsyncCloseType", bound=SupportsAsyncClose, covariant=False
)


class AwaitableOrContextManagerWrapper(typing.Generic[SupportsAsyncCloseType]):
__slots__ = ("aw", "entered")

def __init__(self, aw: typing.Awaitable[SupportsAsyncCloseType]) -> None:
self.aw = aw

def __await__(self) -> typing.Generator[typing.Any, None, SupportsAsyncCloseType]:
return self.aw.__await__()

async def __aenter__(self) -> SupportsAsyncCloseType:
self.entered = await self.aw
return self.entered

async def __aexit__(self, *args: typing.Any) -> typing.Union[None, bool]:
await self.entered.close()
return None
15 changes: 10 additions & 5 deletions starlette/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import anyio

from starlette._utils import AwaitableOrContextManager, AwaitableOrContextManagerWrapper
from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State
from starlette.exceptions import HTTPException
from starlette.formparsers import FormParser, MultiPartException, MultiPartParser
Expand Down Expand Up @@ -187,6 +188,8 @@ async def empty_send(message: Message) -> typing.NoReturn:


class Request(HTTPConnection):
_form: typing.Optional[FormData]

def __init__(
self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send
):
Expand All @@ -196,6 +199,7 @@ def __init__(
self._send = send
self._stream_consumed = False
self._is_disconnected = False
self._form = None

@property
def method(self) -> str:
Expand All @@ -210,10 +214,8 @@ async def stream(self) -> typing.AsyncGenerator[bytes, None]:
yield self._body
yield b""
return

if self._stream_consumed:
raise RuntimeError("Stream consumed")

self._stream_consumed = True
while True:
message = await self._receive()
Expand Down Expand Up @@ -242,8 +244,8 @@ async def json(self) -> typing.Any:
self._json = json.loads(body)
return self._json

async def form(self) -> FormData:
if not hasattr(self, "_form"):
async def _get_form(self) -> FormData:
if self._form is None:
assert (
parse_options_header is not None
), "The `python-multipart` library must be installed to use form parsing."
Expand All @@ -265,8 +267,11 @@ async def form(self) -> FormData:
self._form = FormData()
return self._form

def form(self) -> AwaitableOrContextManager[FormData]:
return AwaitableOrContextManagerWrapper(self._get_form())

async def close(self) -> None:
if hasattr(self, "_form"):
if self._form is not None:
await self._form.close()

async def is_disconnected(self) -> bool:
Expand Down
13 changes: 13 additions & 0 deletions tests/test_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,19 @@ async def app(scope, receive, send):
assert response.json() == {"form": {"abc": "123 @"}}


def test_request_form_context_manager(test_client_factory):
async def app(scope, receive, send):
request = Request(scope, receive)
async with request.form() as form:
response = JSONResponse({"form": dict(form)})
await response(scope, receive, send)

client = test_client_factory(app)

response = client.post("/", data={"abc": "123 @"})
assert response.json() == {"form": {"abc": "123 @"}}


def test_request_body_then_stream(test_client_factory):
async def app(scope, receive, send):
request = Request(scope, receive)
Expand Down

0 comments on commit c568b55

Please sign in to comment.