Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BC-breaking] RandomErasing is now scriptable #2386

Merged
merged 2 commits into from
Jul 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 58 additions & 32 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1618,38 +1618,64 @@ def test_random_grayscale(self):

def test_random_erasing(self):
"""Unit tests for random erasing transform"""

img = torch.rand([3, 60, 60])

# Test Set 1: Erasing with int value
img_re = transforms.RandomErasing(value=0.2)
i, j, h, w, v = img_re.get_params(img, scale=img_re.scale, ratio=img_re.ratio, value=img_re.value)
img_output = F.erase(img, i, j, h, w, v)
self.assertEqual(img_output.size(0), 3)

# Test Set 2: Check if the unerased region is preserved
orig_unerased = img.clone()
orig_unerased[:, i:i + h, j:j + w] = 0
output_unerased = img_output.clone()
output_unerased[:, i:i + h, j:j + w] = 0
self.assertTrue(torch.equal(orig_unerased, output_unerased))

# Test Set 3: Erasing with random value
img_re = transforms.RandomErasing(value='random')(img)
self.assertEqual(img_re.size(0), 3)

# Test Set 4: Erasing with tuple value
img_re = transforms.RandomErasing(value=(0.2, 0.2, 0.2))(img)
self.assertEqual(img_re.size(0), 3)

# Test Set 5: Testing the inplace behaviour
img_re = transforms.RandomErasing(value=(0.2), inplace=True)(img)
self.assertTrue(torch.equal(img_re, img))

# Test Set 6: Checking when no erased region is selected
img = torch.rand([3, 300, 1])
img_re = transforms.RandomErasing(ratio=(0.1, 0.2), value='random')(img)
self.assertTrue(torch.equal(img_re, img))
for is_scripted in [False, True]:
torch.manual_seed(12)
img = torch.rand(3, 60, 60)

# Test Set 0: invalid value
random_erasing = transforms.RandomErasing(value=(0.1, 0.2, 0.3, 0.4), p=1.0)
with self.assertRaises(ValueError, msg="If value is a sequence, it should have either a single value or 3"):
img_re = random_erasing(img)

# Test Set 1: Erasing with int value
random_erasing = transforms.RandomErasing(value=0.2)
if is_scripted:
random_erasing = torch.jit.script(random_erasing)

i, j, h, w, v = transforms.RandomErasing.get_params(
img, scale=random_erasing.scale, ratio=random_erasing.ratio, value=[random_erasing.value, ]
)
img_output = F.erase(img, i, j, h, w, v)
self.assertEqual(img_output.size(0), 3)

# Test Set 2: Check if the unerased region is preserved
true_output = img.clone()
true_output[:, i:i + h, j:j + w] = random_erasing.value
self.assertTrue(torch.equal(true_output, img_output))

# Test Set 3: Erasing with random value
random_erasing = transforms.RandomErasing(value="random")
if is_scripted:
random_erasing = torch.jit.script(random_erasing)
img_re = random_erasing(img)

self.assertEqual(img_re.size(0), 3)

# Test Set 4: Erasing with tuple value
random_erasing = transforms.RandomErasing(value=(0.2, 0.2, 0.2))
if is_scripted:
random_erasing = torch.jit.script(random_erasing)
img_re = random_erasing(img)
self.assertEqual(img_re.size(0), 3)
true_output = img.clone()
true_output[:, i:i + h, j:j + w] = torch.tensor(random_erasing.value)[:, None, None]
self.assertTrue(torch.equal(true_output, img_output))

# Test Set 5: Testing the inplace behaviour
random_erasing = transforms.RandomErasing(value=(0.2,), inplace=True)
if is_scripted:
random_erasing = torch.jit.script(random_erasing)

img_re = random_erasing(img)
self.assertTrue(torch.equal(img_re, img))

# Test Set 6: Checking when no erased region is selected
img = torch.rand([3, 300, 1])
random_erasing = transforms.RandomErasing(ratio=(0.1, 0.2), value="random")
if is_scripted:
random_erasing = torch.jit.script(random_erasing)
img_re = random_erasing(img)
self.assertTrue(torch.equal(img_re, img))


if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -950,7 +950,7 @@ def to_grayscale(img, num_output_channels=1):
return img


def erase(img, i, j, h, w, v, inplace=False):
def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor:
""" Erase the input Tensor Image with given value.

Args:
Expand Down
78 changes: 54 additions & 24 deletions torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import random
import warnings
from collections.abc import Sequence, Iterable
from typing import Tuple
from typing import Tuple, List, Optional

import numpy as np
import torch
Expand Down Expand Up @@ -1343,7 +1343,7 @@ def __repr__(self):
return self.__class__.__name__ + '(p={0})'.format(self.p)


class RandomErasing(object):
class RandomErasing(torch.nn.Module):
""" Randomly selects a rectangle region in an image and erases its pixels.
'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/pdf/1708.04896.pdf

Expand All @@ -1370,13 +1370,21 @@ class RandomErasing(object):
"""

def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False):
assert isinstance(value, (numbers.Number, str, tuple, list))
super().__init__()
if not isinstance(value, (numbers.Number, str, tuple, list)):
raise TypeError("Argument value should be either a number or str or a sequence")
if isinstance(value, str) and value != "random":
raise ValueError("If value is str, it should be 'random'")
if not isinstance(scale, (tuple, list)):
raise TypeError("Scale should be a sequence")
if not isinstance(ratio, (tuple, list)):
raise TypeError("Ratio should be a sequence")
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("range should be of kind (min, max)")
warnings.warn("Scale and ratio should be of kind (min, max)")
if scale[0] < 0 or scale[1] > 1:
raise ValueError("range of scale should be between 0 and 1")
raise ValueError("Scale should be between 0 and 1")
if p < 0 or p > 1:
raise ValueError("range of random erasing probability should be between 0 and 1")
raise ValueError("Random erasing probability should be between 0 and 1")

self.p = p
self.scale = scale
Expand All @@ -1385,13 +1393,18 @@ def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace
self.inplace = inplace

@staticmethod
def get_params(img, scale, ratio, value=0):
def get_params(
img: Tensor, scale: Tuple[float, float], ratio: Tuple[float, float], value: Optional[List[float]] = None
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
) -> Tuple[int, int, int, int, Tensor]:
"""Get parameters for ``erase`` for a random erasing.

Args:
img (Tensor): Tensor image of size (C, H, W) to be erased.
scale: range of proportion of erased area against input image.
ratio: range of aspect ratio of erased area.
scale (tuple or list): range of proportion of erased area against input image.
ratio (tuple or list): range of aspect ratio of erased area.
value (list, optional): erasing value. If None, it is interpreted as "random"
(erasing each pixel with random values). If ``len(value)`` is 1, it is interpreted as a number,
i.e. ``value[0]``.

Returns:
tuple: params (i, j, h, w, v) to be passed to ``erase`` for random erasing.
Expand All @@ -1400,35 +1413,52 @@ def get_params(img, scale, ratio, value=0):
area = img_h * img_w

for _ in range(10):
erase_area = random.uniform(scale[0], scale[1]) * area
aspect_ratio = random.uniform(ratio[0], ratio[1])
erase_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
aspect_ratio = torch.empty(1).uniform_(ratio[0], ratio[1]).item()

h = int(round(math.sqrt(erase_area * aspect_ratio)))
w = int(round(math.sqrt(erase_area / aspect_ratio)))
if not (h < img_h and w < img_w):
continue

if value is None:
v = torch.empty([img_c, h, w], dtype=torch.float32).normal_()
else:
v = torch.tensor(value)[:, None, None]

if h < img_h and w < img_w:
i = random.randint(0, img_h - h)
j = random.randint(0, img_w - w)
if isinstance(value, numbers.Number):
v = value
elif isinstance(value, torch._six.string_classes):
v = torch.empty([img_c, h, w], dtype=torch.float32).normal_()
elif isinstance(value, (list, tuple)):
v = torch.tensor(value, dtype=torch.float32).view(-1, 1, 1).expand(-1, h, w)
return i, j, h, w, v
i = torch.randint(0, img_h - h, size=(1, )).item()
j = torch.randint(0, img_w - w, size=(1, )).item()
return i, j, h, w, v

# Return original image
return 0, 0, img_h, img_w, img

def __call__(self, img):
def forward(self, img):
"""
Args:
img (Tensor): Tensor image of size (C, H, W) to be erased.

Returns:
img (Tensor): Erased Tensor image.
"""
if random.uniform(0, 1) < self.p:
x, y, h, w, v = self.get_params(img, scale=self.scale, ratio=self.ratio, value=self.value)
if torch.rand(1) < self.p:

# cast self.value to script acceptable type
if isinstance(self.value, (int, float)):
value = [self.value, ]
elif isinstance(self.value, str):
value = None
elif isinstance(self.value, tuple):
value = list(self.value)
else:
value = self.value

if value is not None and not (len(value) in (1, img.shape[-3])):
raise ValueError(
"If value is a sequence, it should have either a single value or "
"{} (number of input channels)".format(img.shape[-3])
)

x, y, h, w, v = self.get_params(img, scale=self.scale, ratio=self.ratio, value=value)
return F.erase(img, x, y, h, w, v, self.inplace)
return img