From 8c40a25dffe2d8811cd4657a497726694c5a6c49 Mon Sep 17 00:00:00 2001 From: RJT1990 Date: Mon, 1 Jul 2019 11:40:39 +0100 Subject: [PATCH] Using datasets.utils extract function; accounting for train_extra dataset unzipping --- torchvision/datasets/cityscapes.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/torchvision/datasets/cityscapes.py b/torchvision/datasets/cityscapes.py index a102b96c7cf..540c83f8f7b 100644 --- a/torchvision/datasets/cityscapes.py +++ b/torchvision/datasets/cityscapes.py @@ -3,6 +3,7 @@ from collections import namedtuple import zipfile +from .utils import extract_archive from .vision import VisionDataset from PIL import Image @@ -126,7 +127,10 @@ def __init__(self, root, split='train', mode='fine', target_type='instance', ' or "color"') if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir): - image_dir_zip = os.path.join(self.root, 'leftImg8bit') + '_trainvaltest.zip' + if split == 'train_extra': + image_dir_zip = os.path.join(self.root, 'leftImg8bit') + '_trainextra.zip' + else: + image_dir_zip = os.path.join(self.root, 'leftImg8bit') + '_trainvaltest.zip' if self.mode == 'gtFine': target_dir_zip = os.path.join(self.root, self.mode) + '_trainvaltest.zip' @@ -134,8 +138,8 @@ def __init__(self, root, split='train', mode='fine', target_type='instance', target_dir_zip = os.path.join(self.root, self.mode) if os.path.isfile(image_dir_zip) and os.path.isfile(target_dir_zip): - extract_cityscapes_zip(zip_location=image_dir_zip, root=self.root) - extract_cityscapes_zip(zip_location=target_dir_zip, root=self.root) + extract_archive(from_path=image_dir_zip, to_path=self.root) + extract_archive(from_path=target_dir_zip, to_path=self.root) else: raise RuntimeError('Dataset not found or incomplete. Please make sure all required folders for the' ' specified "split" and "mode" are inside the "root" directory') @@ -201,9 +205,3 @@ def _get_target_suffix(self, mode, target_type): return '{}_color.png'.format(mode) else: return '{}_polygons.json'.format(mode) - - -def extract_cityscapes_zip(zip_location, root): - zip_file = zipfile.ZipFile(zip_location, 'r') - zip_file.extractall(root) - zip_file.close()