Skip to content

Commit

Permalink
[fbsync] Replace get_tmp_dir() with tmpdir fixture in tests (#4280)
Browse files Browse the repository at this point in the history
Summary:
* Replace in test_datasets*

* Replace in test_image.py

* Replace in test_transforms_tensor.py

* Replace in test_internet.py and test_io.py

* get_list_of_videos is util function still use get_tmp_dir

* Fix get_list_of_videos siginiture

* Add get_tmp_dir import

* Modify test_datasets_video_utils.py for test to pass

* Fix indentation

* Replace get_tmp_dir in util functions in test_dataset_sampler.py

* Replace get_tmp_dir in util functions in test_dataset_video_utils.py

* Move get_tmp_dir() to datasets_utils.py and refactor

* Fix pylint, indentation and imports

* import shutil to common_util.py

* Fix function signiture

* Remove get_list_of_videos under context manager

* Move get_list_of_videos to common_utils.py

* Move get_tmp_dir() back to common_utils.py

* Fix pylint and imports

Reviewed By: NicolasHug

Differential Revision: D30417192

fbshipit-source-id: fd5ae2ad7f21509dbe09f7df85f8d9006b9ed1ea

Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>
  • Loading branch information
2 people authored and facebook-github-bot committed Aug 20, 2021
1 parent 6d8f0cb commit 3dd8d65
Show file tree
Hide file tree
Showing 9 changed files with 317 additions and 374 deletions.
20 changes: 20 additions & 0 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from numbers import Number
from torch._six import string_classes
from collections import OrderedDict
from torchvision import io

import numpy as np
from PIL import Image
Expand Down Expand Up @@ -147,6 +148,25 @@ def _create_data_batch(height=3, width=3, channels=3, num_samples=4, device="cpu
assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0)


def get_list_of_videos(tmpdir, num_videos=5, sizes=None, fps=None):
names = []
for i in range(num_videos):
if sizes is None:
size = 5 * (i + 1)
else:
size = sizes[i]
if fps is None:
f = 5
else:
f = fps[i]
data = torch.randint(0, 256, (size, 300, 400, 3), dtype=torch.uint8)
name = os.path.join(tmpdir, "{}.mp4".format(i))
names.append(name)
io.write_video(name, data, fps=f)

return names


def _assert_equal_tensor_to_pil(tensor, pil_image, msg=None):
np_pil_image = np.array(pil_image)
if np_pil_image.ndim == 2:
Expand Down
19 changes: 8 additions & 11 deletions test/test_datasets_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
USER_AGENT,
)

from common_utils import get_tmp_dir


def limit_requests_per_time(min_secs_between_requests=2.0):
last_requests = {}
Expand Down Expand Up @@ -166,16 +164,15 @@ def assert_url_is_accessible(url, timeout=5.0):
urlopen(request, timeout=timeout)


def assert_file_downloads_correctly(url, md5, timeout=5.0):
with get_tmp_dir() as root:
file = path.join(root, path.basename(url))
with assert_server_response_ok():
with open(file, "wb") as fh:
request = Request(url, headers={"User-Agent": USER_AGENT})
response = urlopen(request, timeout=timeout)
fh.write(response.read())
def assert_file_downloads_correctly(url, md5, tmpdir, timeout=5.0):
file = path.join(tmpdir, path.basename(url))
with assert_server_response_ok():
with open(file, "wb") as fh:
request = Request(url, headers={"User-Agent": USER_AGENT})
response = urlopen(request, timeout=timeout)
fh.write(response.read())

assert check_integrity(file, md5=md5), "The MD5 checksums mismatch"
assert check_integrity(file, md5=md5), "The MD5 checksums mismatch"


class DownloadConfig:
Expand Down
155 changes: 67 additions & 88 deletions test/test_datasets_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,104 +13,83 @@
from torchvision.datasets.video_utils import VideoClips, unfold
from torchvision import get_video_backend

from common_utils import get_tmp_dir, assert_equal


@contextlib.contextmanager
def get_list_of_videos(num_videos=5, sizes=None, fps=None):
with get_tmp_dir() as tmp_dir:
names = []
for i in range(num_videos):
if sizes is None:
size = 5 * (i + 1)
else:
size = sizes[i]
if fps is None:
f = 5
else:
f = fps[i]
data = torch.randint(0, 256, (size, 300, 400, 3), dtype=torch.uint8)
name = os.path.join(tmp_dir, "{}.mp4".format(i))
names.append(name)
io.write_video(name, data, fps=f)

yield names
from common_utils import get_list_of_videos, assert_equal


@pytest.mark.skipif(not io.video._av_available(), reason="this test requires av")
class TestDatasetsSamplers:
def test_random_clip_sampler(self):
with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list:
video_clips = VideoClips(video_list, 5, 5)
sampler = RandomClipSampler(video_clips, 3)
assert len(sampler) == 3 * 3
indices = torch.tensor(list(iter(sampler)))
videos = torch.div(indices, 5, rounding_mode='floor')
v_idxs, count = torch.unique(videos, return_counts=True)
assert_equal(v_idxs, torch.tensor([0, 1, 2]))
assert_equal(count, torch.tensor([3, 3, 3]))
def test_random_clip_sampler(self, tmpdir):
video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[25, 25, 25])
video_clips = VideoClips(video_list, 5, 5)
sampler = RandomClipSampler(video_clips, 3)
assert len(sampler) == 3 * 3
indices = torch.tensor(list(iter(sampler)))
videos = torch.div(indices, 5, rounding_mode='floor')
v_idxs, count = torch.unique(videos, return_counts=True)
assert_equal(v_idxs, torch.tensor([0, 1, 2]))
assert_equal(count, torch.tensor([3, 3, 3]))

def test_random_clip_sampler_unequal(self):
with get_list_of_videos(num_videos=3, sizes=[10, 25, 25]) as video_list:
video_clips = VideoClips(video_list, 5, 5)
sampler = RandomClipSampler(video_clips, 3)
assert len(sampler) == 2 + 3 + 3
indices = list(iter(sampler))
assert 0 in indices
assert 1 in indices
# remove elements of the first video, to simplify testing
indices.remove(0)
indices.remove(1)
indices = torch.tensor(indices) - 2
videos = torch.div(indices, 5, rounding_mode='floor')
v_idxs, count = torch.unique(videos, return_counts=True)
assert_equal(v_idxs, torch.tensor([0, 1]))
assert_equal(count, torch.tensor([3, 3]))
def test_random_clip_sampler_unequal(self, tmpdir):
video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[10, 25, 25])
video_clips = VideoClips(video_list, 5, 5)
sampler = RandomClipSampler(video_clips, 3)
assert len(sampler) == 2 + 3 + 3
indices = list(iter(sampler))
assert 0 in indices
assert 1 in indices
# remove elements of the first video, to simplify testing
indices.remove(0)
indices.remove(1)
indices = torch.tensor(indices) - 2
videos = torch.div(indices, 5, rounding_mode='floor')
v_idxs, count = torch.unique(videos, return_counts=True)
assert_equal(v_idxs, torch.tensor([0, 1]))
assert_equal(count, torch.tensor([3, 3]))

def test_uniform_clip_sampler(self):
with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list:
video_clips = VideoClips(video_list, 5, 5)
sampler = UniformClipSampler(video_clips, 3)
assert len(sampler) == 3 * 3
indices = torch.tensor(list(iter(sampler)))
videos = torch.div(indices, 5, rounding_mode='floor')
v_idxs, count = torch.unique(videos, return_counts=True)
assert_equal(v_idxs, torch.tensor([0, 1, 2]))
assert_equal(count, torch.tensor([3, 3, 3]))
assert_equal(indices, torch.tensor([0, 2, 4, 5, 7, 9, 10, 12, 14]))
def test_uniform_clip_sampler(self, tmpdir):
video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[25, 25, 25])
video_clips = VideoClips(video_list, 5, 5)
sampler = UniformClipSampler(video_clips, 3)
assert len(sampler) == 3 * 3
indices = torch.tensor(list(iter(sampler)))
videos = torch.div(indices, 5, rounding_mode='floor')
v_idxs, count = torch.unique(videos, return_counts=True)
assert_equal(v_idxs, torch.tensor([0, 1, 2]))
assert_equal(count, torch.tensor([3, 3, 3]))
assert_equal(indices, torch.tensor([0, 2, 4, 5, 7, 9, 10, 12, 14]))

def test_uniform_clip_sampler_insufficient_clips(self):
with get_list_of_videos(num_videos=3, sizes=[10, 25, 25]) as video_list:
video_clips = VideoClips(video_list, 5, 5)
sampler = UniformClipSampler(video_clips, 3)
assert len(sampler) == 3 * 3
indices = torch.tensor(list(iter(sampler)))
assert_equal(indices, torch.tensor([0, 0, 1, 2, 4, 6, 7, 9, 11]))
def test_uniform_clip_sampler_insufficient_clips(self, tmpdir):
video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[10, 25, 25])
video_clips = VideoClips(video_list, 5, 5)
sampler = UniformClipSampler(video_clips, 3)
assert len(sampler) == 3 * 3
indices = torch.tensor(list(iter(sampler)))
assert_equal(indices, torch.tensor([0, 0, 1, 2, 4, 6, 7, 9, 11]))

def test_distributed_sampler_and_uniform_clip_sampler(self):
with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list:
video_clips = VideoClips(video_list, 5, 5)
clip_sampler = UniformClipSampler(video_clips, 3)
def test_distributed_sampler_and_uniform_clip_sampler(self, tmpdir):
video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[25, 25, 25])
video_clips = VideoClips(video_list, 5, 5)
clip_sampler = UniformClipSampler(video_clips, 3)

distributed_sampler_rank0 = DistributedSampler(
clip_sampler,
num_replicas=2,
rank=0,
group_size=3,
)
indices = torch.tensor(list(iter(distributed_sampler_rank0)))
assert len(distributed_sampler_rank0) == 6
assert_equal(indices, torch.tensor([0, 2, 4, 10, 12, 14]))
distributed_sampler_rank0 = DistributedSampler(
clip_sampler,
num_replicas=2,
rank=0,
group_size=3,
)
indices = torch.tensor(list(iter(distributed_sampler_rank0)))
assert len(distributed_sampler_rank0) == 6
assert_equal(indices, torch.tensor([0, 2, 4, 10, 12, 14]))

distributed_sampler_rank1 = DistributedSampler(
clip_sampler,
num_replicas=2,
rank=1,
group_size=3,
)
indices = torch.tensor(list(iter(distributed_sampler_rank1)))
assert len(distributed_sampler_rank1) == 6
assert_equal(indices, torch.tensor([5, 7, 9, 0, 2, 4]))
distributed_sampler_rank1 = DistributedSampler(
clip_sampler,
num_replicas=2,
rank=1,
group_size=3,
)
indices = torch.tensor(list(iter(distributed_sampler_rank1)))
assert len(distributed_sampler_rank1) == 6
assert_equal(indices, torch.tensor([5, 7, 9, 0, 2, 4]))


if __name__ == '__main__':
Expand Down
49 changes: 22 additions & 27 deletions test/test_datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import lzma
import contextlib

from common_utils import get_tmp_dir
from torchvision.datasets.utils import _COMPRESSED_FILE_OPENERS


Expand Down Expand Up @@ -113,7 +112,7 @@ def test_detect_file_type_incompatible(self, file):
utils._detect_file_type(file)

@pytest.mark.parametrize('extension', [".bz2", ".gz", ".xz"])
def test_decompress(self, extension):
def test_decompress(self, extension, tmpdir):
def create_compressed(root, content="this is the content"):
file = os.path.join(root, "file")
compressed = f"{file}{extension}"
Expand All @@ -124,21 +123,20 @@ def create_compressed(root, content="this is the content"):

return compressed, file, content

with get_tmp_dir() as temp_dir:
compressed, file, content = create_compressed(temp_dir)
compressed, file, content = create_compressed(tmpdir)

utils._decompress(compressed)
utils._decompress(compressed)

assert os.path.exists(file)
assert os.path.exists(file)

with open(file, "r") as fh:
assert fh.read() == content
with open(file, "r") as fh:
assert fh.read() == content

def test_decompress_no_compression(self):
with pytest.raises(RuntimeError):
utils._decompress("foo.tar")

def test_decompress_remove_finished(self):
def test_decompress_remove_finished(self, tmpdir):
def create_compressed(root, content="this is the content"):
file = os.path.join(root, "file")
compressed = f"{file}.gz"
Expand All @@ -148,12 +146,11 @@ def create_compressed(root, content="this is the content"):

return compressed, file, content

with get_tmp_dir() as temp_dir:
compressed, file, content = create_compressed(temp_dir)
compressed, file, content = create_compressed(tmpdir)

utils.extract_archive(compressed, temp_dir, remove_finished=True)
utils.extract_archive(compressed, tmpdir, remove_finished=True)

assert not os.path.exists(compressed)
assert not os.path.exists(compressed)

@pytest.mark.parametrize('extension', [".gz", ".xz"])
@pytest.mark.parametrize('remove_finished', [True, False])
Expand All @@ -166,7 +163,7 @@ def test_extract_archive_defer_to_decompress(self, extension, remove_finished, m

mocked.assert_called_once_with(file, filename, remove_finished=remove_finished)

def test_extract_zip(self):
def test_extract_zip(self, tmpdir):
def create_archive(root, content="this is the content"):
file = os.path.join(root, "dst.txt")
archive = os.path.join(root, "archive.zip")
Expand All @@ -176,19 +173,18 @@ def create_archive(root, content="this is the content"):

return archive, file, content

with get_tmp_dir() as temp_dir:
archive, file, content = create_archive(temp_dir)
archive, file, content = create_archive(tmpdir)

utils.extract_archive(archive, temp_dir)
utils.extract_archive(archive, tmpdir)

assert os.path.exists(file)
assert os.path.exists(file)

with open(file, "r") as fh:
assert fh.read() == content
with open(file, "r") as fh:
assert fh.read() == content

@pytest.mark.parametrize('extension, mode', [
('.tar', 'w'), ('.tar.gz', 'w:gz'), ('.tgz', 'w:gz'), ('.tar.xz', 'w:xz')])
def test_extract_tar(self, extension, mode):
def test_extract_tar(self, extension, mode, tmpdir):
def create_archive(root, extension, mode, content="this is the content"):
src = os.path.join(root, "src.txt")
dst = os.path.join(root, "dst.txt")
Expand All @@ -202,15 +198,14 @@ def create_archive(root, extension, mode, content="this is the content"):

return archive, dst, content

with get_tmp_dir() as temp_dir:
archive, file, content = create_archive(temp_dir, extension, mode)
archive, file, content = create_archive(tmpdir, extension, mode)

utils.extract_archive(archive, temp_dir)
utils.extract_archive(archive, tmpdir)

assert os.path.exists(file)
assert os.path.exists(file)

with open(file, "r") as fh:
assert fh.read() == content
with open(file, "r") as fh:
assert fh.read() == content

def test_verify_str_arg(self):
assert "a" == utils.verify_str_arg("a", "arg", ("a",))
Expand Down
Loading

0 comments on commit 3dd8d65

Please sign in to comment.