Skip to content

Commit

Permalink
Consider FileResponse.chunk_size when handling multiple ranges (#2703)
Browse files Browse the repository at this point in the history
* Take in consideration the `FileResponse.chunk_size` on multiple ranges

* Update starlette/responses.py

* Update starlette/responses.py

* Update starlette/responses.py

Co-authored-by: Frost Ming <mianghong@gmail.com>

---------

Co-authored-by: Frost Ming <mianghong@gmail.com>
  • Loading branch information
Kludex and frostming authored Sep 25, 2024
1 parent 4fbf766 commit b8139f9
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 14 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -103,4 +103,5 @@ exclude_lines = [
"pragma: nocover",
"if typing.TYPE_CHECKING:",
"@typing.overload",
"raise NotImplementedError",
]
16 changes: 6 additions & 10 deletions starlette/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,13 +374,7 @@ async def _handle_simple(self, send: Send, send_header_only: bool) -> None:
while more_body:
chunk = await file.read(self.chunk_size)
more_body = len(chunk) == self.chunk_size
await send(
{
"type": "http.response.body",
"body": chunk,
"more_body": more_body,
}
)
await send({"type": "http.response.body", "body": chunk, "more_body": more_body})

async def _handle_single_range(
self, send: Send, start: int, end: int, file_size: int, send_header_only: bool
Expand Down Expand Up @@ -419,10 +413,12 @@ async def _handle_multiple_ranges(
else:
async with await anyio.open_file(self.path, mode="rb") as file:
for start, end in ranges:
await file.seek(start)
chunk = await file.read(min(self.chunk_size, end - start))
await send({"type": "http.response.body", "body": header_generator(start, end), "more_body": True})
await send({"type": "http.response.body", "body": chunk, "more_body": True})
await file.seek(start)
while start < end:
chunk = await file.read(min(self.chunk_size, end - start))
start += len(chunk)
await send({"type": "http.response.body", "body": chunk, "more_body": True})
await send({"type": "http.response.body", "body": b"\n", "more_body": True})
await send(
{
Expand Down
6 changes: 3 additions & 3 deletions tests/middleware/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ async def passthrough(
}

async def receive() -> Message:
raise NotImplementedError("Should not be called!") # pragma: no cover
raise NotImplementedError("Should not be called!")

async def send(message: Message) -> None:
if message["type"] == "http.response.body":
Expand Down Expand Up @@ -330,7 +330,7 @@ async def passthrough(request: Request, call_next: RequestResponseEndpoint) -> R
}

async def receive() -> Message:
raise NotImplementedError("Should not be called!") # pragma: no cover
raise NotImplementedError("Should not be called!")

async def send(message: Message) -> None:
if message["type"] == "http.response.body":
Expand Down Expand Up @@ -403,7 +403,7 @@ async def passthrough(
}

async def receive() -> Message:
raise NotImplementedError("Should not be called!") # pragma: no cover
raise NotImplementedError("Should not be called!")

async def send(message: Message) -> None:
if message["type"] == "http.response.body":
Expand Down
57 changes: 56 additions & 1 deletion tests/test_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import time
from http.cookies import SimpleCookie
from pathlib import Path
from typing import AsyncIterator, Iterator
from typing import Any, AsyncIterator, Iterator

import anyio
import pytest
Expand Down Expand Up @@ -682,3 +682,58 @@ def test_file_response_insert_ranges(file_response_client: TestClient) -> None:
"",
f"--{boundary}--",
]


@pytest.mark.anyio
async def test_file_response_multi_small_chunk_size(readme_file: Path) -> None:
class SmallChunkSizeFileResponse(FileResponse):
chunk_size = 10

app = SmallChunkSizeFileResponse(path=str(readme_file))

received_chunks: list[bytes] = []
start_message: dict[str, Any] = {}

async def receive() -> Message:
raise NotImplementedError("Should not be called!")

async def send(message: Message) -> None:
if message["type"] == "http.response.start":
start_message.update(message)
elif message["type"] == "http.response.body":
received_chunks.append(message["body"])

await app({"type": "http", "method": "get", "headers": [(b"range", b"bytes=0-15,20-35,35-50")]}, receive, send)
assert start_message["status"] == 206

headers = Headers(raw=start_message["headers"])
assert headers.get("content-type") == "text/plain; charset=utf-8"
assert headers.get("accept-ranges") == "bytes"
assert "content-length" in headers
assert "last-modified" in headers
assert "etag" in headers
assert headers["content-range"].startswith("multipart/byteranges; boundary=")
boundary = headers["content-range"].split("boundary=")[1]

assert received_chunks == [
# Send the part headers.
f"--{boundary}\nContent-Type: text/plain; charset=utf-8\nContent-Range: bytes 0-15/526\n\n".encode(),
# Send the first chunk (10 bytes).
b"# B\xc3\xa1iZ\xc3\xa9\n",
# Send the second chunk (6 bytes).
b"\nPower",
# Send the new line to separate the parts.
b"\n",
# Send the part headers. We merge the ranges 20-35 and 35-50 into a single part.
f"--{boundary}\nContent-Type: text/plain; charset=utf-8\nContent-Range: bytes 20-50/526\n\n".encode(),
# Send the first chunk (10 bytes).
b"and exquis",
# Send the second chunk (10 bytes).
b"ite WSGI/A",
# Send the third chunk (10 bytes).
b"SGI framew",
# Send the last chunk (1 byte).
b"o",
b"\n",
f"\n--{boundary}--\n".encode(),
]

0 comments on commit b8139f9

Please sign in to comment.