Skip to content

Commit

Permalink
Fix handling of multipart/form-data (#8280) (#8302)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dreamsorcerer authored Apr 7, 2024
1 parent 270ae9c commit cebe526
Show file tree
Hide file tree
Showing 7 changed files with 155 additions and 120 deletions.
1 change: 1 addition & 0 deletions CHANGES/8280.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed ``multipart/form-data`` compliance with :rfc:`7578` -- by :user:`Dreamsorcerer`.
2 changes: 2 additions & 0 deletions CHANGES/8280.deprecation.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Deprecated ``content_transfer_encoding`` parameter in :py:meth:`FormData.add_field()
<aiohttp.FormData.add_field>` -- by :user:`Dreamsorcerer`.
12 changes: 11 additions & 1 deletion aiohttp/formdata.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import io
import warnings
from typing import Any, Iterable, List, Optional
from urllib.parse import urlencode

Expand Down Expand Up @@ -53,7 +54,12 @@ def add_field(
if isinstance(value, io.IOBase):
self._is_multipart = True
elif isinstance(value, (bytes, bytearray, memoryview)):
msg = (
"In v4, passing bytes will no longer create a file field. "
"Please explicitly use the filename parameter or pass a BytesIO object."
)
if filename is None and content_transfer_encoding is None:
warnings.warn(msg, DeprecationWarning)
filename = name

type_options: MultiDict[str] = MultiDict({"name": name})
Expand Down Expand Up @@ -81,7 +87,11 @@ def add_field(
"content_transfer_encoding must be an instance"
" of str. Got: %s" % content_transfer_encoding
)
headers[hdrs.CONTENT_TRANSFER_ENCODING] = content_transfer_encoding
msg = (
"content_transfer_encoding is deprecated. "
"To maintain compatibility with v4 please pass a BytesPayload."
)
warnings.warn(msg, DeprecationWarning)
self._is_multipart = True

self._fields.append((type_options, headers, value))
Expand Down
121 changes: 80 additions & 41 deletions aiohttp/multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,13 +256,22 @@ class BodyPartReader:
chunk_size = 8192

def __init__(
self, boundary: bytes, headers: "CIMultiDictProxy[str]", content: StreamReader
self,
boundary: bytes,
headers: "CIMultiDictProxy[str]",
content: StreamReader,
*,
subtype: str = "mixed",
default_charset: Optional[str] = None,
) -> None:
self.headers = headers
self._boundary = boundary
self._content = content
self._default_charset = default_charset
self._at_eof = False
length = self.headers.get(CONTENT_LENGTH, None)
self._is_form_data = subtype == "form-data"
# https://datatracker.ietf.org/doc/html/rfc7578#section-4.8
length = None if self._is_form_data else self.headers.get(CONTENT_LENGTH, None)
self._length = int(length) if length is not None else None
self._read_bytes = 0
self._unread: Deque[bytes] = deque()
Expand Down Expand Up @@ -329,6 +338,8 @@ async def _read_chunk_from_length(self, size: int) -> bytes:
assert self._length is not None, "Content-Length required for chunked read"
chunk_size = min(size, self._length - self._read_bytes)
chunk = await self._content.read(chunk_size)
if self._content.at_eof():
self._at_eof = True
return chunk

async def _read_chunk_from_stream(self, size: int) -> bytes:
Expand Down Expand Up @@ -449,7 +460,8 @@ def decode(self, data: bytes) -> bytes:
"""
if CONTENT_TRANSFER_ENCODING in self.headers:
data = self._decode_content_transfer(data)
if CONTENT_ENCODING in self.headers:
# https://datatracker.ietf.org/doc/html/rfc7578#section-4.8
if not self._is_form_data and CONTENT_ENCODING in self.headers:
return self._decode_content(data)
return data

Expand Down Expand Up @@ -483,7 +495,7 @@ def get_charset(self, default: str) -> str:
"""Returns charset parameter from Content-Type header or default."""
ctype = self.headers.get(CONTENT_TYPE, "")
mimetype = parse_mimetype(ctype)
return mimetype.parameters.get("charset", default)
return mimetype.parameters.get("charset", self._default_charset or default)

@reify
def name(self) -> Optional[str]:
Expand Down Expand Up @@ -538,9 +550,17 @@ class MultipartReader:
part_reader_cls = BodyPartReader

def __init__(self, headers: Mapping[str, str], content: StreamReader) -> None:
self._mimetype = parse_mimetype(headers[CONTENT_TYPE])
assert self._mimetype.type == "multipart", "multipart/* content type expected"
if "boundary" not in self._mimetype.parameters:
raise ValueError(
"boundary missed for Content-Type: %s" % headers[CONTENT_TYPE]
)

self.headers = headers
self._boundary = ("--" + self._get_boundary()).encode()
self._content = content
self._default_charset: Optional[str] = None
self._last_part: Optional[Union["MultipartReader", BodyPartReader]] = None
self._at_eof = False
self._at_bof = True
Expand Down Expand Up @@ -592,7 +612,24 @@ async def next(
await self._read_boundary()
if self._at_eof: # we just read the last boundary, nothing to do there
return None
self._last_part = await self.fetch_next_part()

part = await self.fetch_next_part()
# https://datatracker.ietf.org/doc/html/rfc7578#section-4.6
if (
self._last_part is None
and self._mimetype.subtype == "form-data"
and isinstance(part, BodyPartReader)
):
_, params = parse_content_disposition(part.headers.get(CONTENT_DISPOSITION))
if params.get("name") == "_charset_":
# Longest encoding in https://encoding.spec.whatwg.org/encodings.json
# is 19 characters, so 32 should be more than enough for any valid encoding.
charset = await part.read_chunk(32)
if len(charset) > 31:
raise RuntimeError("Invalid default charset")
self._default_charset = charset.strip().decode()
part = await self.fetch_next_part()
self._last_part = part
return self._last_part

async def release(self) -> None:
Expand Down Expand Up @@ -628,19 +665,16 @@ def _get_part_reader(
return type(self)(headers, self._content)
return self.multipart_reader_cls(headers, self._content)
else:
return self.part_reader_cls(self._boundary, headers, self._content)

def _get_boundary(self) -> str:
mimetype = parse_mimetype(self.headers[CONTENT_TYPE])

assert mimetype.type == "multipart", "multipart/* content type expected"

if "boundary" not in mimetype.parameters:
raise ValueError(
"boundary missed for Content-Type: %s" % self.headers[CONTENT_TYPE]
return self.part_reader_cls(
self._boundary,
headers,
self._content,
subtype=self._mimetype.subtype,
default_charset=self._default_charset,
)

boundary = mimetype.parameters["boundary"]
def _get_boundary(self) -> str:
boundary = self._mimetype.parameters["boundary"]
if len(boundary) > 70:
raise ValueError("boundary %r is too long (70 chars max)" % boundary)

Expand Down Expand Up @@ -731,6 +765,7 @@ def __init__(self, subtype: str = "mixed", boundary: Optional[str] = None) -> No
super().__init__(None, content_type=ctype)

self._parts: List[_Part] = []
self._is_form_data = subtype == "form-data"

def __enter__(self) -> "MultipartWriter":
return self
Expand Down Expand Up @@ -808,32 +843,36 @@ def append(self, obj: Any, headers: Optional[Mapping[str, str]] = None) -> Paylo

def append_payload(self, payload: Payload) -> Payload:
"""Adds a new body part to multipart writer."""
# compression
encoding: Optional[str] = payload.headers.get(
CONTENT_ENCODING,
"",
).lower()
if encoding and encoding not in ("deflate", "gzip", "identity"):
raise RuntimeError(f"unknown content encoding: {encoding}")
if encoding == "identity":
encoding = None

# te encoding
te_encoding: Optional[str] = payload.headers.get(
CONTENT_TRANSFER_ENCODING,
"",
).lower()
if te_encoding not in ("", "base64", "quoted-printable", "binary"):
raise RuntimeError(
"unknown content transfer encoding: {}" "".format(te_encoding)
encoding: Optional[str] = None
te_encoding: Optional[str] = None
if self._is_form_data:
# https://datatracker.ietf.org/doc/html/rfc7578#section-4.7
# https://datatracker.ietf.org/doc/html/rfc7578#section-4.8
assert CONTENT_DISPOSITION in payload.headers
assert "name=" in payload.headers[CONTENT_DISPOSITION]
assert (
not {CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_TRANSFER_ENCODING}
& payload.headers.keys()
)
if te_encoding == "binary":
te_encoding = None

# size
size = payload.size
if size is not None and not (encoding or te_encoding):
payload.headers[CONTENT_LENGTH] = str(size)
else:
# compression
encoding = payload.headers.get(CONTENT_ENCODING, "").lower()
if encoding and encoding not in ("deflate", "gzip", "identity"):
raise RuntimeError(f"unknown content encoding: {encoding}")
if encoding == "identity":
encoding = None

# te encoding
te_encoding = payload.headers.get(CONTENT_TRANSFER_ENCODING, "").lower()
if te_encoding not in ("", "base64", "quoted-printable", "binary"):
raise RuntimeError(f"unknown content transfer encoding: {te_encoding}")
if te_encoding == "binary":
te_encoding = None

# size
size = payload.size
if size is not None and not (encoding or te_encoding):
payload.headers[CONTENT_LENGTH] = str(size)

self._parts.append((payload, encoding, te_encoding)) # type: ignore[arg-type]
return payload
Expand Down
44 changes: 1 addition & 43 deletions tests/test_client_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1317,48 +1317,6 @@ async def handler(request):
resp.close()


async def test_POST_DATA_with_context_transfer_encoding(aiohttp_client) -> None:
async def handler(request):
data = await request.post()
assert data["name"] == "text"
return web.Response(text=data["name"])

app = web.Application()
app.router.add_post("/", handler)
client = await aiohttp_client(app)

form = aiohttp.FormData()
form.add_field("name", "text", content_transfer_encoding="base64")

resp = await client.post("/", data=form)
assert 200 == resp.status
content = await resp.text()
assert content == "text"
resp.close()


async def test_POST_DATA_with_content_type_context_transfer_encoding(aiohttp_client):
async def handler(request):
data = await request.post()
assert data["name"] == "text"
return web.Response(body=data["name"])

app = web.Application()
app.router.add_post("/", handler)
client = await aiohttp_client(app)

form = aiohttp.FormData()
form.add_field(
"name", "text", content_type="text/plain", content_transfer_encoding="base64"
)

resp = await client.post("/", data=form)
assert 200 == resp.status
content = await resp.text()
assert content == "text"
resp.close()


async def test_POST_MultiDict(aiohttp_client) -> None:
async def handler(request):
data = await request.post()
Expand Down Expand Up @@ -1410,7 +1368,7 @@ async def handler(request):

with fname.open("rb") as f:
async with client.post(
"/", data={"some": f, "test": b"data"}, chunked=True
"/", data={"some": f, "test": io.BytesIO(b"data")}, chunked=True
) as resp:
assert 200 == resp.status

Expand Down
Loading

0 comments on commit cebe526

Please sign in to comment.