Skip to content

Commit

Permalink
Add HTTP range headers to FileResponse
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex committed Dec 16, 2023
1 parent 9a213c1 commit a3b77d1
Showing 1 changed file with 53 additions and 1 deletion.
54 changes: 53 additions & 1 deletion starlette/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from starlette._compat import md5_hexdigest
from starlette.background import BackgroundTask
from starlette.concurrency import iterate_in_threadpool
from starlette.datastructures import URL, MutableHeaders
from starlette.datastructures import URL, Headers, MutableHeaders
from starlette.types import Receive, Scope, Send


Expand Down Expand Up @@ -292,6 +292,7 @@ def __init__(
self.media_type = media_type
self.background = background
self.init_headers(headers)
self.headers.setdefault("accept-ranges", "bytes")
if self.filename is not None:
content_disposition_filename = quote(self.filename)
if content_disposition_filename != self.filename:
Expand Down Expand Up @@ -328,6 +329,16 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
mode = stat_result.st_mode
if not stat.S_ISREG(mode):
raise RuntimeError(f"File at path {self.path} is not a file.")
else:
stat_result = self.stat_result

headers = Headers(scope=scope)
http_range = headers.get("range")
# http_if_range = headers.get("if-range")

ranges = self._parse_range_header(http_range, stat_result.st_size)
print(ranges)

await send(
{
"type": "http.response.start",
Expand All @@ -352,3 +363,44 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
)
if self.background is not None:
await self.background()

def _parse_range_header(
self, http_range: typing.Optional[str], file_size: int
) -> typing.List[typing.Tuple[int, int]]:
ranges: typing.List[typing.Tuple[int, int]] = []
if http_range is None:
return ranges

if http_range.strip() == "":
return ranges

units, range_ = http_range.split("=", 1)
units = units.strip().lower()

if units != "bytes":
return ranges

for val in range_.split(","):
val = val.strip()
if "-" not in val:
return []
if val.startswith("-"):
suffix_length = int(val[1:])
if suffix_length == 0:
return []
ranges.append((file_size - suffix_length, file_size))
elif val.endswith("-"):
start = int(val[:-1])
if start >= file_size:
return []
ranges.append((start, file_size))
else:
start, end = [int(v) for v in val.split("-", 1)]
start = int(start)
end = int(end) + 1
if start >= end:
return []
if end > file_size:
return []
ranges.append((start, end))
return ranges

0 comments on commit a3b77d1

Please sign in to comment.