Skip to content

Commit

Permalink
add gdown as optional requirement for dataset GDrive download (pytorc…
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier authored and NicolasHug committed Feb 8, 2024
1 parent b2383d4 commit 157b613
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 67 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests-schedule.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
run: pip install --no-build-isolation --editable .

- name: Install all optional dataset requirements
run: pip install scipy pycocotools lmdb requests
run: pip install scipy pycocotools lmdb gdown

- name: Install tests requirements
run: pip install pytest
Expand Down
4 changes: 4 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,7 @@ ignore_missing_imports = True
[mypy-h5py.*]

ignore_missing_imports = True

[mypy-gdown.*]

ignore_missing_imports = True
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def write_version_file():

requirements = [
"numpy",
"requests",
pytorch_dep,
]

Expand Down
4 changes: 4 additions & 0 deletions torchvision/datasets/caltech.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ class Caltech101(VisionDataset):
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
.. warning::
To download the dataset `gdown <https://github.com/wkentaro/gdown>`_ is required.
"""

def __init__(
Expand Down
4 changes: 4 additions & 0 deletions torchvision/datasets/celeba.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ class CelebA(VisionDataset):
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
.. warning::
To download the dataset `gdown <https://github.com/wkentaro/gdown>`_ is required.
"""

base_folder = "celeba"
Expand Down
4 changes: 4 additions & 0 deletions torchvision/datasets/pcam.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ class PCAM(VisionDataset):
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
download (bool, optional): If True, downloads the dataset from the internet and puts it into ``root/pcam``. If
dataset is already downloaded, it is not downloaded again.
.. warning::
To download the dataset `gdown <https://github.com/wkentaro/gdown>`_ is required.
"""

_FILES = {
Expand Down
74 changes: 9 additions & 65 deletions torchvision/datasets/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import bz2
import contextlib
import gzip
import hashlib
import itertools
import lzma
import os
import os.path
Expand All @@ -13,13 +11,11 @@
import urllib
import urllib.error
import urllib.request
import warnings
import zipfile
from typing import Any, Callable, Dict, IO, Iterable, Iterator, List, Optional, Tuple, TypeVar
from urllib.parse import urlparse

import numpy as np
import requests
import torch
from torch.utils.model_zoo import tqdm

Expand Down Expand Up @@ -187,22 +183,6 @@ def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]:
return files


def _extract_gdrive_api_response(response, chunk_size: int = 32 * 1024) -> Tuple[bytes, Iterator[bytes]]:
content = response.iter_content(chunk_size)
first_chunk = None
# filter out keep-alive new chunks
while not first_chunk:
first_chunk = next(content)
content = itertools.chain([first_chunk], content)

try:
match = re.search("<title>Google Drive - (?P<api_response>.+?)</title>", first_chunk.decode())
api_response = match["api_response"] if match is not None else None
except UnicodeDecodeError:
api_response = None
return api_response, content


def download_file_from_google_drive(file_id: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None):
"""Download a Google Drive file from and place it in root.
Expand All @@ -212,7 +192,12 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[
filename (str, optional): Name to save the file under. If None, use the id of the file.
md5 (str, optional): MD5 checksum of the download. If None, do not check
"""
# Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
try:
import gdown
except ModuleNotFoundError:
raise RuntimeError(
"To download files from GDrive, 'gdown' is required. You can install it with 'pip install gdown'."
)

root = os.path.expanduser(root)
if not filename:
Expand All @@ -225,51 +210,10 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[
print(f"Using downloaded {'and verified ' if md5 else ''}file: {fpath}")
return

url = "https://drive.google.com/uc"
params = dict(id=file_id, export="download")
with requests.Session() as session:
response = session.get(url, params=params, stream=True)
gdown.download(id=file_id, output=fpath, quiet=False, user_agent=USER_AGENT)

for key, value in response.cookies.items():
if key.startswith("download_warning"):
token = value
break
else:
api_response, content = _extract_gdrive_api_response(response)
token = "t" if api_response == "Virus scan warning" else None

if token is not None:
response = session.get(url, params=dict(params, confirm=token), stream=True)
api_response, content = _extract_gdrive_api_response(response)

if api_response == "Quota exceeded":
raise RuntimeError(
f"The daily quota of the file {filename} is exceeded and it "
f"can't be downloaded. This is a limitation of Google Drive "
f"and can only be overcome by trying again later."
)

_save_response_content(content, fpath)

# In case we deal with an unhandled GDrive API response, the file should be smaller than 10kB and contain only text
if os.stat(fpath).st_size < 10 * 1024:
with contextlib.suppress(UnicodeDecodeError), open(fpath) as fh:
text = fh.read()
# Regular expression to detect HTML. Copied from https://stackoverflow.com/a/70585604
if re.search(r"</?\s*[a-z-][^>]*\s*>|(&(?:[\w\d]+|#\d+|#x[a-f\d]+);)", text):
warnings.warn(
f"We detected some HTML elements in the downloaded file. "
f"This most likely means that the download triggered an unhandled API response by GDrive. "
f"Please report this to torchvision at https://github.com/pytorch/vision/issues including "
f"the response:\n\n{text}"
)

if md5 and not check_md5(fpath, md5):
raise RuntimeError(
f"The MD5 checksum of the download file {fpath} does not match the one on record."
f"Please delete the file and try again. "
f"If the issue persists, please report this to torchvision at https://github.com/pytorch/vision/issues."
)
if not check_integrity(fpath, md5):
raise RuntimeError("File not found or corrupted.")


def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> None:
Expand Down
4 changes: 4 additions & 0 deletions torchvision/datasets/widerface.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ class WIDERFace(VisionDataset):
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
.. warning::
To download the dataset `gdown <https://github.com/wkentaro/gdown>`_ is required.
"""

BASE_FOLDER = "widerface"
Expand Down

0 comments on commit 157b613

Please sign in to comment.