Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port test_datasets_video_utils.py to pytest #4035

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions test/test_datasets_video_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import contextlib
import os
import torch
import unittest
import pytest

from torchvision import io
from torchvision.datasets.video_utils import VideoClips, unfold
Expand Down Expand Up @@ -31,7 +31,7 @@ def get_list_of_videos(num_videos=5, sizes=None, fps=None):
yield names


class Tester(unittest.TestCase):
class TestVideo:

def test_unfold(self):
a = torch.arange(7)
Expand All @@ -58,7 +58,7 @@ def test_unfold(self):
])
assert_equal(r, expected, check_stride=False)

@unittest.skipIf(not io.video._av_available(), "this test requires av")
@pytest.mark.skipif(not io.video._av_available(), reason="this test requires av")
def test_video_clips(self):
with get_list_of_videos(num_videos=3) as video_list:
video_clips = VideoClips(video_list, 5, 5, num_workers=2)
Expand All @@ -82,7 +82,7 @@ def test_video_clips(self):
assert video_idx == v_idx
assert clip_idx == c_idx

@unittest.skipIf(not io.video._av_available(), "this test requires av")
@pytest.mark.skipif(not io.video._av_available(), reason="this test requires av")
def test_video_clips_custom_fps(self):
with get_list_of_videos(num_videos=3, sizes=[12, 12, 12], fps=[3, 4, 6]) as video_list:
num_frames = 4
Expand Down Expand Up @@ -124,12 +124,12 @@ def test_compute_clips_for_video(self):
num_frames = 32
orig_fps = 30
new_fps = 13
with self.assertWarns(UserWarning):
with pytest.warns(UserWarning):
clips, idxs = VideoClips.compute_clips_for_video(video_pts, num_frames, num_frames,
orig_fps, new_fps)
assert len(clips) == 0
assert len(idxs) == 0


if __name__ == '__main__':
unittest.main()
pytest.main([__file__])