diff --git a/CHANGES/8653.bugfix.rst b/CHANGES/8653.bugfix.rst new file mode 100644 index 00000000000..5c4d66c181f --- /dev/null +++ b/CHANGES/8653.bugfix.rst @@ -0,0 +1 @@ +Fixed multipart reading when stream buffer splits the boundary over several read() calls -- by :user:`Dreamsorcerer`. diff --git a/aiohttp/multipart.py b/aiohttp/multipart.py index 71fc2654a1c..26780e3060c 100644 --- a/aiohttp/multipart.py +++ b/aiohttp/multipart.py @@ -266,6 +266,7 @@ def __init__( ) -> None: self.headers = headers self._boundary = boundary + self._boundary_len = len(boundary) + 2 # Boundary + \r\n self._content = content self._default_charset = default_charset self._at_eof = False @@ -346,15 +347,25 @@ async def _read_chunk_from_stream(self, size: int) -> bytes: # Reads content chunk of body part with unknown length. # The Content-Length header for body part is not necessary. assert ( - size >= len(self._boundary) + 2 + size >= self._boundary_len ), "Chunk size must be greater or equal than boundary length + 2" first_chunk = self._prev_chunk is None if first_chunk: self._prev_chunk = await self._content.read(size) - chunk = await self._content.read(size) - self._content_eof += int(self._content.at_eof()) - assert self._content_eof < 3, "Reading after EOF" + chunk = b"" + # content.read() may return less than size, so we need to loop to ensure + # we have enough data to detect the boundary. + while len(chunk) < self._boundary_len: + chunk += await self._content.read(size) + self._content_eof += int(self._content.at_eof()) + assert self._content_eof < 3, "Reading after EOF" + if self._content_eof: + break + if len(chunk) > size: + self._content.unread_data(chunk[size:]) + chunk = chunk[:size] + assert self._prev_chunk is not None window = self._prev_chunk + chunk sub = b"\r\n" + self._boundary diff --git a/tests/test_multipart.py b/tests/test_multipart.py index 436b70957fa..6fc9fe573ec 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -2,6 +2,7 @@ import io import json import pathlib +import sys import zlib from unittest import mock @@ -754,6 +755,66 @@ async def test_invalid_boundary(self) -> None: with pytest.raises(ValueError): await reader.next() + @pytest.mark.skipif(sys.version_info < (3, 10), reason="Needs anext()") + async def test_read_boundary_across_chunks(self) -> None: + class SplitBoundaryStream: + def __init__(self) -> None: + self.content = [ + b"--foobar\r\n\r\n", + b"Hello,\r\n-", + b"-fo", + b"ob", + b"ar\r\n", + b"\r\nwor", + b"ld!", + b"\r\n--f", + b"oobar--", + ] + + async def read(self, size=None) -> bytes: + chunk = self.content.pop(0) + assert len(chunk) <= size + return chunk + + def at_eof(self) -> bool: + return not self.content + + async def readline(self) -> bytes: + line = b"" + while self.content and b"\n" not in line: + line += self.content.pop(0) + line, *extra = line.split(b"\n", maxsplit=1) + if extra and extra[0]: + self.content.insert(0, extra[0]) + return line + b"\n" + + def unread_data(self, data: bytes) -> None: + if self.content: + self.content[0] = data + self.content[0] + else: + self.content.append(data) + + stream = SplitBoundaryStream() + reader = aiohttp.MultipartReader( + {CONTENT_TYPE: 'multipart/related;boundary="foobar"'}, stream + ) + part = await anext(reader) + result = await part.read_chunk(10) + assert result == b"Hello," + result = await part.read_chunk(10) + assert result == b"" + assert part.at_eof() + + part = await anext(reader) + result = await part.read_chunk(10) + assert result == b"world!" + result = await part.read_chunk(10) + assert result == b"" + assert part.at_eof() + + with pytest.raises(StopAsyncIteration): + await anext(reader) + async def test_release(self) -> None: with Stream( newline.join(