Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace get_tmp_dir() with tmpdir fixture in pytest #4280

Merged
merged 21 commits into from
Aug 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from numbers import Number
from torch._six import string_classes
from collections import OrderedDict
from torchvision import io

import numpy as np
from PIL import Image
Expand Down Expand Up @@ -147,6 +148,25 @@ def _create_data_batch(height=3, width=3, channels=3, num_samples=4, device="cpu
assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0)


def get_list_of_videos(tmpdir, num_videos=5, sizes=None, fps=None):
names = []
for i in range(num_videos):
if sizes is None:
size = 5 * (i + 1)
else:
size = sizes[i]
if fps is None:
f = 5
else:
f = fps[i]
data = torch.randint(0, 256, (size, 300, 400, 3), dtype=torch.uint8)
name = os.path.join(tmpdir, "{}.mp4".format(i))
names.append(name)
io.write_video(name, data, fps=f)

return names


def _assert_equal_tensor_to_pil(tensor, pil_image, msg=None):
np_pil_image = np.array(pil_image)
if np_pil_image.ndim == 2:
Expand Down
19 changes: 8 additions & 11 deletions test/test_datasets_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
USER_AGENT,
)

from common_utils import get_tmp_dir


def limit_requests_per_time(min_secs_between_requests=2.0):
last_requests = {}
Expand Down Expand Up @@ -166,16 +164,15 @@ def assert_url_is_accessible(url, timeout=5.0):
urlopen(request, timeout=timeout)


def assert_file_downloads_correctly(url, md5, timeout=5.0):
with get_tmp_dir() as root:
file = path.join(root, path.basename(url))
with assert_server_response_ok():
with open(file, "wb") as fh:
request = Request(url, headers={"User-Agent": USER_AGENT})
response = urlopen(request, timeout=timeout)
fh.write(response.read())
def assert_file_downloads_correctly(url, md5, tmpdir, timeout=5.0):
file = path.join(tmpdir, path.basename(url))
with assert_server_response_ok():
with open(file, "wb") as fh:
request = Request(url, headers={"User-Agent": USER_AGENT})
response = urlopen(request, timeout=timeout)
fh.write(response.read())

assert check_integrity(file, md5=md5), "The MD5 checksums mismatch"
assert check_integrity(file, md5=md5), "The MD5 checksums mismatch"


class DownloadConfig:
Expand Down
155 changes: 67 additions & 88 deletions test/test_datasets_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,104 +13,83 @@
from torchvision.datasets.video_utils import VideoClips, unfold
from torchvision import get_video_backend

from common_utils import get_tmp_dir, assert_equal


@contextlib.contextmanager
def get_list_of_videos(num_videos=5, sizes=None, fps=None):
with get_tmp_dir() as tmp_dir:
names = []
for i in range(num_videos):
if sizes is None:
size = 5 * (i + 1)
else:
size = sizes[i]
if fps is None:
f = 5
else:
f = fps[i]
data = torch.randint(0, 256, (size, 300, 400, 3), dtype=torch.uint8)
name = os.path.join(tmp_dir, "{}.mp4".format(i))
names.append(name)
io.write_video(name, data, fps=f)

yield names
from common_utils import get_list_of_videos, assert_equal


@pytest.mark.skipif(not io.video._av_available(), reason="this test requires av")
class TestDatasetsSamplers:
def test_random_clip_sampler(self):
with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list:
video_clips = VideoClips(video_list, 5, 5)
sampler = RandomClipSampler(video_clips, 3)
assert len(sampler) == 3 * 3
indices = torch.tensor(list(iter(sampler)))
videos = torch.div(indices, 5, rounding_mode='floor')
v_idxs, count = torch.unique(videos, return_counts=True)
assert_equal(v_idxs, torch.tensor([0, 1, 2]))
assert_equal(count, torch.tensor([3, 3, 3]))
def test_random_clip_sampler(self, tmpdir):
video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[25, 25, 25])
video_clips = VideoClips(video_list, 5, 5)
sampler = RandomClipSampler(video_clips, 3)
assert len(sampler) == 3 * 3
indices = torch.tensor(list(iter(sampler)))
videos = torch.div(indices, 5, rounding_mode='floor')
v_idxs, count = torch.unique(videos, return_counts=True)
assert_equal(v_idxs, torch.tensor([0, 1, 2]))
assert_equal(count, torch.tensor([3, 3, 3]))

def test_random_clip_sampler_unequal(self):
with get_list_of_videos(num_videos=3, sizes=[10, 25, 25]) as video_list:
video_clips = VideoClips(video_list, 5, 5)
sampler = RandomClipSampler(video_clips, 3)
assert len(sampler) == 2 + 3 + 3
indices = list(iter(sampler))
assert 0 in indices
assert 1 in indices
# remove elements of the first video, to simplify testing
indices.remove(0)
indices.remove(1)
indices = torch.tensor(indices) - 2
videos = torch.div(indices, 5, rounding_mode='floor')
v_idxs, count = torch.unique(videos, return_counts=True)
assert_equal(v_idxs, torch.tensor([0, 1]))
assert_equal(count, torch.tensor([3, 3]))
def test_random_clip_sampler_unequal(self, tmpdir):
video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[10, 25, 25])
video_clips = VideoClips(video_list, 5, 5)
sampler = RandomClipSampler(video_clips, 3)
assert len(sampler) == 2 + 3 + 3
indices = list(iter(sampler))
assert 0 in indices
assert 1 in indices
# remove elements of the first video, to simplify testing
indices.remove(0)
indices.remove(1)
indices = torch.tensor(indices) - 2
videos = torch.div(indices, 5, rounding_mode='floor')
v_idxs, count = torch.unique(videos, return_counts=True)
assert_equal(v_idxs, torch.tensor([0, 1]))
assert_equal(count, torch.tensor([3, 3]))

def test_uniform_clip_sampler(self):
with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list:
video_clips = VideoClips(video_list, 5, 5)
sampler = UniformClipSampler(video_clips, 3)
assert len(sampler) == 3 * 3
indices = torch.tensor(list(iter(sampler)))
videos = torch.div(indices, 5, rounding_mode='floor')
v_idxs, count = torch.unique(videos, return_counts=True)
assert_equal(v_idxs, torch.tensor([0, 1, 2]))
assert_equal(count, torch.tensor([3, 3, 3]))
assert_equal(indices, torch.tensor([0, 2, 4, 5, 7, 9, 10, 12, 14]))
def test_uniform_clip_sampler(self, tmpdir):
video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[25, 25, 25])
video_clips = VideoClips(video_list, 5, 5)
sampler = UniformClipSampler(video_clips, 3)
assert len(sampler) == 3 * 3
indices = torch.tensor(list(iter(sampler)))
videos = torch.div(indices, 5, rounding_mode='floor')
v_idxs, count = torch.unique(videos, return_counts=True)
assert_equal(v_idxs, torch.tensor([0, 1, 2]))
assert_equal(count, torch.tensor([3, 3, 3]))
assert_equal(indices, torch.tensor([0, 2, 4, 5, 7, 9, 10, 12, 14]))

def test_uniform_clip_sampler_insufficient_clips(self):
with get_list_of_videos(num_videos=3, sizes=[10, 25, 25]) as video_list:
video_clips = VideoClips(video_list, 5, 5)
sampler = UniformClipSampler(video_clips, 3)
assert len(sampler) == 3 * 3
indices = torch.tensor(list(iter(sampler)))
assert_equal(indices, torch.tensor([0, 0, 1, 2, 4, 6, 7, 9, 11]))
def test_uniform_clip_sampler_insufficient_clips(self, tmpdir):
video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[10, 25, 25])
video_clips = VideoClips(video_list, 5, 5)
sampler = UniformClipSampler(video_clips, 3)
assert len(sampler) == 3 * 3
indices = torch.tensor(list(iter(sampler)))
assert_equal(indices, torch.tensor([0, 0, 1, 2, 4, 6, 7, 9, 11]))

def test_distributed_sampler_and_uniform_clip_sampler(self):
with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list:
video_clips = VideoClips(video_list, 5, 5)
clip_sampler = UniformClipSampler(video_clips, 3)
def test_distributed_sampler_and_uniform_clip_sampler(self, tmpdir):
video_list = get_list_of_videos(tmpdir, num_videos=3, sizes=[25, 25, 25])
video_clips = VideoClips(video_list, 5, 5)
clip_sampler = UniformClipSampler(video_clips, 3)

distributed_sampler_rank0 = DistributedSampler(
clip_sampler,
num_replicas=2,
rank=0,
group_size=3,
)
indices = torch.tensor(list(iter(distributed_sampler_rank0)))
assert len(distributed_sampler_rank0) == 6
assert_equal(indices, torch.tensor([0, 2, 4, 10, 12, 14]))
distributed_sampler_rank0 = DistributedSampler(
clip_sampler,
num_replicas=2,
rank=0,
group_size=3,
)
indices = torch.tensor(list(iter(distributed_sampler_rank0)))
assert len(distributed_sampler_rank0) == 6
assert_equal(indices, torch.tensor([0, 2, 4, 10, 12, 14]))

distributed_sampler_rank1 = DistributedSampler(
clip_sampler,
num_replicas=2,
rank=1,
group_size=3,
)
indices = torch.tensor(list(iter(distributed_sampler_rank1)))
assert len(distributed_sampler_rank1) == 6
assert_equal(indices, torch.tensor([5, 7, 9, 0, 2, 4]))
distributed_sampler_rank1 = DistributedSampler(
clip_sampler,
num_replicas=2,
rank=1,
group_size=3,
)
indices = torch.tensor(list(iter(distributed_sampler_rank1)))
assert len(distributed_sampler_rank1) == 6
assert_equal(indices, torch.tensor([5, 7, 9, 0, 2, 4]))


if __name__ == '__main__':
Expand Down
49 changes: 22 additions & 27 deletions test/test_datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import lzma
import contextlib

from common_utils import get_tmp_dir
from torchvision.datasets.utils import _COMPRESSED_FILE_OPENERS


Expand Down Expand Up @@ -113,7 +112,7 @@ def test_detect_file_type_incompatible(self, file):
utils._detect_file_type(file)

@pytest.mark.parametrize('extension', [".bz2", ".gz", ".xz"])
def test_decompress(self, extension):
def test_decompress(self, extension, tmpdir):
def create_compressed(root, content="this is the content"):
file = os.path.join(root, "file")
compressed = f"{file}{extension}"
Expand All @@ -124,21 +123,20 @@ def create_compressed(root, content="this is the content"):

return compressed, file, content

with get_tmp_dir() as temp_dir:
compressed, file, content = create_compressed(temp_dir)
compressed, file, content = create_compressed(tmpdir)

utils._decompress(compressed)
utils._decompress(compressed)

assert os.path.exists(file)
assert os.path.exists(file)

with open(file, "r") as fh:
assert fh.read() == content
with open(file, "r") as fh:
assert fh.read() == content

def test_decompress_no_compression(self):
with pytest.raises(RuntimeError):
utils._decompress("foo.tar")

def test_decompress_remove_finished(self):
def test_decompress_remove_finished(self, tmpdir):
def create_compressed(root, content="this is the content"):
file = os.path.join(root, "file")
compressed = f"{file}.gz"
Expand All @@ -148,12 +146,11 @@ def create_compressed(root, content="this is the content"):

return compressed, file, content

with get_tmp_dir() as temp_dir:
compressed, file, content = create_compressed(temp_dir)
compressed, file, content = create_compressed(tmpdir)

utils.extract_archive(compressed, temp_dir, remove_finished=True)
utils.extract_archive(compressed, tmpdir, remove_finished=True)

assert not os.path.exists(compressed)
assert not os.path.exists(compressed)

@pytest.mark.parametrize('extension', [".gz", ".xz"])
@pytest.mark.parametrize('remove_finished', [True, False])
Expand All @@ -166,7 +163,7 @@ def test_extract_archive_defer_to_decompress(self, extension, remove_finished, m

mocked.assert_called_once_with(file, filename, remove_finished=remove_finished)

def test_extract_zip(self):
def test_extract_zip(self, tmpdir):
def create_archive(root, content="this is the content"):
file = os.path.join(root, "dst.txt")
archive = os.path.join(root, "archive.zip")
Expand All @@ -176,19 +173,18 @@ def create_archive(root, content="this is the content"):

return archive, file, content

with get_tmp_dir() as temp_dir:
archive, file, content = create_archive(temp_dir)
archive, file, content = create_archive(tmpdir)

utils.extract_archive(archive, temp_dir)
utils.extract_archive(archive, tmpdir)

assert os.path.exists(file)
assert os.path.exists(file)

with open(file, "r") as fh:
assert fh.read() == content
with open(file, "r") as fh:
assert fh.read() == content

@pytest.mark.parametrize('extension, mode', [
('.tar', 'w'), ('.tar.gz', 'w:gz'), ('.tgz', 'w:gz'), ('.tar.xz', 'w:xz')])
def test_extract_tar(self, extension, mode):
def test_extract_tar(self, extension, mode, tmpdir):
def create_archive(root, extension, mode, content="this is the content"):
src = os.path.join(root, "src.txt")
dst = os.path.join(root, "dst.txt")
Expand All @@ -202,15 +198,14 @@ def create_archive(root, extension, mode, content="this is the content"):

return archive, dst, content

with get_tmp_dir() as temp_dir:
archive, file, content = create_archive(temp_dir, extension, mode)
archive, file, content = create_archive(tmpdir, extension, mode)

utils.extract_archive(archive, temp_dir)
utils.extract_archive(archive, tmpdir)

assert os.path.exists(file)
assert os.path.exists(file)

with open(file, "r") as fh:
assert fh.read() == content
with open(file, "r") as fh:
assert fh.read() == content

def test_verify_str_arg(self):
assert "a" == utils.verify_str_arg("a", "arg", ("a",))
Expand Down
Loading