From a3b77d1fc9832316e775029bab3b1cf3e5893a76 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Sun, 10 Dec 2023 11:21:34 +0300 Subject: [PATCH] Add HTTP range headers to `FileResponse` --- starlette/responses.py | 54 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 53 insertions(+), 1 deletion(-) diff --git a/starlette/responses.py b/starlette/responses.py index 507e79b32..a31758d65 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -16,7 +16,7 @@ from starlette._compat import md5_hexdigest from starlette.background import BackgroundTask from starlette.concurrency import iterate_in_threadpool -from starlette.datastructures import URL, MutableHeaders +from starlette.datastructures import URL, Headers, MutableHeaders from starlette.types import Receive, Scope, Send @@ -292,6 +292,7 @@ def __init__( self.media_type = media_type self.background = background self.init_headers(headers) + self.headers.setdefault("accept-ranges", "bytes") if self.filename is not None: content_disposition_filename = quote(self.filename) if content_disposition_filename != self.filename: @@ -328,6 +329,16 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: mode = stat_result.st_mode if not stat.S_ISREG(mode): raise RuntimeError(f"File at path {self.path} is not a file.") + else: + stat_result = self.stat_result + + headers = Headers(scope=scope) + http_range = headers.get("range") + # http_if_range = headers.get("if-range") + + ranges = self._parse_range_header(http_range, stat_result.st_size) + print(ranges) + await send( { "type": "http.response.start", @@ -352,3 +363,44 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: ) if self.background is not None: await self.background() + + def _parse_range_header( + self, http_range: typing.Optional[str], file_size: int + ) -> typing.List[typing.Tuple[int, int]]: + ranges: typing.List[typing.Tuple[int, int]] = [] + if http_range is None: + return ranges + + if http_range.strip() == "": + return ranges + + units, range_ = http_range.split("=", 1) + units = units.strip().lower() + + if units != "bytes": + return ranges + + for val in range_.split(","): + val = val.strip() + if "-" not in val: + return [] + if val.startswith("-"): + suffix_length = int(val[1:]) + if suffix_length == 0: + return [] + ranges.append((file_size - suffix_length, file_size)) + elif val.endswith("-"): + start = int(val[:-1]) + if start >= file_size: + return [] + ranges.append((start, file_size)) + else: + start, end = [int(v) for v in val.split("-", 1)] + start = int(start) + end = int(end) + 1 + if start >= end: + return [] + if end > file_size: + return [] + ranges.append((start, end)) + return ranges