Skip to content

Commit

Permalink
Merge pull request from GHSA-74m5-2c7w-9w3x
Browse files Browse the repository at this point in the history
* ♻️ Refactor multipart parser logic to support limiting max fields and files

* ✨ Add support for new request.form() parameters max_files and max_fields

* ✅ Add tests for limiting max fields and files in form data

* 📝 Add docs about request.form() with new parameters max_files and max_fields

* 📝 Update `docs/requests.md`

Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>

* 📝 Tweak docs for request.form()

* ✏ Fix typo in `starlette/formparsers.py`

Co-authored-by: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com>

---------

Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
Co-authored-by: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com>
  • Loading branch information
3 people authored Feb 14, 2023
1 parent 5771a78 commit 8c74c2c
Show file tree
Hide file tree
Showing 4 changed files with 356 additions and 91 deletions.
12 changes: 12 additions & 0 deletions docs/requests.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,18 @@ state with `disconnected = await request.is_disconnected()`.

Request files are normally sent as multipart form data (`multipart/form-data`).

Signature: `request.form(max_files=1000, max_fields=1000)`

You can configure the number of maximum fields or files with the parameters `max_files` and `max_fields`:

```python
async with request.form(max_files=1000, max_fields=1000):
...
```

!!! info
These limits are for security reasons, allowing an unlimited number of fields or files could lead to a denial of service attack by consuming a lot of CPU and memory parsing too many empty fields.

When you call `async with request.form() as form` you receive a `starlette.datastructures.FormData` which is an immutable
multidict, containing both file uploads and text input. File upload items are represented as instances of `starlette.datastructures.UploadFile`.

Expand Down
188 changes: 101 additions & 87 deletions starlette/formparsers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import typing
from dataclasses import dataclass, field
from enum import Enum
from tempfile import SpooledTemporaryFile
from urllib.parse import unquote_plus
Expand All @@ -21,15 +22,13 @@ class FormMessage(Enum):
END = 5


class MultiPartMessage(Enum):
PART_BEGIN = 1
PART_DATA = 2
PART_END = 3
HEADER_FIELD = 4
HEADER_VALUE = 5
HEADER_END = 6
HEADERS_FINISHED = 7
END = 8
@dataclass
class MultipartPart:
content_disposition: typing.Optional[bytes] = None
field_name: str = ""
data: bytes = b""
file: typing.Optional[UploadFile] = None
item_headers: typing.List[typing.Tuple[bytes, bytes]] = field(default_factory=list)


def _user_safe_decode(src: bytes, codec: str) -> str:
Expand Down Expand Up @@ -120,53 +119,115 @@ class MultiPartParser:
max_file_size = 1024 * 1024

def __init__(
self, headers: Headers, stream: typing.AsyncGenerator[bytes, None]
self,
headers: Headers,
stream: typing.AsyncGenerator[bytes, None],
*,
max_files: typing.Union[int, float] = 1000,
max_fields: typing.Union[int, float] = 1000,
) -> None:
assert (
multipart is not None
), "The `python-multipart` library must be installed to use form parsing."
self.headers = headers
self.stream = stream
self.messages: typing.List[typing.Tuple[MultiPartMessage, bytes]] = []
self.max_files = max_files
self.max_fields = max_fields
self.items: typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]] = []
self._current_files = 0
self._current_fields = 0
self._current_partial_header_name: bytes = b""
self._current_partial_header_value: bytes = b""
self._current_part = MultipartPart()
self._charset = ""
self._file_parts_to_write: typing.List[typing.Tuple[MultipartPart, bytes]] = []
self._file_parts_to_finish: typing.List[MultipartPart] = []

def on_part_begin(self) -> None:
message = (MultiPartMessage.PART_BEGIN, b"")
self.messages.append(message)
self._current_part = MultipartPart()

def on_part_data(self, data: bytes, start: int, end: int) -> None:
message = (MultiPartMessage.PART_DATA, data[start:end])
self.messages.append(message)
message_bytes = data[start:end]
if self._current_part.file is None:
self._current_part.data += message_bytes
else:
self._file_parts_to_write.append((self._current_part, message_bytes))

def on_part_end(self) -> None:
message = (MultiPartMessage.PART_END, b"")
self.messages.append(message)
if self._current_part.file is None:
self.items.append(
(
self._current_part.field_name,
_user_safe_decode(self._current_part.data, self._charset),
)
)
else:
self._file_parts_to_finish.append(self._current_part)
# The file can be added to the items right now even though it's not
# finished yet, because it will be finished in the `parse()` method, before
# self.items is used in the return value.
self.items.append((self._current_part.field_name, self._current_part.file))

def on_header_field(self, data: bytes, start: int, end: int) -> None:
message = (MultiPartMessage.HEADER_FIELD, data[start:end])
self.messages.append(message)
self._current_partial_header_name += data[start:end]

def on_header_value(self, data: bytes, start: int, end: int) -> None:
message = (MultiPartMessage.HEADER_VALUE, data[start:end])
self.messages.append(message)
self._current_partial_header_value += data[start:end]

def on_header_end(self) -> None:
message = (MultiPartMessage.HEADER_END, b"")
self.messages.append(message)
field = self._current_partial_header_name.lower()
if field == b"content-disposition":
self._current_part.content_disposition = self._current_partial_header_value
self._current_part.item_headers.append(
(field, self._current_partial_header_value)
)
self._current_partial_header_name = b""
self._current_partial_header_value = b""

def on_headers_finished(self) -> None:
message = (MultiPartMessage.HEADERS_FINISHED, b"")
self.messages.append(message)
disposition, options = parse_options_header(
self._current_part.content_disposition
)
try:
self._current_part.field_name = _user_safe_decode(
options[b"name"], self._charset
)
except KeyError:
raise MultiPartException(
'The Content-Disposition header field "name" must be ' "provided."
)
if b"filename" in options:
self._current_files += 1
if self._current_files > self.max_files:
raise MultiPartException(
f"Too many files. Maximum number of files is {self.max_files}."
)
filename = _user_safe_decode(options[b"filename"], self._charset)
tempfile = SpooledTemporaryFile(max_size=self.max_file_size)
self._current_part.file = UploadFile(
file=tempfile, # type: ignore[arg-type]
size=0,
filename=filename,
headers=Headers(raw=self._current_part.item_headers),
)
else:
self._current_fields += 1
if self._current_fields > self.max_fields:
raise MultiPartException(
f"Too many fields. Maximum number of fields is {self.max_fields}."
)
self._current_part.file = None

def on_end(self) -> None:
message = (MultiPartMessage.END, b"")
self.messages.append(message)
pass

async def parse(self) -> FormData:
# Parse the Content-Type header to get the multipart boundary.
_, params = parse_options_header(self.headers["Content-Type"])
charset = params.get(b"charset", "utf-8")
if type(charset) == bytes:
charset = charset.decode("latin-1")
self._charset = charset
try:
boundary = params[b"boundary"]
except KeyError:
Expand All @@ -186,68 +247,21 @@ async def parse(self) -> FormData:

# Create the parser.
parser = multipart.MultipartParser(boundary, callbacks)
header_field = b""
header_value = b""
content_disposition = None
field_name = ""
data = b""
file: typing.Optional[UploadFile] = None

items: typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]] = []
item_headers: typing.List[typing.Tuple[bytes, bytes]] = []

# Feed the parser with data from the request.
async for chunk in self.stream:
parser.write(chunk)
messages = list(self.messages)
self.messages.clear()
for message_type, message_bytes in messages:
if message_type == MultiPartMessage.PART_BEGIN:
content_disposition = None
data = b""
item_headers = []
elif message_type == MultiPartMessage.HEADER_FIELD:
header_field += message_bytes
elif message_type == MultiPartMessage.HEADER_VALUE:
header_value += message_bytes
elif message_type == MultiPartMessage.HEADER_END:
field = header_field.lower()
if field == b"content-disposition":
content_disposition = header_value
item_headers.append((field, header_value))
header_field = b""
header_value = b""
elif message_type == MultiPartMessage.HEADERS_FINISHED:
disposition, options = parse_options_header(content_disposition)
try:
field_name = _user_safe_decode(options[b"name"], charset)
except KeyError:
raise MultiPartException(
'The Content-Disposition header field "name" must be '
"provided."
)
if b"filename" in options:
filename = _user_safe_decode(options[b"filename"], charset)
tempfile = SpooledTemporaryFile(max_size=self.max_file_size)
file = UploadFile(
file=tempfile, # type: ignore[arg-type]
size=0,
filename=filename,
headers=Headers(raw=item_headers),
)
else:
file = None
elif message_type == MultiPartMessage.PART_DATA:
if file is None:
data += message_bytes
else:
await file.write(message_bytes)
elif message_type == MultiPartMessage.PART_END:
if file is None:
items.append((field_name, _user_safe_decode(data, charset)))
else:
await file.seek(0)
items.append((field_name, file))
# Write file data, it needs to use await with the UploadFile methods that
# call the corresponding file methods *in a threadpool*, otherwise, if
# they were called directly in the callback methods above (regular,
# non-async functions), that would block the event loop in the main thread.
for part, data in self._file_parts_to_write:
assert part.file # for type checkers
await part.file.write(data)
for part in self._file_parts_to_finish:
assert part.file # for type checkers
await part.file.seek(0)
self._file_parts_to_write.clear()
self._file_parts_to_finish.clear()

parser.finalize()
return FormData(items)
return FormData(self.items)
25 changes: 21 additions & 4 deletions starlette/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,12 @@ async def json(self) -> typing.Any:
self._json = json.loads(body)
return self._json

async def _get_form(self) -> FormData:
async def _get_form(
self,
*,
max_files: typing.Union[int, float] = 1000,
max_fields: typing.Union[int, float] = 1000,
) -> FormData:
if self._form is None:
assert (
parse_options_header is not None
Expand All @@ -254,7 +259,12 @@ async def _get_form(self) -> FormData:
content_type, _ = parse_options_header(content_type_header)
if content_type == b"multipart/form-data":
try:
multipart_parser = MultiPartParser(self.headers, self.stream())
multipart_parser = MultiPartParser(
self.headers,
self.stream(),
max_files=max_files,
max_fields=max_fields,
)
self._form = await multipart_parser.parse()
except MultiPartException as exc:
if "app" in self.scope:
Expand All @@ -267,8 +277,15 @@ async def _get_form(self) -> FormData:
self._form = FormData()
return self._form

def form(self) -> AwaitableOrContextManager[FormData]:
return AwaitableOrContextManagerWrapper(self._get_form())
def form(
self,
*,
max_files: typing.Union[int, float] = 1000,
max_fields: typing.Union[int, float] = 1000,
) -> AwaitableOrContextManager[FormData]:
return AwaitableOrContextManagerWrapper(
self._get_form(max_files=max_files, max_fields=max_fields)
)

async def close(self) -> None:
if self._form is not None:
Expand Down
Loading

0 comments on commit 8c74c2c

Please sign in to comment.