Skip to content

Commit

Permalink
[fbsync] Use torch.testing.assert_close in test_transforms_tensor.py (#…
Browse files Browse the repository at this point in the history
…3885)

Summary: Co-authored-by: Philip Meier <github.pmeier@posteo.de>

Reviewed By: vincentqb, cpuhrsch

Differential Revision: D28679970

fbshipit-source-id: ca3d0f527ba0c029b4fc64ba7883c89d561b67b8
  • Loading branch information
datumbox authored and facebook-github-bot committed May 25, 2021
1 parent d855eb4 commit 3e87b01
Showing 1 changed file with 18 additions and 11 deletions.
29 changes: 18 additions & 11 deletions test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Sequence

from common_utils import TransformsTester, get_tmp_dir, int_dtypes, float_dtypes
from _assert_utils import assert_equal


NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC
Expand Down Expand Up @@ -38,7 +39,7 @@ def _test_transform_vs_scripted(self, transform, s_transform, tensor, msg=None):
out1 = transform(tensor)
torch.manual_seed(12)
out2 = s_transform(tensor)
self.assertTrue(out1.equal(out2), msg=msg)
assert_equal(out1, out2, msg=msg)

def _test_transform_vs_scripted_on_batch(self, transform, s_transform, batch_tensors, msg=None):
torch.manual_seed(12)
Expand All @@ -48,11 +49,11 @@ def _test_transform_vs_scripted_on_batch(self, transform, s_transform, batch_ten
img_tensor = batch_tensors[i, ...]
torch.manual_seed(12)
transformed_img = transform(img_tensor)
self.assertTrue(transformed_img.equal(transformed_batch[i, ...]), msg=msg)
assert_equal(transformed_img, transformed_batch[i, ...], msg=msg)

torch.manual_seed(12)
s_transformed_batch = s_transform(batch_tensors)
self.assertTrue(transformed_batch.equal(s_transformed_batch), msg=msg)
assert_equal(transformed_batch, s_transformed_batch, msg=msg)

def _test_class_op(self, method, meth_kwargs=None, test_exact_match=True, **match_kwargs):
if meth_kwargs is None:
Expand All @@ -75,7 +76,7 @@ def _test_class_op(self, method, meth_kwargs=None, test_exact_match=True, **matc

torch.manual_seed(12)
transformed_tensor_script = scripted_fn(tensor)
self.assertTrue(transformed_tensor.equal(transformed_tensor_script))
assert_equal(transformed_tensor, transformed_tensor_script)

batch_tensors = self._create_data_batch(height=23, width=34, channels=3, num_samples=4, device=self.device)
self._test_transform_vs_scripted_on_batch(f, scripted_fn, batch_tensors)
Expand Down Expand Up @@ -270,8 +271,11 @@ def _test_op_list_output(self, func, method, out_length, fn_kwargs=None, meth_kw
self.assertEqual(len(transformed_t_list), len(transformed_t_list_script))
self.assertEqual(len(transformed_t_list_script), out_length)
for transformed_tensor, transformed_tensor_script in zip(transformed_t_list, transformed_t_list_script):
self.assertTrue(transformed_tensor.equal(transformed_tensor_script),
msg="{} vs {}".format(transformed_tensor, transformed_tensor_script))
assert_equal(
transformed_tensor,
transformed_tensor_script,
msg="{} vs {}".format(transformed_tensor, transformed_tensor_script),
)

# test for class interface
fn = getattr(T, method)(**meth_kwargs)
Expand All @@ -289,8 +293,11 @@ def _test_op_list_output(self, func, method, out_length, fn_kwargs=None, meth_kw
torch.manual_seed(12)
transformed_img_list = fn(img_tensor)
for transformed_img, transformed_batch in zip(transformed_img_list, transformed_batch_list):
self.assertTrue(transformed_img.equal(transformed_batch[i, ...]),
msg="{} vs {}".format(transformed_img, transformed_batch[i, ...]))
assert_equal(
transformed_img,
transformed_batch[i, ...],
msg="{} vs {}".format(transformed_img, transformed_batch[i, ...]),
)

with get_tmp_dir() as tmp_dir:
scripted_fn.save(os.path.join(tmp_dir, "t_op_list_{}.pt".format(method)))
Expand Down Expand Up @@ -505,7 +512,7 @@ def test_linear_transformation(self):
transformed_batch = fn(batch_tensors)
torch.manual_seed(12)
s_transformed_batch = scripted_fn(batch_tensors)
self.assertTrue(transformed_batch.equal(s_transformed_batch))
assert_equal(transformed_batch, s_transformed_batch)

with get_tmp_dir() as tmp_dir:
scripted_fn.save(os.path.join(tmp_dir, "t_norm.pt"))
Expand All @@ -525,7 +532,7 @@ def test_compose(self):
transformed_tensor = transforms(tensor)
torch.manual_seed(12)
transformed_tensor_script = scripted_fn(tensor)
self.assertTrue(transformed_tensor.equal(transformed_tensor_script), msg="{}".format(transforms))
assert_equal(transformed_tensor, transformed_tensor_script, msg="{}".format(transforms))

t = T.Compose([
lambda x: x,
Expand All @@ -551,7 +558,7 @@ def test_random_apply(self):
transformed_tensor = transforms(tensor)
torch.manual_seed(12)
transformed_tensor_script = scripted_fn(tensor)
self.assertTrue(transformed_tensor.equal(transformed_tensor_script), msg="{}".format(transforms))
assert_equal(transformed_tensor, transformed_tensor_script, msg="{}".format(transforms))

if torch.device(self.device).type == "cpu":
# Can't check this twice, otherwise
Expand Down

0 comments on commit 3e87b01

Please sign in to comment.