From 8c74c2c8dba7030154f8af18e016136bea1938fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Tue, 14 Feb 2023 09:01:32 +0100 Subject: [PATCH] Merge pull request from GHSA-74m5-2c7w-9w3x MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ♻️ Refactor multipart parser logic to support limiting max fields and files * ✨ Add support for new request.form() parameters max_files and max_fields * ✅ Add tests for limiting max fields and files in form data * 📝 Add docs about request.form() with new parameters max_files and max_fields * 📝 Update `docs/requests.md` Co-authored-by: Marcelo Trylesinski * 📝 Tweak docs for request.form() * ✏ Fix typo in `starlette/formparsers.py` Co-authored-by: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> --------- Co-authored-by: Marcelo Trylesinski Co-authored-by: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> --- docs/requests.md | 12 +++ starlette/formparsers.py | 188 +++++++++++++++++--------------- starlette/requests.py | 25 ++++- tests/test_formparsers.py | 222 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 356 insertions(+), 91 deletions(-) diff --git a/docs/requests.md b/docs/requests.md index 727d45bc8..f8d19bb9b 100644 --- a/docs/requests.md +++ b/docs/requests.md @@ -114,6 +114,18 @@ state with `disconnected = await request.is_disconnected()`. Request files are normally sent as multipart form data (`multipart/form-data`). +Signature: `request.form(max_files=1000, max_fields=1000)` + +You can configure the number of maximum fields or files with the parameters `max_files` and `max_fields`: + +```python +async with request.form(max_files=1000, max_fields=1000): + ... +``` + +!!! info + These limits are for security reasons, allowing an unlimited number of fields or files could lead to a denial of service attack by consuming a lot of CPU and memory parsing too many empty fields. + When you call `async with request.form() as form` you receive a `starlette.datastructures.FormData` which is an immutable multidict, containing both file uploads and text input. File upload items are represented as instances of `starlette.datastructures.UploadFile`. diff --git a/starlette/formparsers.py b/starlette/formparsers.py index eb76c6f10..a2baa7a90 100644 --- a/starlette/formparsers.py +++ b/starlette/formparsers.py @@ -1,4 +1,5 @@ import typing +from dataclasses import dataclass, field from enum import Enum from tempfile import SpooledTemporaryFile from urllib.parse import unquote_plus @@ -21,15 +22,13 @@ class FormMessage(Enum): END = 5 -class MultiPartMessage(Enum): - PART_BEGIN = 1 - PART_DATA = 2 - PART_END = 3 - HEADER_FIELD = 4 - HEADER_VALUE = 5 - HEADER_END = 6 - HEADERS_FINISHED = 7 - END = 8 +@dataclass +class MultipartPart: + content_disposition: typing.Optional[bytes] = None + field_name: str = "" + data: bytes = b"" + file: typing.Optional[UploadFile] = None + item_headers: typing.List[typing.Tuple[bytes, bytes]] = field(default_factory=list) def _user_safe_decode(src: bytes, codec: str) -> str: @@ -120,46 +119,107 @@ class MultiPartParser: max_file_size = 1024 * 1024 def __init__( - self, headers: Headers, stream: typing.AsyncGenerator[bytes, None] + self, + headers: Headers, + stream: typing.AsyncGenerator[bytes, None], + *, + max_files: typing.Union[int, float] = 1000, + max_fields: typing.Union[int, float] = 1000, ) -> None: assert ( multipart is not None ), "The `python-multipart` library must be installed to use form parsing." self.headers = headers self.stream = stream - self.messages: typing.List[typing.Tuple[MultiPartMessage, bytes]] = [] + self.max_files = max_files + self.max_fields = max_fields + self.items: typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]] = [] + self._current_files = 0 + self._current_fields = 0 + self._current_partial_header_name: bytes = b"" + self._current_partial_header_value: bytes = b"" + self._current_part = MultipartPart() + self._charset = "" + self._file_parts_to_write: typing.List[typing.Tuple[MultipartPart, bytes]] = [] + self._file_parts_to_finish: typing.List[MultipartPart] = [] def on_part_begin(self) -> None: - message = (MultiPartMessage.PART_BEGIN, b"") - self.messages.append(message) + self._current_part = MultipartPart() def on_part_data(self, data: bytes, start: int, end: int) -> None: - message = (MultiPartMessage.PART_DATA, data[start:end]) - self.messages.append(message) + message_bytes = data[start:end] + if self._current_part.file is None: + self._current_part.data += message_bytes + else: + self._file_parts_to_write.append((self._current_part, message_bytes)) def on_part_end(self) -> None: - message = (MultiPartMessage.PART_END, b"") - self.messages.append(message) + if self._current_part.file is None: + self.items.append( + ( + self._current_part.field_name, + _user_safe_decode(self._current_part.data, self._charset), + ) + ) + else: + self._file_parts_to_finish.append(self._current_part) + # The file can be added to the items right now even though it's not + # finished yet, because it will be finished in the `parse()` method, before + # self.items is used in the return value. + self.items.append((self._current_part.field_name, self._current_part.file)) def on_header_field(self, data: bytes, start: int, end: int) -> None: - message = (MultiPartMessage.HEADER_FIELD, data[start:end]) - self.messages.append(message) + self._current_partial_header_name += data[start:end] def on_header_value(self, data: bytes, start: int, end: int) -> None: - message = (MultiPartMessage.HEADER_VALUE, data[start:end]) - self.messages.append(message) + self._current_partial_header_value += data[start:end] def on_header_end(self) -> None: - message = (MultiPartMessage.HEADER_END, b"") - self.messages.append(message) + field = self._current_partial_header_name.lower() + if field == b"content-disposition": + self._current_part.content_disposition = self._current_partial_header_value + self._current_part.item_headers.append( + (field, self._current_partial_header_value) + ) + self._current_partial_header_name = b"" + self._current_partial_header_value = b"" def on_headers_finished(self) -> None: - message = (MultiPartMessage.HEADERS_FINISHED, b"") - self.messages.append(message) + disposition, options = parse_options_header( + self._current_part.content_disposition + ) + try: + self._current_part.field_name = _user_safe_decode( + options[b"name"], self._charset + ) + except KeyError: + raise MultiPartException( + 'The Content-Disposition header field "name" must be ' "provided." + ) + if b"filename" in options: + self._current_files += 1 + if self._current_files > self.max_files: + raise MultiPartException( + f"Too many files. Maximum number of files is {self.max_files}." + ) + filename = _user_safe_decode(options[b"filename"], self._charset) + tempfile = SpooledTemporaryFile(max_size=self.max_file_size) + self._current_part.file = UploadFile( + file=tempfile, # type: ignore[arg-type] + size=0, + filename=filename, + headers=Headers(raw=self._current_part.item_headers), + ) + else: + self._current_fields += 1 + if self._current_fields > self.max_fields: + raise MultiPartException( + f"Too many fields. Maximum number of fields is {self.max_fields}." + ) + self._current_part.file = None def on_end(self) -> None: - message = (MultiPartMessage.END, b"") - self.messages.append(message) + pass async def parse(self) -> FormData: # Parse the Content-Type header to get the multipart boundary. @@ -167,6 +227,7 @@ async def parse(self) -> FormData: charset = params.get(b"charset", "utf-8") if type(charset) == bytes: charset = charset.decode("latin-1") + self._charset = charset try: boundary = params[b"boundary"] except KeyError: @@ -186,68 +247,21 @@ async def parse(self) -> FormData: # Create the parser. parser = multipart.MultipartParser(boundary, callbacks) - header_field = b"" - header_value = b"" - content_disposition = None - field_name = "" - data = b"" - file: typing.Optional[UploadFile] = None - - items: typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]] = [] - item_headers: typing.List[typing.Tuple[bytes, bytes]] = [] - # Feed the parser with data from the request. async for chunk in self.stream: parser.write(chunk) - messages = list(self.messages) - self.messages.clear() - for message_type, message_bytes in messages: - if message_type == MultiPartMessage.PART_BEGIN: - content_disposition = None - data = b"" - item_headers = [] - elif message_type == MultiPartMessage.HEADER_FIELD: - header_field += message_bytes - elif message_type == MultiPartMessage.HEADER_VALUE: - header_value += message_bytes - elif message_type == MultiPartMessage.HEADER_END: - field = header_field.lower() - if field == b"content-disposition": - content_disposition = header_value - item_headers.append((field, header_value)) - header_field = b"" - header_value = b"" - elif message_type == MultiPartMessage.HEADERS_FINISHED: - disposition, options = parse_options_header(content_disposition) - try: - field_name = _user_safe_decode(options[b"name"], charset) - except KeyError: - raise MultiPartException( - 'The Content-Disposition header field "name" must be ' - "provided." - ) - if b"filename" in options: - filename = _user_safe_decode(options[b"filename"], charset) - tempfile = SpooledTemporaryFile(max_size=self.max_file_size) - file = UploadFile( - file=tempfile, # type: ignore[arg-type] - size=0, - filename=filename, - headers=Headers(raw=item_headers), - ) - else: - file = None - elif message_type == MultiPartMessage.PART_DATA: - if file is None: - data += message_bytes - else: - await file.write(message_bytes) - elif message_type == MultiPartMessage.PART_END: - if file is None: - items.append((field_name, _user_safe_decode(data, charset))) - else: - await file.seek(0) - items.append((field_name, file)) + # Write file data, it needs to use await with the UploadFile methods that + # call the corresponding file methods *in a threadpool*, otherwise, if + # they were called directly in the callback methods above (regular, + # non-async functions), that would block the event loop in the main thread. + for part, data in self._file_parts_to_write: + assert part.file # for type checkers + await part.file.write(data) + for part in self._file_parts_to_finish: + assert part.file # for type checkers + await part.file.seek(0) + self._file_parts_to_write.clear() + self._file_parts_to_finish.clear() parser.finalize() - return FormData(items) + return FormData(self.items) diff --git a/starlette/requests.py b/starlette/requests.py index d924b501a..0c7264a56 100644 --- a/starlette/requests.py +++ b/starlette/requests.py @@ -244,7 +244,12 @@ async def json(self) -> typing.Any: self._json = json.loads(body) return self._json - async def _get_form(self) -> FormData: + async def _get_form( + self, + *, + max_files: typing.Union[int, float] = 1000, + max_fields: typing.Union[int, float] = 1000, + ) -> FormData: if self._form is None: assert ( parse_options_header is not None @@ -254,7 +259,12 @@ async def _get_form(self) -> FormData: content_type, _ = parse_options_header(content_type_header) if content_type == b"multipart/form-data": try: - multipart_parser = MultiPartParser(self.headers, self.stream()) + multipart_parser = MultiPartParser( + self.headers, + self.stream(), + max_files=max_files, + max_fields=max_fields, + ) self._form = await multipart_parser.parse() except MultiPartException as exc: if "app" in self.scope: @@ -267,8 +277,15 @@ async def _get_form(self) -> FormData: self._form = FormData() return self._form - def form(self) -> AwaitableOrContextManager[FormData]: - return AwaitableOrContextManagerWrapper(self._get_form()) + def form( + self, + *, + max_files: typing.Union[int, float] = 1000, + max_fields: typing.Union[int, float] = 1000, + ) -> AwaitableOrContextManager[FormData]: + return AwaitableOrContextManagerWrapper( + self._get_form(max_files=max_files, max_fields=max_fields) + ) async def close(self) -> None: if self._form is not None: diff --git a/tests/test_formparsers.py b/tests/test_formparsers.py index 804ce8d26..502f7809f 100644 --- a/tests/test_formparsers.py +++ b/tests/test_formparsers.py @@ -98,6 +98,29 @@ async def app_read_body(scope, receive, send): await response(scope, receive, send) +def make_app_max_parts(max_files: int = 1000, max_fields: int = 1000): + async def app(scope, receive, send): + request = Request(scope, receive) + data = await request.form(max_files=max_files, max_fields=max_fields) + output: typing.Dict[str, typing.Any] = {} + for key, value in data.items(): + if isinstance(value, UploadFile): + content = await value.read() + output[key] = { + "filename": value.filename, + "size": value.size, + "content": content.decode(), + "content_type": value.content_type, + } + else: + output[key] = value + await request.close() + response = JSONResponse(output) + await response(scope, receive, send) + + return app + + def test_multipart_request_data(tmpdir, test_client_factory): client = test_client_factory(app) response = client.post("/", data={"some": "data"}, files=FORCE_MULTIPART) @@ -460,3 +483,202 @@ def test_missing_name_parameter_on_content_disposition( assert ( res.text == 'The Content-Disposition header field "name" must be provided.' ) + + +@pytest.mark.parametrize( + "app,expectation", + [ + (app, pytest.raises(MultiPartException)), + (Starlette(routes=[Mount("/", app=app)]), does_not_raise()), + ], +) +def test_too_many_fields_raise(app, expectation, test_client_factory): + client = test_client_factory(app) + fields = [] + for i in range(1001): + fields.append( + "--B\r\n" f'Content-Disposition: form-data; name="N{i}";\r\n\r\n' "\r\n" + ) + data = "".join(fields).encode("utf-8") + with expectation: + res = client.post( + "/", + data=data, + headers={"Content-Type": ("multipart/form-data; boundary=B")}, + ) + assert res.status_code == 400 + assert res.text == "Too many fields. Maximum number of fields is 1000." + + +@pytest.mark.parametrize( + "app,expectation", + [ + (app, pytest.raises(MultiPartException)), + (Starlette(routes=[Mount("/", app=app)]), does_not_raise()), + ], +) +def test_too_many_files_raise(app, expectation, test_client_factory): + client = test_client_factory(app) + fields = [] + for i in range(1001): + fields.append( + "--B\r\n" + f'Content-Disposition: form-data; name="N{i}"; filename="F{i}";\r\n\r\n' + "\r\n" + ) + data = "".join(fields).encode("utf-8") + with expectation: + res = client.post( + "/", + data=data, + headers={"Content-Type": ("multipart/form-data; boundary=B")}, + ) + assert res.status_code == 400 + assert res.text == "Too many files. Maximum number of files is 1000." + + +@pytest.mark.parametrize( + "app,expectation", + [ + (app, pytest.raises(MultiPartException)), + (Starlette(routes=[Mount("/", app=app)]), does_not_raise()), + ], +) +def test_too_many_files_single_field_raise(app, expectation, test_client_factory): + client = test_client_factory(app) + fields = [] + for i in range(1001): + # This uses the same field name "N" for all files, equivalent to a + # multifile upload form field + fields.append( + "--B\r\n" + f'Content-Disposition: form-data; name="N"; filename="F{i}";\r\n\r\n' + "\r\n" + ) + data = "".join(fields).encode("utf-8") + with expectation: + res = client.post( + "/", + data=data, + headers={"Content-Type": ("multipart/form-data; boundary=B")}, + ) + assert res.status_code == 400 + assert res.text == "Too many files. Maximum number of files is 1000." + + +@pytest.mark.parametrize( + "app,expectation", + [ + (app, pytest.raises(MultiPartException)), + (Starlette(routes=[Mount("/", app=app)]), does_not_raise()), + ], +) +def test_too_many_files_and_fields_raise(app, expectation, test_client_factory): + client = test_client_factory(app) + fields = [] + for i in range(1001): + fields.append( + "--B\r\n" + f'Content-Disposition: form-data; name="F{i}"; filename="F{i}";\r\n\r\n' + "\r\n" + ) + fields.append( + "--B\r\n" f'Content-Disposition: form-data; name="N{i}";\r\n\r\n' "\r\n" + ) + data = "".join(fields).encode("utf-8") + with expectation: + res = client.post( + "/", + data=data, + headers={"Content-Type": ("multipart/form-data; boundary=B")}, + ) + assert res.status_code == 400 + assert res.text == "Too many files. Maximum number of files is 1000." + + +@pytest.mark.parametrize( + "app,expectation", + [ + (make_app_max_parts(max_fields=1), pytest.raises(MultiPartException)), + ( + Starlette(routes=[Mount("/", app=make_app_max_parts(max_fields=1))]), + does_not_raise(), + ), + ], +) +def test_max_fields_is_customizable_low_raises(app, expectation, test_client_factory): + client = test_client_factory(app) + fields = [] + for i in range(2): + fields.append( + "--B\r\n" f'Content-Disposition: form-data; name="N{i}";\r\n\r\n' "\r\n" + ) + data = "".join(fields).encode("utf-8") + with expectation: + res = client.post( + "/", + data=data, + headers={"Content-Type": ("multipart/form-data; boundary=B")}, + ) + assert res.status_code == 400 + assert res.text == "Too many fields. Maximum number of fields is 1." + + +@pytest.mark.parametrize( + "app,expectation", + [ + (make_app_max_parts(max_files=1), pytest.raises(MultiPartException)), + ( + Starlette(routes=[Mount("/", app=make_app_max_parts(max_files=1))]), + does_not_raise(), + ), + ], +) +def test_max_files_is_customizable_low_raises(app, expectation, test_client_factory): + client = test_client_factory(app) + fields = [] + for i in range(2): + fields.append( + "--B\r\n" + f'Content-Disposition: form-data; name="F{i}"; filename="F{i}";\r\n\r\n' + "\r\n" + ) + data = "".join(fields).encode("utf-8") + with expectation: + res = client.post( + "/", + data=data, + headers={"Content-Type": ("multipart/form-data; boundary=B")}, + ) + assert res.status_code == 400 + assert res.text == "Too many files. Maximum number of files is 1." + + +def test_max_fields_is_customizable_high(test_client_factory): + client = test_client_factory(make_app_max_parts(max_fields=2000, max_files=2000)) + fields = [] + for i in range(2000): + fields.append( + "--B\r\n" f'Content-Disposition: form-data; name="N{i}";\r\n\r\n' "\r\n" + ) + fields.append( + "--B\r\n" + f'Content-Disposition: form-data; name="F{i}"; filename="F{i}";\r\n\r\n' + "\r\n" + ) + data = "".join(fields).encode("utf-8") + data += b"--B--\r\n" + res = client.post( + "/", + data=data, + headers={"Content-Type": ("multipart/form-data; boundary=B")}, + ) + assert res.status_code == 200 + res_data = res.json() + assert res_data["N1999"] == "" + assert res_data["F1999"] == { + "filename": "F1999", + "size": 0, + "content": "", + "content_type": None, + }