Skip to content

Commit

Permalink
Miscellaneous dataset fixes (#1174)
Browse files Browse the repository at this point in the history
* fix stl10

* fix lsun
  • Loading branch information
Philip Meier authored and fmassa committed Jul 26, 2019
1 parent 8102158 commit 59c97d7
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 35 deletions.
52 changes: 30 additions & 22 deletions torchvision/datasets/lsun.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import six
import string
import sys
from collections import Iterable

if sys.version_info[0] == 2:
import cPickle as pickle
Expand Down Expand Up @@ -72,6 +73,24 @@ class LSUN(VisionDataset):
def __init__(self, root, classes='train', transform=None, target_transform=None):
super(LSUN, self).__init__(root, transform=transform,
target_transform=target_transform)
self.classes = self._verify_classes(classes)

# for each class, create an LSUNClassDataset
self.dbs = []
for c in self.classes:
self.dbs.append(LSUNClass(
root=root + '/' + c + '_lmdb',
transform=transform))

self.indices = []
count = 0
for db in self.dbs:
count += len(db)
self.indices.append(count)

self.length = count

def _verify_classes(self, classes):
categories = ['bedroom', 'bridge', 'church_outdoor', 'classroom',
'conference_room', 'dining_room', 'kitchen',
'living_room', 'restaurant', 'tower']
Expand All @@ -84,39 +103,28 @@ def __init__(self, root, classes='train', transform=None, target_transform=None)
else:
classes = [c + '_' + classes for c in categories]
except ValueError:
# TODO: Should this check for Iterable instead of list?
if not isinstance(classes, list):
raise ValueError
if not isinstance(classes, Iterable):
msg = ("Expected type str or Iterable for argument classes, "
"but got type {}.")
raise ValueError(msg.format(type(classes)))

classes = list(classes)
msg_fmtstr = ("Expected type str for elements in argument classes, "
"but got type {}.")
for c in classes:
# TODO: This assumes each item is a str (or subclass). Should this
# also be checked?
verify_str_arg(c, custom_msg=msg_fmtstr.format(type(c)))
c_short = c.split('_')
category, dset_opt = '_'.join(c_short[:-1]), c_short[-1]
msg_fmtstr = "Unknown value '{}' for {}. Valid values are {{{}}}."

msg_fmtstr = "Unknown value '{}' for {}. Valid values are {{{}}}."
msg = msg_fmtstr.format(category, "LSUN class",
iterable_to_str(categories))
verify_str_arg(category, valid_values=categories, custom_msg=msg)

msg = msg_fmtstr.format(dset_opt, "postfix", iterable_to_str(dset_opts))
verify_str_arg(dset_opt, valid_values=dset_opts, custom_msg=msg)
finally:
self.classes = classes

# for each class, create an LSUNClassDataset
self.dbs = []
for c in self.classes:
self.dbs.append(LSUNClass(
root=root + '/' + c + '_lmdb',
transform=transform))

self.indices = []
count = 0
for db in self.dbs:
count += len(db)
self.indices.append(count)

self.length = count
return classes

def __getitem__(self, index):
"""
Expand Down
35 changes: 22 additions & 13 deletions torchvision/datasets/stl10.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(self, root, split='train', folds=None, transform=None,
super(STL10, self).__init__(root, transform=transform,
target_transform=target_transform)
self.split = verify_str_arg(split, "split", self.splits)
self.folds = folds # one of the 10 pre-defined folds or the full dataset
self.folds = self._verify_folds(folds)

if download:
self.download()
Expand Down Expand Up @@ -89,6 +89,19 @@ def __init__(self, root, split='train', folds=None, transform=None,
with open(class_file) as f:
self.classes = f.read().splitlines()

def _verify_folds(self, folds):
if folds is None:
return folds
elif isinstance(folds, int):
if folds in range(10):
return folds
msg = ("Value for argument folds should be in the range [0, 10), "
"but got {}.")
raise ValueError(msg.format(folds))
else:
msg = "Expected type None or int for argument folds, but got type {}."
raise ValueError(msg.format(type(folds)))

def __getitem__(self, index):
"""
Args:
Expand Down Expand Up @@ -154,15 +167,11 @@ def extra_repr(self):

def __load_folds(self, folds):
# loads one of the folds if specified
if isinstance(folds, int):
if folds >= 0 and folds < 10:
path_to_folds = os.path.join(
self.root, self.base_folder, self.folds_list_file)
with open(path_to_folds, 'r') as f:
str_idx = f.read().splitlines()[folds]
list_idx = np.fromstring(str_idx, dtype=np.uint8, sep=' ')
self.data, self.labels = self.data[list_idx, :, :, :], self.labels[list_idx]
else:
# FIXME: docstring allows None for folds (it is even the default value)
# Is this intended?
raise ValueError('Folds "{}" not found. Valid splits are: 0-9.'.format(folds))
if folds is None:
return
path_to_folds = os.path.join(
self.root, self.base_folder, self.folds_list_file)
with open(path_to_folds, 'r') as f:
str_idx = f.read().splitlines()[folds]
list_idx = np.fromstring(str_idx, dtype=np.uint8, sep=' ')
self.data, self.labels = self.data[list_idx, :, :, :], self.labels[list_idx]

0 comments on commit 59c97d7

Please sign in to comment.