Skip to content

Commit

Permalink
fix(requests): Fix multithreaded downloads
Browse files Browse the repository at this point in the history
For some reason moving the download speed calculation code from the requests() function to the download() function makes it actually multi-threaded instead of sequential downloads.
  • Loading branch information
rlaphoenix committed Apr 4, 2024
1 parent 5d1b54b commit 994ab15
Showing 1 changed file with 77 additions and 80 deletions.
157 changes: 77 additions & 80 deletions devine/core/downloaders/requests.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math
import os
import time
from concurrent import futures
from concurrent.futures import as_completed
from concurrent.futures.thread import ThreadPoolExecutor
from http.cookiejar import CookieJar
from pathlib import Path
Expand All @@ -19,11 +19,14 @@
CHUNK_SIZE = 1024
PROGRESS_WINDOW = 5

DOWNLOAD_SIZES = []
LAST_SPEED_REFRESH = time.time()

def download(
url: str,
save_path: Path,
session: Optional[Session] = None,
segmented: bool = False,
**kwargs: Any
) -> Generator[dict[str, Any], None, None]:
"""
Expand All @@ -48,10 +51,13 @@ def download(
session: The Requests Session to make HTTP requests with. Useful to set Header,
Cookie, and Proxy data. Connections are saved and re-used with the session
so long as the server keeps the connection alive.
segmented: If downloads are segments or parts of one bigger file.
kwargs: Any extra keyword arguments to pass to the session.get() call. Use this
for one-time request changes like a header, cookie, or proxy. For example,
to request Byte-ranges use e.g., `headers={"Range": "bytes=0-128"}`.
"""
global LAST_SPEED_REFRESH

session = session or Session()

save_dir = save_path.parent
Expand All @@ -69,6 +75,7 @@ def download(
file_downloaded=save_path,
written=save_path.stat().st_size
)
# TODO: This should return, potential recovery bug

# TODO: Design a control file format so we know how much of the file is missing
control_file.write_bytes(b"")
Expand All @@ -77,47 +84,59 @@ def download(
try:
while True:
written = 0

# these are for single-url speed calcs only
download_sizes = []
last_speed_refresh = time.time()

try:
stream = session.get(url, stream=True, **kwargs)
stream.raise_for_status()

try:
content_length = int(stream.headers.get("Content-Length", "0"))
except ValueError:
content_length = 0
if not segmented:
try:
content_length = int(stream.headers.get("Content-Length", "0"))
except ValueError:
content_length = 0

if content_length > 0:
yield dict(total=math.ceil(content_length / CHUNK_SIZE))
else:
# we have no data to calculate total chunks
yield dict(total=None) # indeterminate mode
if content_length > 0:
yield dict(total=math.ceil(content_length / CHUNK_SIZE))
else:
# we have no data to calculate total chunks
yield dict(total=None) # indeterminate mode

with open(save_path, "wb") as f:
for chunk in stream.iter_content(chunk_size=CHUNK_SIZE):
download_size = len(chunk)
f.write(chunk)
written += download_size

yield dict(advance=1)

now = time.time()
time_since = now - last_speed_refresh

download_sizes.append(download_size)
if time_since > PROGRESS_WINDOW or download_size < CHUNK_SIZE:
data_size = sum(download_sizes)
download_speed = math.ceil(data_size / (time_since or 1))
yield dict(downloaded=f"{filesize.decimal(download_speed)}/s")
last_speed_refresh = now
download_sizes.clear()

yield dict(
file_downloaded=save_path,
written=written
)
if not segmented:
yield dict(advance=1)
now = time.time()
time_since = now - last_speed_refresh
download_sizes.append(download_size)
if time_since > PROGRESS_WINDOW or download_size < CHUNK_SIZE:
data_size = sum(download_sizes)
download_speed = math.ceil(data_size / (time_since or 1))
yield dict(downloaded=f"{filesize.decimal(download_speed)}/s")
last_speed_refresh = now
download_sizes.clear()

yield dict(file_downloaded=save_path, written=written)

if segmented:
yield dict(advance=1)
now = time.time()
time_since = now - LAST_SPEED_REFRESH
if written: # no size == skipped dl
DOWNLOAD_SIZES.append(written)
if DOWNLOAD_SIZES and time_since > PROGRESS_WINDOW:
data_size = sum(DOWNLOAD_SIZES)
download_speed = math.ceil(data_size / (time_since or 1))
yield dict(downloaded=f"{filesize.decimal(download_speed)}/s")
LAST_SPEED_REFRESH = now
DOWNLOAD_SIZES.clear()
break
except Exception as e:
save_path.unlink(missing_ok=True)
Expand Down Expand Up @@ -237,59 +256,37 @@ def requests(

yield dict(total=len(urls))

download_sizes = []
last_speed_refresh = time.time()

with ThreadPoolExecutor(max_workers=max_workers) as pool:
for i, future in enumerate(futures.as_completed((
pool.submit(
download,
session=session,
**url
)
for url in urls
))):
file_path, download_size = None, None
try:
for status_update in future.result():
if status_update.get("file_downloaded") and status_update.get("written"):
file_path = status_update["file_downloaded"]
download_size = status_update["written"]
elif len(urls) == 1:
# these are per-chunk updates, only useful if it's one big file
yield status_update
except KeyboardInterrupt:
DOWNLOAD_CANCELLED.set() # skip pending track downloads
yield dict(downloaded="[yellow]CANCELLING")
pool.shutdown(wait=True, cancel_futures=True)
yield dict(downloaded="[yellow]CANCELLED")
# tell dl that it was cancelled
# the pool is already shut down, so exiting loop is fine
raise
except Exception:
DOWNLOAD_CANCELLED.set() # skip pending track downloads
yield dict(downloaded="[red]FAILING")
pool.shutdown(wait=True, cancel_futures=True)
yield dict(downloaded="[red]FAILED")
# tell dl that it failed
# the pool is already shut down, so exiting loop is fine
raise
else:
yield dict(file_downloaded=file_path, written=download_size)
yield dict(advance=1)

now = time.time()
time_since = now - last_speed_refresh

if download_size: # no size == skipped dl
download_sizes.append(download_size)

if download_sizes and (time_since > PROGRESS_WINDOW or i == len(urls)):
data_size = sum(download_sizes)
download_speed = math.ceil(data_size / (time_since or 1))
yield dict(downloaded=f"{filesize.decimal(download_speed)}/s")
last_speed_refresh = now
download_sizes.clear()
try:
with ThreadPoolExecutor(max_workers=max_workers) as pool:
for future in as_completed(
pool.submit(
download,
session=session,
segmented=True,
**url
)
for url in urls
):
try:
yield from future.result()
except KeyboardInterrupt:
DOWNLOAD_CANCELLED.set() # skip pending track downloads
yield dict(downloaded="[yellow]CANCELLING")
pool.shutdown(wait=True, cancel_futures=True)
yield dict(downloaded="[yellow]CANCELLED")
# tell dl that it was cancelled
# the pool is already shut down, so exiting loop is fine
raise
except Exception:
DOWNLOAD_CANCELLED.set() # skip pending track downloads
yield dict(downloaded="[red]FAILING")
pool.shutdown(wait=True, cancel_futures=True)
yield dict(downloaded="[red]FAILED")
# tell dl that it failed
# the pool is already shut down, so exiting loop is fine
raise
finally:
DOWNLOAD_SIZES.clear()


__all__ = ("requests",)

0 comments on commit 994ab15

Please sign in to comment.