From d97b58452269f617835d675e7416520e2cf1d5d5 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 22 Jul 2019 09:02:34 +0200 Subject: [PATCH] bug fixes --- torchvision/datasets/imagenet.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/torchvision/datasets/imagenet.py b/torchvision/datasets/imagenet.py index 85bb7c759be..d6eead13a35 100644 --- a/torchvision/datasets/imagenet.py +++ b/torchvision/datasets/imagenet.py @@ -38,7 +38,7 @@ class ImageNet(ImageFolder): loader (callable, optional): A function to load an image given its path. Attributes: - classes (list): List of the class names. + classes (list): List of the class name tuples. class_to_idx (dict): Dict with items (class_name, class_index). wnids (list): List of the WordNet IDs. wnid_to_idx (dict): Dict with items (wordnet_id, class_index). @@ -57,12 +57,11 @@ def __init__(self, root, split='train', download=False, **kwargs): super(ImageNet, self).__init__(self.split_folder, **kwargs) self.root = root - idcs = [idx for _, idx in self.imgs] self.wnids = self.classes - self.wnid_to_idx = {wnid: idx for idx, wnid in zip(idcs, self.wnids)} + self.wnid_to_idx = self.class_to_idx self.classes = [wnid_to_classes[wnid] for wnid in self.wnids] self.class_to_idx = {cls: idx - for clss, idx in zip(self.classes, idcs) + for idx, clss in enumerate(self.classes) for cls in clss} def download(self):