Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify video backend #1514

Merged
merged 7 commits into from
Oct 23, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 31 additions & 47 deletions test/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,6 @@
except ImportError:
av = None

_video_backend = get_video_backend()


def _read_video(filename, start_pts=0, end_pts=None):
if _video_backend == "pyav":
return io.read_video(filename, start_pts, end_pts)
else:
if end_pts is None:
end_pts = -1
return io._read_video_from_file(
filename,
video_pts_range=(start_pts, end_pts),
)


def _create_video_frames(num_frames, height, width):
y, x = torch.meshgrid(torch.linspace(-2, 2, height), torch.linspace(-2, 2, width))
Expand All @@ -61,7 +47,7 @@ def temp_video(num_frames, height, width, fps, lossless=False, video_codec=None,
options = {'crf': '0'}

if video_codec is None:
if _video_backend == "pyav":
if get_video_backend() == "pyav":
video_codec = 'libx264'
else:
# when video_codec is not set, we assume it is libx264rgb which accepts
Expand All @@ -76,16 +62,18 @@ def temp_video(num_frames, height, width, fps, lossless=False, video_codec=None,
yield f.name, data


@unittest.skipIf(get_video_backend() != "pyav" and not io._HAS_VIDEO_OPT,
"video_reader backend not available")
@unittest.skipIf(av is None, "PyAV unavailable")
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
@unittest.skipIf(sys.platform == 'win32', 'temporarily disabled on Windows')
class Tester(unittest.TestCase):
# compression adds artifacts, thus we add a tolerance of
# 6 in 0-255 range
TOLERANCE = 6

def test_write_read_video(self):
with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
lv, _, info = _read_video(f_name)
lv, _, info = io.read_video(f_name)
self.assertTrue(data.equal(lv))
self.assertEqual(info["video_fps"], 5)

Expand All @@ -107,10 +95,7 @@ def test_probe_video_from_memory(self):

def test_read_timestamps(self):
with temp_video(10, 300, 300, 5) as (f_name, data):
if _video_backend == "pyav":
pts, _ = io.read_video_timestamps(f_name)
else:
pts, _, _ = io._read_video_timestamps_from_file(f_name)
pts, _ = io.read_video_timestamps(f_name)
# note: not all formats/codecs provide accurate information for computing the
# timestamps. For the format that we use here, this information is available,
# so we use it as a baseline
Expand All @@ -124,42 +109,41 @@ def test_read_timestamps(self):

def test_read_partial_video(self):
with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
if _video_backend == "pyav":
pts, _ = io.read_video_timestamps(f_name)
else:
pts, _, _ = io._read_video_timestamps_from_file(f_name)
pts, _ = io.read_video_timestamps(f_name)
for start in range(5):
for l in range(1, 4):
lv, _, _ = _read_video(f_name, pts[start], pts[start + l - 1])
lv, _, _ = io.read_video(f_name, pts[start], pts[start + l - 1])
s_data = data[start:(start + l)]
self.assertEqual(len(lv), l)
self.assertTrue(s_data.equal(lv))

if _video_backend == "pyav":
if get_video_backend() == "pyav":
# for "video_reader" backend, we don't decode the closest early frame
# when the given start pts is not matching any frame pts
lv, _, _ = _read_video(f_name, pts[4] + 1, pts[7])
lv, _, _ = io.read_video(f_name, pts[4] + 1, pts[7])
self.assertEqual(len(lv), 4)
self.assertTrue(data[4:8].equal(lv))

def test_read_partial_video_bframes(self):
# do not use lossless encoding, to test the presence of B-frames
options = {'bframes': '16', 'keyint': '10', 'min-keyint': '4'}
with temp_video(100, 300, 300, 5, options=options) as (f_name, data):
if _video_backend == "pyav":
pts, _ = io.read_video_timestamps(f_name)
else:
pts, _, _ = io._read_video_timestamps_from_file(f_name)
pts, _ = io.read_video_timestamps(f_name)
for start in range(0, 80, 20):
for l in range(1, 4):
lv, _, _ = _read_video(f_name, pts[start], pts[start + l - 1])
lv, _, _ = io.read_video(f_name, pts[start], pts[start + l - 1])
s_data = data[start:(start + l)]
self.assertEqual(len(lv), l)
self.assertTrue((s_data.float() - lv.float()).abs().max() < self.TOLERANCE)

lv, _, _ = io.read_video(f_name, pts[4] + 1, pts[7])
self.assertEqual(len(lv), 4)
self.assertTrue((data[4:8].float() - lv.float()).abs().max() < self.TOLERANCE)
# TODO fix this
if get_video_backend() == 'pyav':
self.assertEqual(len(lv), 4)
self.assertTrue((data[4:8].float() - lv.float()).abs().max() < self.TOLERANCE)
else:
self.assertEqual(len(lv), 3)
self.assertTrue((data[5:8].float() - lv.float()).abs().max() < self.TOLERANCE)

def test_read_packed_b_frames_divx_file(self):
with get_tmp_dir() as temp_dir:
Expand All @@ -168,11 +152,7 @@ def test_read_packed_b_frames_divx_file(self):
url = "https://download.pytorch.org/vision_tests/io/" + name
try:
utils.download_url(url, temp_dir)
if _video_backend == "pyav":
pts, fps = io.read_video_timestamps(f_name)
else:
pts, _, info = io._read_video_timestamps_from_file(f_name)
fps = info["video_fps"]
pts, fps = io.read_video_timestamps(f_name)

self.assertEqual(pts, sorted(pts))
self.assertEqual(fps, 30)
Expand All @@ -183,10 +163,7 @@ def test_read_packed_b_frames_divx_file(self):

def test_read_timestamps_from_packet(self):
with temp_video(10, 300, 300, 5, video_codec='mpeg4') as (f_name, data):
if _video_backend == "pyav":
pts, _ = io.read_video_timestamps(f_name)
else:
pts, _, _ = io._read_video_timestamps_from_file(f_name)
pts, _ = io.read_video_timestamps(f_name)
# note: not all formats/codecs provide accurate information for computing the
# timestamps. For the format that we use here, this information is available,
# so we use it as a baseline
Expand Down Expand Up @@ -235,8 +212,11 @@ def test_read_partial_video_pts_unit_sec(self):
lv, _, _ = io.read_video(f_name,
int(pts[4] * (1.0 / stream.time_base) + 1) * stream.time_base, pts[7],
pts_unit='sec')
self.assertEqual(len(lv), 4)
self.assertTrue(data[4:8].equal(lv))
if get_video_backend() == "pyav":
# for "video_reader" backend, we don't decode the closest early frame
# when the given start pts is not matching any frame pts
self.assertEqual(len(lv), 4)
self.assertTrue(data[4:8].equal(lv))

def test_read_video_corrupted_file(self):
with tempfile.NamedTemporaryFile(suffix='.mp4') as f:
Expand Down Expand Up @@ -267,7 +247,11 @@ def test_read_video_partially_corrupted_file(self):
# this exercises the container.decode assertion check
video, audio, info = io.read_video(f.name, pts_unit='sec')
# check that size is not equal to 5, but 3
self.assertEqual(len(video), 3)
# TODO fix this
if get_video_backend() == 'pyav':
self.assertEqual(len(video), 3)
else:
self.assertEqual(len(video), 4)
# but the valid decoded content is still correct
self.assertTrue(video[:3].equal(data[:3]))
# and the last few frames are wrong
Expand Down
11 changes: 11 additions & 0 deletions test/test_io_opt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import unittest
from torchvision import set_video_backend
import test_io


set_video_backend('video_reader')


if __name__ == '__main__':
suite = unittest.TestLoader().loadTestsFromModule(test_io)
unittest.TextTestRunner(verbosity=1).run(suite)
2 changes: 1 addition & 1 deletion test/test_video_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from urllib.error import URLError


from torchvision.io._video_opt import _HAS_VIDEO_OPT
from torchvision.io import _HAS_VIDEO_OPT


VIDEO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "videos")
Expand Down
3 changes: 1 addition & 2 deletions torchvision/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from .video import write_video, read_video, read_video_timestamps
from .video import write_video, read_video, read_video_timestamps, _HAS_VIDEO_OPT
from ._video_opt import (
_read_video_from_file,
_read_video_timestamps_from_file,
_probe_video_from_file,
_read_video_from_memory,
_read_video_timestamps_from_memory,
_probe_video_from_memory,
_HAS_VIDEO_OPT,
)


Expand Down
76 changes: 64 additions & 12 deletions torchvision/io/_video_opt.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,10 @@
from fractions import Fraction
import math
import numpy as np
import os
import torch
import imp
import warnings


_HAS_VIDEO_OPT = False

try:
lib_dir = os.path.join(os.path.dirname(__file__), '..')
_, path, description = imp.find_module("video_reader", [lib_dir])
torch.ops.load_library(path)
_HAS_VIDEO_OPT = True
except (ImportError, OSError):
warnings.warn("video reader based on ffmpeg c++ ops not available")

default_timebase = Fraction(0, 1)


Expand Down Expand Up @@ -356,3 +345,66 @@ def _probe_video_from_memory(video_data):
vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result
info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
return info


def _read_video(filename, start_pts=0, end_pts=None, pts_unit='pts'):
if end_pts is None:
end_pts = float("inf")

if pts_unit == 'pts':
warnings.warn("The pts_unit 'pts' gives wrong results and will be removed in a " +
"follow-up version. Please use pts_unit 'sec'.")

info = _probe_video_from_file(filename)

has_video = 'video_timebase' in info
has_audio = 'audio_timebase' in info

def get_pts(time_base):
start_offset = start_pts
end_offset = end_pts
if pts_unit == 'sec':
start_offset = int(math.floor(start_pts * (1 / time_base)))
if end_offset != float("inf"):
end_offset = int(math.ceil(end_pts * (1 / time_base)))
if end_offset == float("inf"):
end_offset = -1
return start_offset, end_offset

video_pts_range = (0, -1)
video_timebase = default_timebase
if has_video:
video_timebase = info['video_timebase']
video_pts_range = get_pts(video_timebase)

audio_pts_range = (0, -1)
audio_timebase = default_timebase
if has_audio:
audio_timebase = info['audio_timebase']
audio_pts_range = get_pts(audio_timebase)

return _read_video_from_file(
filename,
read_video_stream=True,
video_pts_range=video_pts_range,
video_timebase=video_timebase,
read_audio_stream=True,
audio_pts_range=audio_pts_range,
audio_timebase=audio_timebase,
)


def _read_video_timestamps(filename, pts_unit='pts'):
if pts_unit == 'pts':
warnings.warn("The pts_unit 'pts' gives wrong results and will be removed in a " +
"follow-up version. Please use pts_unit 'sec'.")

pts, _, info = _read_video_timestamps_from_file(filename)

if pts_unit == 'sec':
video_time_base = info['video_timebase']
pts = [x * video_time_base for x in pts]

video_fps = info.get('video_fps', None)

return pts, video_fps
25 changes: 25 additions & 0 deletions torchvision/io/video.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,26 @@
import re
import imp
import gc
import os
import torch
import numpy as np
import math
import warnings

from . import _video_opt


_HAS_VIDEO_OPT = False

try:
lib_dir = os.path.join(os.path.dirname(__file__), '..')
_, path, description = imp.find_module("video_reader", [lib_dir])
torch.ops.load_library(path)
_HAS_VIDEO_OPT = True
except (ImportError, OSError):
pass


try:
import av
av.logging.set_level(av.logging.ERROR)
Expand Down Expand Up @@ -190,6 +206,11 @@ def read_video(filename, start_pts=0, end_pts=None, pts_unit='pts'):
metadata for the video and audio. Can contain the fields video_fps (float)
and audio_fps (int)
"""

from torchvision import get_video_backend
if get_video_backend() != "pyav":
return _video_opt._read_video(filename, start_pts, end_pts, pts_unit)

_check_av_available()

if end_pts is None:
Expand Down Expand Up @@ -273,6 +294,10 @@ def read_video_timestamps(filename, pts_unit='pts'):
the frame rate for the video

"""
from torchvision import get_video_backend
if get_video_backend() != "pyav":
return _video_opt._read_video_timestamps(filename, pts_unit)

_check_av_available()

video_frames = []
Expand Down