Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Add support for joint transformations in VisionDataset #872

Merged
merged 2 commits into from
Apr 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 8 additions & 17 deletions torchvision/datasets/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,8 @@ class CocoCaptions(VisionDataset):

"""

def __init__(self, root, annFile, transform=None, target_transform=None):
super(CocoCaptions, self).__init__(root)
self.transform = transform
self.target_transform = target_transform
def __init__(self, root, annFile, transform=None, target_transform=None, transforms=None):
super(CocoCaptions, self).__init__(root, transforms, transform, target_transform)
from pycocotools.coco import COCO
self.coco = COCO(annFile)
self.ids = list(sorted(self.coco.imgs.keys()))
Expand All @@ -68,11 +66,9 @@ def __getitem__(self, index):
path = coco.loadImgs(img_id)[0]['file_name']

img = Image.open(os.path.join(self.root, path)).convert('RGB')
if self.transform is not None:
img = self.transform(img)

if self.target_transform is not None:
target = self.target_transform(target)
if self.transforms is not None:
img, target = self.transforms(img, target)

return img, target

Expand All @@ -92,10 +88,8 @@ class CocoDetection(VisionDataset):
target and transforms it.
"""

def __init__(self, root, annFile, transform=None, target_transform=None):
super(CocoDetection, self).__init__(root)
self.transform = transform
self.target_transform = target_transform
def __init__(self, root, annFile, transform=None, target_transform=None, transforms=None):
super(CocoDetection, self).__init__(root, transforms, transform, target_transform)
from pycocotools.coco import COCO
self.coco = COCO(annFile)
self.ids = list(sorted(self.coco.imgs.keys()))
Expand All @@ -116,11 +110,8 @@ def __getitem__(self, index):
path = coco.loadImgs(img_id)[0]['file_name']

img = Image.open(os.path.join(self.root, path)).convert('RGB')
if self.transform is not None:
img = self.transform(img)

if self.target_transform is not None:
target = self.target_transform(target)
if self.transforms is not None:
img, target = self.transforms(img, target)

return img, target

Expand Down
9 changes: 4 additions & 5 deletions torchvision/datasets/sbd.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(self,
image_set='train',
mode='boundaries',
download=False,
xy_transform=None, **kwargs):
transforms=None):

try:
from scipy.io import loadmat
Expand All @@ -63,12 +63,11 @@ def __init__(self,
raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: "
"pip install scipy")

super(SBDataset, self).__init__(root)
super(SBDataset, self).__init__(root, transforms)

if mode not in ("segmentation", "boundaries"):
raise ValueError("Argument mode should be 'segmentation' or 'boundaries'")

self.xy_transform = xy_transform
self.image_set = image_set
self.mode = mode
self.num_classes = 20
Expand Down Expand Up @@ -120,8 +119,8 @@ def __getitem__(self, index):
img = Image.open(self.images[index]).convert('RGB')
target = self._get_target(self.masks[index])

if self.xy_transform is not None:
img, target = self.xy_transform(img, target)
if self.transforms is not None:
img, target = self.transforms(img, target)

return img, target

Expand Down
53 changes: 46 additions & 7 deletions torchvision/datasets/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,25 @@
class VisionDataset(data.Dataset):
_repr_indent = 4

def __init__(self, root):
def __init__(self, root, transforms=None, transform=None, target_transform=None):
if isinstance(root, torch._six.string_classes):
root = os.path.expanduser(root)
self.root = root

has_transforms = transforms is not None
has_separate_transform = transform is not None or target_transform is not None
if has_transforms and has_separate_transform:
raise ValueError("Only transforms or transform/target_transform can "
"be passed as argument")

# for backwards-compatibility
self.transform = transform
self.target_transform = target_transform

if has_separate_transform:
transforms = StandardTransform(transform, target_transform)
self.transforms = transforms

def __getitem__(self, index):
raise NotImplementedError

Expand All @@ -23,12 +37,8 @@ def __repr__(self):
if self.root is not None:
body.append("Root location: {}".format(self.root))
body += self.extra_repr().splitlines()
if hasattr(self, 'transform') and self.transform is not None:
body += self._format_transform_repr(self.transform,
"Transforms: ")
if hasattr(self, 'target_transform') and self.target_transform is not None:
body += self._format_transform_repr(self.target_transform,
"Target transforms: ")
if self.transforms is not None:
body += [repr(self.transforms)]
lines = [head] + [" " * self._repr_indent + line for line in body]
return '\n'.join(lines)

Expand All @@ -39,3 +49,32 @@ def _format_transform_repr(self, transform, head):

def extra_repr(self):
return ""


class StandardTransform(object):
def __init__(self, transform=None, target_transform=None):
self.transform = transform
self.target_transform = target_transform

def __call__(self, input, target):
if self.transform is not None:
input = self.transform(input)
if self.target_transform is not None:
target = self.target_transform(target)
return input, target

def _format_transform_repr(self, transform, head):
lines = transform.__repr__().splitlines()
return (["{}{}".format(head, lines[0])] +
["{}{}".format(" " * len(head), line) for line in lines[1:]])

def __repr__(self):
body = [self.__class__.__name__]
if self.transform is not None:
body += self._format_transform_repr(self.transform,
"Transform: ")
if self.target_transform is not None:
body += self._format_transform_repr(self.target_transform,
"Target transform: ")

return '\n'.join(body)
28 changes: 10 additions & 18 deletions torchvision/datasets/voc.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,9 @@ def __init__(self,
image_set='train',
download=False,
transform=None,
target_transform=None):
super(VOCSegmentation, self).__init__(root)
self.transform = transform
self.target_transform = target_transform
target_transform=None,
transforms=None):
super(VOCSegmentation, self).__init__(root, transforms, transform, target_transform)
self.year = year
self.url = DATASET_YEAR_DICT[year]['url']
self.filename = DATASET_YEAR_DICT[year]['filename']
Expand Down Expand Up @@ -122,11 +121,8 @@ def __getitem__(self, index):
img = Image.open(self.images[index]).convert('RGB')
target = Image.open(self.masks[index])

if self.transform is not None:
img = self.transform(img)

if self.target_transform is not None:
target = self.target_transform(target)
if self.transforms is not None:
img, target = self.transforms(img, target)

return img, target

Expand Down Expand Up @@ -157,10 +153,9 @@ def __init__(self,
image_set='train',
download=False,
transform=None,
target_transform=None):
super(VOCDetection, self).__init__(root)
self.transform = transform
self.target_transform = target_transform
target_transform=None,
transforms=None):
super(VOCDetection, self).__init__(root, transforms, transform, target_transform)
self.year = year
self.url = DATASET_YEAR_DICT[year]['url']
self.filename = DATASET_YEAR_DICT[year]['filename']
Expand Down Expand Up @@ -208,11 +203,8 @@ def __getitem__(self, index):
target = self.parse_voc_xml(
ET.parse(self.annotations[index]).getroot())

if self.transform is not None:
img = self.transform(img)

if self.target_transform is not None:
target = self.target_transform(target)
if self.transforms is not None:
img, target = self.transforms(img, target)

return img, target

Expand Down