Skip to content

Commit

Permalink
Add tests for the PhotoTour dataset (#3486)
Browse files Browse the repository at this point in the history
* add tests for PhotoTour dataset

* fix grayscale image generation

* fix test_feature_types for a examples of a single feature

* make image size variable instead of hard coding it

* make dataset length variable instead of hard coding it

* replace numpy with torch

* fix typo
  • Loading branch information
pmeier authored Mar 2, 2021
1 parent f0f5ee0 commit f637c63
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 14 deletions.
27 changes: 18 additions & 9 deletions test/datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,14 +436,17 @@ def test_feature_types(self, config):
with self.create_dataset(config) as (dataset, _):
example = dataset[0]

actual = len(example)
expected = len(self.FEATURE_TYPES)
self.assertEqual(
actual,
expected,
f"The number of the returned features does not match the the number of elements in in FEATURE_TYPES: "
f"{actual} != {expected}",
)
if len(self.FEATURE_TYPES) > 1:
actual = len(example)
expected = len(self.FEATURE_TYPES)
self.assertEqual(
actual,
expected,
f"The number of the returned features does not match the the number of elements in FEATURE_TYPES: "
f"{actual} != {expected}",
)
else:
example = (example,)

for idx, (feature, expected_feature_type) in enumerate(zip(example, self.FEATURE_TYPES)):
with self.subTest(idx=idx):
Expand Down Expand Up @@ -586,7 +589,13 @@ def create_image_file(

image = create_image_or_video_tensor(size)
file = pathlib.Path(root) / name
PIL.Image.fromarray(image.permute(2, 1, 0).numpy()).save(file, **kwargs)

# torch (num_channels x height x width) -> PIL (width x height x num_channels)
image = image.permute(2, 1, 0)
# For grayscale images PIL doesn't use a channel dimension
if image.shape[2] == 1:
image = torch.squeeze(image, 2)
PIL.Image.fromarray(image.numpy()).save(file, **kwargs)
return file


Expand Down
79 changes: 79 additions & 0 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import torch.nn.functional as F
import string
import io
import zipfile


try:
Expand Down Expand Up @@ -1275,5 +1276,83 @@ def test_not_found_or_corrupted(self):
self.skipTest("The data is generated at creation and thus cannot be non-existent or corrupted.")


class PhotoTourTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.PhotoTour

# The PhotoTour dataset returns examples with different features with respect to the 'train' parameter. Thus,
# we overwrite 'FEATURE_TYPES' with a dummy value to satisfy the initial checks of the base class. Furthermore, we
# overwrite the 'test_feature_types()' method to select the correct feature types before the test is run.
FEATURE_TYPES = ()
_TRAIN_FEATURE_TYPES = (torch.Tensor,)
_TEST_FEATURE_TYPES = (torch.Tensor, torch.Tensor, torch.Tensor)

CONFIGS = datasets_utils.combinations_grid(train=(True, False))

_NAME = "liberty"

def dataset_args(self, tmpdir, config):
return tmpdir, self._NAME

def inject_fake_data(self, tmpdir, config):
tmpdir = pathlib.Path(tmpdir)

# In contrast to the original data, the fake images injected here comprise only a single patch. Thus,
# num_images == num_patches.
num_patches = 5

image_files = self._create_images(tmpdir, self._NAME, num_patches)
point_ids, info_file = self._create_info_file(tmpdir / self._NAME, num_patches)
num_matches, matches_file = self._create_matches_file(tmpdir / self._NAME, num_patches, point_ids)

self._create_archive(tmpdir, self._NAME, *image_files, info_file, matches_file)

return num_patches if config["train"] else num_matches

def _create_images(self, root, name, num_images):
# The images in the PhotoTour dataset comprises of multiple grayscale patches of 64 x 64 pixels. Thus, the
# smallest fake image is 64 x 64 pixels and comprises a single patch.
return datasets_utils.create_image_folder(
root, name, lambda idx: f"patches{idx:04d}.bmp", num_images, size=(1, 64, 64)
)

def _create_info_file(self, root, num_images):
point_ids = torch.randint(num_images, size=(num_images,)).tolist()

file = root / "info.txt"
with open(file, "w") as fh:
fh.writelines([f"{point_id} 0\n" for point_id in point_ids])

return point_ids, file

def _create_matches_file(self, root, num_patches, point_ids):
lines = [
f"{patch_id1} {point_ids[patch_id1]} 0 {patch_id2} {point_ids[patch_id2]} 0\n"
for patch_id1, patch_id2 in itertools.combinations(range(num_patches), 2)
]

file = root / "m50_100000_100000_0.txt"
with open(file, "w") as fh:
fh.writelines(lines)

return len(lines), file

def _create_archive(self, root, name, *files):
archive = root / f"{name}.zip"
with zipfile.ZipFile(archive, "w") as zip:
for file in files:
zip.write(file, arcname=file.relative_to(root))

return archive

@datasets_utils.test_all_configs
def test_feature_types(self, config):
feature_types = self.FEATURE_TYPES
self.FEATURE_TYPES = self._TRAIN_FEATURE_TYPES if config["train"] else self._TEST_FEATURE_TYPES
try:
super().test_feature_types.__wrapped__(self, config)
finally:
self.FEATURE_TYPES = feature_types


if __name__ == "__main__":
unittest.main()
8 changes: 3 additions & 5 deletions torchvision/datasets/phototour.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,7 @@ def __getitem__(self, index: int) -> Union[torch.Tensor, Tuple[Any, Any, torch.T
return data1, data2, m[2]

def __len__(self) -> int:
if self.train:
return self.lens[self.name]
return len(self.matches)
return len(self.data if self.train else self.matches)

def _check_datafile_exists(self) -> bool:
return os.path.exists(self.data_file)
Expand Down Expand Up @@ -194,8 +192,8 @@ def find_files(_data_dir: str, _image_ext: str) -> List[str]:

for fpath in list_files:
img = Image.open(fpath)
for y in range(0, 1024, 64):
for x in range(0, 1024, 64):
for y in range(0, img.height, 64):
for x in range(0, img.width, 64):
patch = img.crop((x, y, x + 64, y + 64))
patches.append(PIL2array(patch))
return torch.ByteTensor(np.array(patches[:n]))
Expand Down

0 comments on commit f637c63

Please sign in to comment.