diff --git a/torchvision/datasets/mnist.py b/torchvision/datasets/mnist.py index 66299cd9418..6ad54f323b0 100644 --- a/torchvision/datasets/mnist.py +++ b/torchvision/datasets/mnist.py @@ -132,10 +132,14 @@ def download(self): makedir_exist_ok(self.raw_folder) makedir_exist_ok(self.processed_folder) + # create fake header, see: https://github.com/pytorch/vision/issues/1938 + header = [('User-agent', 'Mozilla/5.0')] + # download files for url, md5 in self.resources: filename = url.rpartition('/')[2] - download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5) + download_and_extract_archive( + url, download_root=self.raw_folder, filename=filename, md5=md5, header=header) # process and save as torch files print('Processing...') diff --git a/torchvision/datasets/utils.py b/torchvision/datasets/utils.py index aa61237a6d2..e8bd29109a3 100644 --- a/torchvision/datasets/utils.py +++ b/torchvision/datasets/utils.py @@ -56,7 +56,7 @@ def makedir_exist_ok(dirpath): raise -def download_url(url, root, filename=None, md5=None): +def download_url(url, root, filename=None, md5=None, header=None): """Download a file from a url and place it in root. Args: @@ -64,6 +64,7 @@ def download_url(url, root, filename=None, md5=None): root (str): Directory to place downloaded file in filename (str, optional): Name to save the file under. If None, use the basename of the URL md5 (str, optional): MD5 checksum of the download. If None, do not check + header (list(tuples), optional): Header to pass to urlretrieve. If None, do not set """ from six.moves import urllib @@ -78,6 +79,11 @@ def download_url(url, root, filename=None, md5=None): if check_integrity(fpath, md5): print('Using downloaded and verified file: ' + fpath) else: # download the file + # set header, see #1938 + if header is not None: + opener = urllib.request.build_opener() + opener.addheaders = header + urllib.request.install_opener(opener) try: print('Downloading ' + url + ' to ' + fpath) urllib.request.urlretrieve( @@ -98,6 +104,9 @@ def download_url(url, root, filename=None, md5=None): # check integrity of downloaded file if not check_integrity(fpath, md5): raise RuntimeError("File not found or corrupted.") + # reset header, see #1938 + opener = urllib.request.build_opener() + urllib.request.install_opener(opener) def list_dir(root, prefix=False): @@ -254,14 +263,14 @@ def extract_archive(from_path, to_path=None, remove_finished=False): def download_and_extract_archive(url, download_root, extract_root=None, filename=None, - md5=None, remove_finished=False): + md5=None, remove_finished=False, header=None): download_root = os.path.expanduser(download_root) if extract_root is None: extract_root = download_root if not filename: filename = os.path.basename(url) - download_url(url, download_root, filename, md5) + download_url(url, download_root, filename, md5, header) archive = os.path.join(download_root, filename) print("Extracting {} to {}".format(archive, extract_root))