From 61994a08f25a8842b2ce57848de5fe79eb9bfe18 Mon Sep 17 00:00:00 2001 From: Martin <1500595+bmmtstb@users.noreply.github.com> Date: Thu, 18 Jul 2024 19:18:39 +0200 Subject: [PATCH] Allow write_video to work with video_arrays on cuda devices --- test/test_io.py | 6 +++++- torchvision/io/video.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/test/test_io.py b/test/test_io.py index c45180571f0..0f57da29c27 100644 --- a/test/test_io.py +++ b/test/test_io.py @@ -255,10 +255,14 @@ def test_read_video_partially_corrupted_file(self): assert_equal(video, data) @pytest.mark.skipif(sys.platform == "win32", reason="temporarily disabled on Windows") - def test_write_video_with_audio(self, tmpdir): + @pytest.mark.parametrize("device", ["cpu", "cuda"]) + def test_write_video_with_audio(self, device, tmpdir): f_name = os.path.join(VIDEO_DIR, "R6llTwEh07w.mp4") video_tensor, audio_tensor, info = io.read_video(f_name, pts_unit="sec") + video_tensor = video_tensor.to(device) + audio_tensor = audio_tensor.to(device) + out_f_name = os.path.join(tmpdir, "testing.mp4") io.video.write_video( out_f_name, diff --git a/torchvision/io/video.py b/torchvision/io/video.py index 9b2eacbab11..3bccf414a71 100644 --- a/torchvision/io/video.py +++ b/torchvision/io/video.py @@ -80,7 +80,7 @@ def write_video( if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(write_video) _check_av_available() - video_array = torch.as_tensor(video_array, dtype=torch.uint8).numpy() + video_array = torch.as_tensor(video_array, dtype=torch.uint8).numpy(force=True) # PyAV does not support floating point numbers with decimal point # and will throw OverflowException in case this is not the case