From 8d6962f34e79b13533dad8a3a215aa178c7ce16e Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 21 May 2021 08:39:43 +0200 Subject: [PATCH] remove obsolete test_datasets_transforms --- test/test_datasets_transforms.py | 72 -------------------------------- 1 file changed, 72 deletions(-) delete mode 100644 test/test_datasets_transforms.py diff --git a/test/test_datasets_transforms.py b/test/test_datasets_transforms.py deleted file mode 100644 index 6cffd4f76a9..00000000000 --- a/test/test_datasets_transforms.py +++ /dev/null @@ -1,72 +0,0 @@ -import os -import shutil -import contextlib -import tempfile -import unittest -from torchvision.datasets import ImageFolder - -FAKEDATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), - 'assets', 'fakedata') - - -@contextlib.contextmanager -def tmp_dir(src=None, **kwargs): - tmp_dir = tempfile.mkdtemp(**kwargs) - if src is not None: - os.rmdir(tmp_dir) - shutil.copytree(src, tmp_dir) - try: - yield tmp_dir - finally: - shutil.rmtree(tmp_dir) - - -def mock_transform(return_value, arg_list): - def mock(arg): - arg_list.append(arg) - return return_value - return mock - - -class Tester(unittest.TestCase): - def test_transform(self): - with tmp_dir(src=os.path.join(FAKEDATA_DIR, 'imagefolder')) as root: - class_a_image_files = [os.path.join(root, 'a', file) - for file in ('a1.png', 'a2.png', 'a3.png')] - class_b_image_files = [os.path.join(root, 'b', file) - for file in ('b1.png', 'b2.png', 'b3.png', 'b4.png')] - return_value = os.path.join(root, 'a', 'a1.png') - args = [] - transform = mock_transform(return_value, args) - dataset = ImageFolder(root, loader=lambda x: x, transform=transform) - - outputs = [dataset[i][0] for i in range(len(dataset))] - self.assertEqual([return_value] * len(outputs), outputs) - - imgs = sorted(class_a_image_files + class_b_image_files) - self.assertEqual(imgs, sorted(args)) - - def test_target_transform(self): - with tmp_dir(src=os.path.join(FAKEDATA_DIR, 'imagefolder')) as root: - class_a_image_files = [os.path.join(root, 'a', file) - for file in ('a1.png', 'a2.png', 'a3.png')] - class_b_image_files = [os.path.join(root, 'b', file) - for file in ('b1.png', 'b2.png', 'b3.png', 'b4.png')] - return_value = os.path.join(root, 'a', 'a1.png') - args = [] - target_transform = mock_transform(return_value, args) - dataset = ImageFolder(root, loader=lambda x: x, - target_transform=target_transform) - - outputs = [dataset[i][1] for i in range(len(dataset))] - self.assertEqual([return_value] * len(outputs), outputs) - - class_a_idx = dataset.class_to_idx['a'] - class_b_idx = dataset.class_to_idx['b'] - targets = sorted([class_a_idx] * len(class_a_image_files) + - [class_b_idx] * len(class_b_image_files)) - self.assertEqual(targets, sorted(args)) - - -if __name__ == '__main__': - unittest.main()