Skip to content

Commit

Permalink
[fbsync] [pytest port] test{crop_five, crop_ten, resize, resized_crop…
Browse files Browse the repository at this point in the history
…} in test_transforms_tensor (#4010)

Reviewed By: fmassa

Differential Revision: D29097743

fbshipit-source-id: 0c8900d69c8092abfdf00d6abd3bee6c8db1d8e5
  • Loading branch information
NicolasHug authored and facebook-github-bot committed Jun 14, 2021
1 parent add5e09 commit 4516390
Showing 1 changed file with 117 additions and 144 deletions.
261 changes: 117 additions & 144 deletions test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
_assert_equal_tensor_to_pil,
_assert_approx_equal_tensor_to_pil,
cpu_and_gpu,
cpu_only
)
from _assert_utils import assert_equal

Expand Down Expand Up @@ -142,150 +143,6 @@ def test_random_autocontrast(self):
def test_random_equalize(self):
_test_op(F.equalize, T.RandomEqualize, device=self.device)

def _test_op_list_output(self, func, method, out_length, fn_kwargs=None, meth_kwargs=None):
if fn_kwargs is None:
fn_kwargs = {}
if meth_kwargs is None:
meth_kwargs = {}

fn = getattr(F, func)
scripted_fn = torch.jit.script(fn)

tensor, pil_img = _create_data(height=20, width=20, device=self.device)
transformed_t_list = fn(tensor, **fn_kwargs)
transformed_p_list = fn(pil_img, **fn_kwargs)
self.assertEqual(len(transformed_t_list), len(transformed_p_list))
self.assertEqual(len(transformed_t_list), out_length)
for transformed_tensor, transformed_pil_img in zip(transformed_t_list, transformed_p_list):
_assert_equal_tensor_to_pil(transformed_tensor, transformed_pil_img)

transformed_t_list_script = scripted_fn(tensor.detach().clone(), **fn_kwargs)
self.assertEqual(len(transformed_t_list), len(transformed_t_list_script))
self.assertEqual(len(transformed_t_list_script), out_length)
for transformed_tensor, transformed_tensor_script in zip(transformed_t_list, transformed_t_list_script):
assert_equal(
transformed_tensor,
transformed_tensor_script,
msg="{} vs {}".format(transformed_tensor, transformed_tensor_script),
)

# test for class interface
fn = getattr(T, method)(**meth_kwargs)
scripted_fn = torch.jit.script(fn)
output = scripted_fn(tensor)
self.assertEqual(len(output), len(transformed_t_list_script))

# test on batch of tensors
batch_tensors = _create_data_batch(height=23, width=34, channels=3, num_samples=4, device=self.device)
torch.manual_seed(12)
transformed_batch_list = fn(batch_tensors)

for i in range(len(batch_tensors)):
img_tensor = batch_tensors[i, ...]
torch.manual_seed(12)
transformed_img_list = fn(img_tensor)
for transformed_img, transformed_batch in zip(transformed_img_list, transformed_batch_list):
assert_equal(
transformed_img,
transformed_batch[i, ...],
msg="{} vs {}".format(transformed_img, transformed_batch[i, ...]),
)

with get_tmp_dir() as tmp_dir:
scripted_fn.save(os.path.join(tmp_dir, "t_op_list_{}.pt".format(method)))

def test_five_crop(self):
fn_kwargs = meth_kwargs = {"size": (5,)}
self._test_op_list_output(
"five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = meth_kwargs = {"size": [5, ]}
self._test_op_list_output(
"five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = meth_kwargs = {"size": (4, 5)}
self._test_op_list_output(
"five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = meth_kwargs = {"size": [4, 5]}
self._test_op_list_output(
"five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)

def test_ten_crop(self):
fn_kwargs = meth_kwargs = {"size": (5,)}
self._test_op_list_output(
"ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = meth_kwargs = {"size": [5, ]}
self._test_op_list_output(
"ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = meth_kwargs = {"size": (4, 5)}
self._test_op_list_output(
"ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
fn_kwargs = meth_kwargs = {"size": [4, 5]}
self._test_op_list_output(
"ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)

def test_resize(self):

# TODO: Minimal check for bug-fix, improve this later
x = torch.rand(3, 32, 46)
t = T.Resize(size=38)
y = t(x)
# If size is an int, smaller edge of the image will be matched to this number.
# i.e, if height > width, then image will be rescaled to (size * height / width, size).
self.assertTrue(isinstance(y, torch.Tensor))
self.assertEqual(y.shape[1], 38)
self.assertEqual(y.shape[2], int(38 * 46 / 32))

tensor, _ = _create_data(height=34, width=36, device=self.device)
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)

for dt in [None, torch.float32, torch.float64]:
if dt is not None:
# This is a trivial cast to float of uint8 data to test all cases
tensor = tensor.to(dt)
for size in [32, 34, [32, ], [32, 32], (32, 32), [34, 35]]:
for max_size in (None, 35, 1000):
if max_size is not None and isinstance(size, Sequence) and len(size) != 1:
continue # Not supported
for interpolation in [BILINEAR, BICUBIC, NEAREST]:

if isinstance(size, int):
script_size = [size, ]
else:
script_size = size

transform = T.Resize(size=script_size, interpolation=interpolation, max_size=max_size)
s_transform = torch.jit.script(transform)
_test_transform_vs_scripted(transform, s_transform, tensor)
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)

with get_tmp_dir() as tmp_dir:
s_transform.save(os.path.join(tmp_dir, "t_resize.pt"))

def test_resized_crop(self):
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)

for scale in [(0.7, 1.2), [0.7, 1.2]]:
for ratio in [(0.75, 1.333), [0.75, 1.333]]:
for size in [(32, ), [44, ], [32, ], [32, 32], (32, 32), [44, 55]]:
for interpolation in [NEAREST, BILINEAR, BICUBIC]:
transform = T.RandomResizedCrop(
size=size, scale=scale, ratio=ratio, interpolation=interpolation
)
s_transform = torch.jit.script(transform)
_test_transform_vs_scripted(transform, s_transform, tensor)
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)

with get_tmp_dir() as tmp_dir:
s_transform.save(os.path.join(tmp_dir, "t_resized_crop.pt"))

def test_normalize(self):
fn = T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
tensor, _ = _create_data(26, 34, device=self.device)
Expand Down Expand Up @@ -634,6 +491,122 @@ def test_center_crop(device):
scripted_fn.save(os.path.join(tmp_dir, "t_center_crop.pt"))


@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('fn, method, out_length', [
# test_five_crop
(F.five_crop, T.FiveCrop, 5),
# test_ten_crop
(F.ten_crop, T.TenCrop, 10)
])
@pytest.mark.parametrize('size', [(5, ), [5, ], (4, 5), [4, 5]])
def test_x_crop(fn, method, out_length, size, device):
meth_kwargs = fn_kwargs = {'size': size}
scripted_fn = torch.jit.script(fn)

tensor, pil_img = _create_data(height=20, width=20, device=device)
transformed_t_list = fn(tensor, **fn_kwargs)
transformed_p_list = fn(pil_img, **fn_kwargs)
assert len(transformed_t_list) == len(transformed_p_list)
assert len(transformed_t_list) == out_length
for transformed_tensor, transformed_pil_img in zip(transformed_t_list, transformed_p_list):
_assert_equal_tensor_to_pil(transformed_tensor, transformed_pil_img)

transformed_t_list_script = scripted_fn(tensor.detach().clone(), **fn_kwargs)
assert len(transformed_t_list) == len(transformed_t_list_script)
assert len(transformed_t_list_script) == out_length
for transformed_tensor, transformed_tensor_script in zip(transformed_t_list, transformed_t_list_script):
assert_equal(transformed_tensor, transformed_tensor_script)

# test for class interface
fn = method(**meth_kwargs)
scripted_fn = torch.jit.script(fn)
output = scripted_fn(tensor)
assert len(output) == len(transformed_t_list_script)

# test on batch of tensors
batch_tensors = _create_data_batch(height=23, width=34, channels=3, num_samples=4, device=device)
torch.manual_seed(12)
transformed_batch_list = fn(batch_tensors)

for i in range(len(batch_tensors)):
img_tensor = batch_tensors[i, ...]
torch.manual_seed(12)
transformed_img_list = fn(img_tensor)
for transformed_img, transformed_batch in zip(transformed_img_list, transformed_batch_list):
assert_equal(transformed_img, transformed_batch[i, ...])


@cpu_only
@pytest.mark.parametrize('method', ["FiveCrop", "TenCrop"])
def test_x_crop_save(method):
fn = getattr(T, method)(size=[5, ])
scripted_fn = torch.jit.script(fn)
with get_tmp_dir() as tmp_dir:
scripted_fn.save(os.path.join(tmp_dir, "t_op_list_{}.pt".format(method)))


class TestResize:
@cpu_only
@pytest.mark.parametrize('size', [32, 34, 35, 36, 38])
def test_resize_int(self, size):
# TODO: Minimal check for bug-fix, improve this later
x = torch.rand(3, 32, 46)
t = T.Resize(size=size)
y = t(x)
# If size is an int, smaller edge of the image will be matched to this number.
# i.e, if height > width, then image will be rescaled to (size * height / width, size).
assert isinstance(y, torch.Tensor)
assert y.shape[1] == size
assert y.shape[2] == int(size * 46 / 32)

@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('dt', [None, torch.float32, torch.float64])
@pytest.mark.parametrize('size', [[32, ], [32, 32], (32, 32), [34, 35]])
@pytest.mark.parametrize('max_size', [None, 35, 1000])
@pytest.mark.parametrize('interpolation', [BILINEAR, BICUBIC, NEAREST])
def test_resize_scripted(self, dt, size, max_size, interpolation, device):
tensor, _ = _create_data(height=34, width=36, device=device)
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)

if dt is not None:
# This is a trivial cast to float of uint8 data to test all cases
tensor = tensor.to(dt)
if max_size is not None and len(size) != 1:
pytest.xfail("with max_size, size must be a sequence with 2 elements")

transform = T.Resize(size=size, interpolation=interpolation, max_size=max_size)
s_transform = torch.jit.script(transform)
_test_transform_vs_scripted(transform, s_transform, tensor)
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)

@cpu_only
def test_resize_save(self):
transform = T.Resize(size=[32, ])
s_transform = torch.jit.script(transform)
with get_tmp_dir() as tmp_dir:
s_transform.save(os.path.join(tmp_dir, "t_resize.pt"))

@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('scale', [(0.7, 1.2), [0.7, 1.2]])
@pytest.mark.parametrize('ratio', [(0.75, 1.333), [0.75, 1.333]])
@pytest.mark.parametrize('size', [(32, ), [44, ], [32, ], [32, 32], (32, 32), [44, 55]])
@pytest.mark.parametrize('interpolation', [NEAREST, BILINEAR, BICUBIC])
def test_resized_crop(self, scale, ratio, size, interpolation, device):
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
transform = T.RandomResizedCrop(size=size, scale=scale, ratio=ratio, interpolation=interpolation)
s_transform = torch.jit.script(transform)
_test_transform_vs_scripted(transform, s_transform, tensor)
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)

@cpu_only
def test_resized_crop_save(self):
transform = T.RandomResizedCrop(size=[32, ])
s_transform = torch.jit.script(transform)
with get_tmp_dir() as tmp_dir:
s_transform.save(os.path.join(tmp_dir, "t_resized_crop.pt"))


@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
class CUDATester(Tester):

Expand Down

0 comments on commit 4516390

Please sign in to comment.