Skip to content

Commit

Permalink
Automatically send video to CPU in io.write_video (#8537)
Browse files Browse the repository at this point in the history
Co-authored-by: Nicolas Hug <nh.nicolas.hug@gmail.com>
  • Loading branch information
bmmtstb and NicolasHug authored Jul 25, 2024
1 parent 4a1cb63 commit 3e60dbd
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
6 changes: 5 additions & 1 deletion test/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,10 +255,14 @@ def test_read_video_partially_corrupted_file(self):
assert_equal(video, data)

@pytest.mark.skipif(sys.platform == "win32", reason="temporarily disabled on Windows")
def test_write_video_with_audio(self, tmpdir):
@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_write_video_with_audio(self, device, tmpdir):
f_name = os.path.join(VIDEO_DIR, "R6llTwEh07w.mp4")
video_tensor, audio_tensor, info = io.read_video(f_name, pts_unit="sec")

video_tensor = video_tensor.to(device)
audio_tensor = audio_tensor.to(device)

out_f_name = os.path.join(tmpdir, "testing.mp4")
io.video.write_video(
out_f_name,
Expand Down
2 changes: 1 addition & 1 deletion torchvision/io/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def write_video(
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(write_video)
_check_av_available()
video_array = torch.as_tensor(video_array, dtype=torch.uint8).numpy()
video_array = torch.as_tensor(video_array, dtype=torch.uint8).numpy(force=True)

# PyAV does not support floating point numbers with decimal point
# and will throw OverflowException in case this is not the case
Expand Down

0 comments on commit 3e60dbd

Please sign in to comment.