diff --git a/starlette/responses.py b/starlette/responses.py index bc8b23f1ce..e863405ae1 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -10,11 +10,12 @@ from urllib.parse import quote import anyio +import anyio.to_thread from starlette._compat import md5_hexdigest from starlette.background import BackgroundTask from starlette.concurrency import iterate_in_threadpool -from starlette.datastructures import URL, MutableHeaders +from starlette.datastructures import URL, Headers, MutableHeaders from starlette.types import Receive, Scope, Send @@ -286,6 +287,7 @@ def __init__( self.media_type = media_type self.background = background self.init_headers(headers) + self.headers.setdefault("accept-ranges", "bytes") if self.filename is not None: content_disposition_filename = quote(self.filename) if content_disposition_filename != self.filename: @@ -322,6 +324,16 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: mode = stat_result.st_mode if not stat.S_ISREG(mode): raise RuntimeError(f"File at path {self.path} is not a file.") + else: + stat_result = self.stat_result + + headers = Headers(scope=scope) + http_range = headers.get("range") + # http_if_range = headers.get("if-range") + + ranges = self._parse_range_header(http_range, stat_result.st_size) + print(ranges) + await send( { "type": "http.response.start", @@ -346,3 +358,44 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: ) if self.background is not None: await self.background() + + def _parse_range_header( + self, http_range: typing.Optional[str], file_size: int + ) -> typing.List[typing.Tuple[int, int]]: + ranges: typing.List[typing.Tuple[int, int]] = [] + if http_range is None: + return ranges + + if http_range.strip() == "": + return ranges + + units, range_ = http_range.split("=", 1) + units = units.strip().lower() + + if units != "bytes": + return ranges + + for val in range_.split(","): + val = val.strip() + if "-" not in val: + return [] + if val.startswith("-"): + suffix_length = int(val[1:]) + if suffix_length == 0: + return [] + ranges.append((file_size - suffix_length, file_size)) + elif val.endswith("-"): + start = int(val[:-1]) + if start >= file_size: + return [] + ranges.append((start, file_size)) + else: + start, end = [int(v) for v in val.split("-", 1)] + start = int(start) + end = int(end) + 1 + if start >= end: + return [] + if end > file_size: + return [] + ranges.append((start, end)) + return ranges diff --git a/tests/test_responses.py b/tests/test_responses.py index 7535fa6412..e3d2138fc3 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -3,6 +3,7 @@ import time import typing from http.cookies import SimpleCookie +from pathlib import Path import anyio import pytest @@ -245,6 +246,28 @@ async def app(scope, receive, send): assert filled_by_bg_task == "6, 7, 8, 9" +def test_file_response_with_range( + tmpdir: Path, test_client_factory: typing.Callable[..., TestClient] +): + path = os.path.join(tmpdir, "xyz") + content = b"" + with open(path, "wb") as file: + file.write(content) + + app = FileResponse(path=path, filename="example.png") + client = test_client_factory(app) + response = client.get("/", headers={"range": "bytes=1-12"}) + expected_disposition = 'attachment; filename="example.png"' + assert response.status_code == status.HTTP_206_PARTIAL_CONTENT + assert response.content == content[1:13] + assert response.headers["content-type"] == "image/png" + assert response.headers["content-disposition"] == expected_disposition + assert response.headers["content-range"] == "bytes 1-12/14" + assert "content-length" in response.headers + assert "last-modified" in response.headers + assert "etag" in response.headers + + def test_file_response_with_directory_raises_error(tmpdir, test_client_factory): app = FileResponse(path=tmpdir, filename="example.png") client = test_client_factory(app)