diff --git a/torchvision/datasets/coco.py b/torchvision/datasets/coco.py index 80f371c63d5..5ce42ddad72 100644 --- a/torchvision/datasets/coco.py +++ b/torchvision/datasets/coco.py @@ -43,10 +43,8 @@ class CocoCaptions(VisionDataset): """ - def __init__(self, root, annFile, transform=None, target_transform=None): - super(CocoCaptions, self).__init__(root) - self.transform = transform - self.target_transform = target_transform + def __init__(self, root, annFile, transform=None, target_transform=None, transforms=None): + super(CocoCaptions, self).__init__(root, transforms, transform, target_transform) from pycocotools.coco import COCO self.coco = COCO(annFile) self.ids = list(sorted(self.coco.imgs.keys())) @@ -68,11 +66,9 @@ def __getitem__(self, index): path = coco.loadImgs(img_id)[0]['file_name'] img = Image.open(os.path.join(self.root, path)).convert('RGB') - if self.transform is not None: - img = self.transform(img) - if self.target_transform is not None: - target = self.target_transform(target) + if self.transforms is not None: + img, target = self.transforms(img, target) return img, target @@ -92,10 +88,8 @@ class CocoDetection(VisionDataset): target and transforms it. """ - def __init__(self, root, annFile, transform=None, target_transform=None): - super(CocoDetection, self).__init__(root) - self.transform = transform - self.target_transform = target_transform + def __init__(self, root, annFile, transform=None, target_transform=None, transforms=None): + super(CocoDetection, self).__init__(root, transforms, transform, target_transform) from pycocotools.coco import COCO self.coco = COCO(annFile) self.ids = list(sorted(self.coco.imgs.keys())) @@ -116,11 +110,8 @@ def __getitem__(self, index): path = coco.loadImgs(img_id)[0]['file_name'] img = Image.open(os.path.join(self.root, path)).convert('RGB') - if self.transform is not None: - img = self.transform(img) - - if self.target_transform is not None: - target = self.target_transform(target) + if self.transforms is not None: + img, target = self.transforms(img, target) return img, target diff --git a/torchvision/datasets/sbd.py b/torchvision/datasets/sbd.py index eb585971760..901e45d0c83 100644 --- a/torchvision/datasets/sbd.py +++ b/torchvision/datasets/sbd.py @@ -54,7 +54,7 @@ def __init__(self, image_set='train', mode='boundaries', download=False, - xy_transform=None, **kwargs): + transforms=None): try: from scipy.io import loadmat @@ -63,12 +63,11 @@ def __init__(self, raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: " "pip install scipy") - super(SBDataset, self).__init__(root) + super(SBDataset, self).__init__(root, transforms) if mode not in ("segmentation", "boundaries"): raise ValueError("Argument mode should be 'segmentation' or 'boundaries'") - self.xy_transform = xy_transform self.image_set = image_set self.mode = mode self.num_classes = 20 @@ -120,8 +119,8 @@ def __getitem__(self, index): img = Image.open(self.images[index]).convert('RGB') target = self._get_target(self.masks[index]) - if self.xy_transform is not None: - img, target = self.xy_transform(img, target) + if self.transforms is not None: + img, target = self.transforms(img, target) return img, target diff --git a/torchvision/datasets/vision.py b/torchvision/datasets/vision.py index 168388aadde..b3bb523b51b 100644 --- a/torchvision/datasets/vision.py +++ b/torchvision/datasets/vision.py @@ -6,11 +6,25 @@ class VisionDataset(data.Dataset): _repr_indent = 4 - def __init__(self, root): + def __init__(self, root, transforms=None, transform=None, target_transform=None): if isinstance(root, torch._six.string_classes): root = os.path.expanduser(root) self.root = root + has_transforms = transforms is not None + has_separate_transform = transform is not None or target_transform is not None + if has_transforms and has_separate_transform: + raise ValueError("Only transforms or transform/target_transform can " + "be passed as argument") + + # for backwards-compatibility + self.transform = transform + self.target_transform = target_transform + + if has_separate_transform: + transforms = StandardTransform(transform, target_transform) + self.transforms = transforms + def __getitem__(self, index): raise NotImplementedError @@ -23,12 +37,8 @@ def __repr__(self): if self.root is not None: body.append("Root location: {}".format(self.root)) body += self.extra_repr().splitlines() - if hasattr(self, 'transform') and self.transform is not None: - body += self._format_transform_repr(self.transform, - "Transforms: ") - if hasattr(self, 'target_transform') and self.target_transform is not None: - body += self._format_transform_repr(self.target_transform, - "Target transforms: ") + if self.transforms is not None: + body += [repr(self.transforms)] lines = [head] + [" " * self._repr_indent + line for line in body] return '\n'.join(lines) @@ -39,3 +49,32 @@ def _format_transform_repr(self, transform, head): def extra_repr(self): return "" + + +class StandardTransform(object): + def __init__(self, transform=None, target_transform=None): + self.transform = transform + self.target_transform = target_transform + + def __call__(self, input, target): + if self.transform is not None: + input = self.transform(input) + if self.target_transform is not None: + target = self.target_transform(target) + return input, target + + def _format_transform_repr(self, transform, head): + lines = transform.__repr__().splitlines() + return (["{}{}".format(head, lines[0])] + + ["{}{}".format(" " * len(head), line) for line in lines[1:]]) + + def __repr__(self): + body = [self.__class__.__name__] + if self.transform is not None: + body += self._format_transform_repr(self.transform, + "Transform: ") + if self.target_transform is not None: + body += self._format_transform_repr(self.target_transform, + "Target transform: ") + + return '\n'.join(body) diff --git a/torchvision/datasets/voc.py b/torchvision/datasets/voc.py index 96b96b459d4..47f28d5c619 100644 --- a/torchvision/datasets/voc.py +++ b/torchvision/datasets/voc.py @@ -74,10 +74,9 @@ def __init__(self, image_set='train', download=False, transform=None, - target_transform=None): - super(VOCSegmentation, self).__init__(root) - self.transform = transform - self.target_transform = target_transform + target_transform=None, + transforms=None): + super(VOCSegmentation, self).__init__(root, transforms, transform, target_transform) self.year = year self.url = DATASET_YEAR_DICT[year]['url'] self.filename = DATASET_YEAR_DICT[year]['filename'] @@ -122,11 +121,8 @@ def __getitem__(self, index): img = Image.open(self.images[index]).convert('RGB') target = Image.open(self.masks[index]) - if self.transform is not None: - img = self.transform(img) - - if self.target_transform is not None: - target = self.target_transform(target) + if self.transforms is not None: + img, target = self.transforms(img, target) return img, target @@ -157,10 +153,9 @@ def __init__(self, image_set='train', download=False, transform=None, - target_transform=None): - super(VOCDetection, self).__init__(root) - self.transform = transform - self.target_transform = target_transform + target_transform=None, + transforms=None): + super(VOCDetection, self).__init__(root, transforms, transform, target_transform) self.year = year self.url = DATASET_YEAR_DICT[year]['url'] self.filename = DATASET_YEAR_DICT[year]['filename'] @@ -208,11 +203,8 @@ def __getitem__(self, index): target = self.parse_voc_xml( ET.parse(self.annotations[index]).getroot()) - if self.transform is not None: - img = self.transform(img) - - if self.target_transform is not None: - target = self.target_transform(target) + if self.transforms is not None: + img, target = self.transforms(img, target) return img, target