diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 72da9b7889d..6f43d5d263f 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -349,46 +349,53 @@ def pad(img, padding, fill=0, padding_mode='constant'): return Image.fromarray(img) -def crop(img, i, j, h, w): +def crop(img, top, left, height, width): """Crop the given PIL Image. - Args: - img (PIL Image): Image to be cropped. - i (int): i in (i,j) i.e coordinates of the upper left corner. - j (int): j in (i,j) i.e coordinates of the upper left corner. - h (int): Height of the cropped image. - w (int): Width of the cropped image. - + img (PIL Image): Image to be cropped. (0,0) denotes the top left corner of the image. + top (int): Vertical component of the top left corner of the crop box. + left (int): Horizontal component of the top left corner of the crop box. + height (int): Height of the crop box. + width (int): Width of the crop box. Returns: PIL Image: Cropped image. """ if not _is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) - return img.crop((j, i, j + w, i + h)) + return img.crop((left, top, left + width, top + height)) def center_crop(img, output_size): + """Crop the given PIL Image and resize it to desired size. + + Args: + img (PIL Image): Image to be cropped. (0,0) denotes the top left corner of the image. + output_size (sequence or int): (height, width) of the crop box. If int, + it is used for both directions + Returns: + PIL Image: Cropped image. + """ if isinstance(output_size, numbers.Number): output_size = (int(output_size), int(output_size)) - w, h = img.size - th, tw = output_size - i = int(round((h - th) / 2.)) - j = int(round((w - tw) / 2.)) - return crop(img, i, j, th, tw) + image_width, image_height = img.size + crop_height, crop_width = output_size + crop_top = int(round((image_height - crop_height) / 2.)) + crop_left = int(round((image_width - crop_width) / 2.)) + return crop(img, crop_top, crop_left, crop_height, crop_width) -def resized_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR): +def resized_crop(img, top, left, height, width, size, interpolation=Image.BILINEAR): """Crop the given PIL Image and resize it to desired size. Notably used in :class:`~torchvision.transforms.RandomResizedCrop`. Args: - img (PIL Image): Image to be cropped. - i (int): i in (i,j) i.e coordinates of the upper left corner - j (int): j in (i,j) i.e coordinates of the upper left corner - h (int): Height of the cropped image. - w (int): Width of the cropped image. + img (PIL Image): Image to be cropped. (0,0) denotes the top left corner of the image. + top (int): Vertical component of the top left corner of the crop box. + left (int): Horizontal component of the top left corner of the crop box. + height (int): Height of the crop box. + width (int): Width of the crop box. size (sequence or int): Desired output size. Same semantics as ``resize``. interpolation (int, optional): Desired interpolation. Default is ``PIL.Image.BILINEAR``. @@ -396,7 +403,7 @@ def resized_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR): PIL Image: Cropped image. """ assert _is_pil_image(img), 'img should be PIL Image' - img = crop(img, i, j, h, w) + img = crop(img, top, left, height, width) img = resize(img, size, interpolation) return img @@ -495,16 +502,18 @@ def five_crop(img, size): else: assert len(size) == 2, "Please provide only two dimensions (h, w) for size." - w, h = img.size - crop_h, crop_w = size - if crop_w > w or crop_h > h: - raise ValueError("Requested crop size {} is bigger than input size {}".format(size, - (h, w))) - tl = img.crop((0, 0, crop_w, crop_h)) - tr = img.crop((w - crop_w, 0, w, crop_h)) - bl = img.crop((0, h - crop_h, crop_w, h)) - br = img.crop((w - crop_w, h - crop_h, w, h)) - center = center_crop(img, (crop_h, crop_w)) + image_width, image_height = img.size + crop_height, crop_width = size + if crop_width > image_width or crop_height > image_height: + msg = "Requested crop size {} is bigger than input size {}" + raise ValueError(msg.format(size, (image_height, image_width))) + + tl = img.crop((0, 0, crop_width, crop_height)) + tr = img.crop((image_width - crop_width, 0, image_width, crop_height)) + bl = img.crop((0, image_height - crop_height, crop_width, image_height)) + br = img.crop((image_width - crop_width, image_height - crop_height, + image_width, image_height)) + center = center_crop(img, (crop_height, crop_width)) return (tl, tr, bl, br, center)