Skip to content

Commit

Permalink
Add test for fileobj
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok committed May 18, 2022
1 parent 9372591 commit 4cccfc8
Showing 1 changed file with 57 additions and 24 deletions.
81 changes: 57 additions & 24 deletions test/torchaudio_unittest/io/stream_reader_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from parameterized import parameterized
from parameterized import parameterized, parameterized_class
from torchaudio_unittest.common_utils import (
get_asset_path,
get_image,
Expand All @@ -22,12 +22,46 @@
)


def get_video_asset(file="nasa_13013.mp4"):
return get_asset_path(file)
################################################################################
# Helper decorator and Mixin to duplicate the tests for fileobj
_TEST_FILEOBJ = "src_is_fileobj"

def _class_name(cls, _, params):
return f'{cls.__name__}{"_fileobj" if params[_TEST_FILEOBJ] else ""}'


_media_source = parameterized_class(
(_TEST_FILEOBJ, ),
[(False, ), (True, )],
class_name_func=_class_name
)


class _MediaSourceMixin:
def setUp(self):
super().setUp()
self.src = None

def get_video_asset(self, file="nasa_13013.mp4"):
if self.src is not None:
raise ValueError("get_video_asset can be called only once.")

path = get_asset_path(file)
if getattr(self, _TEST_FILEOBJ):
self.src = open(path, "rb")
return self.src
return path

def tearDown(self):
if self.src is not None:
self.src.close()
super().tearDown()
################################################################################


@skipIfNoFFmpeg
class StreamReaderInterfaceTest(TempDirMixin, TorchaudioTestCase):
@_media_source
class StreamReaderInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestCase):
"""Test suite for interface behaviors around StreamReader"""

def test_streamer_invalid_input(self):
Expand All @@ -48,14 +82,13 @@ def test_streamer_invalid_input(self):
def test_streamer_invalide_option(self, invalid_keys, options):
"""When invalid options are given, StreamReader raises an exception with these keys"""
options.update({k: k for k in invalid_keys})
src = get_video_asset()
with self.assertRaises(RuntimeError) as ctx:
StreamReader(src, option=options)
StreamReader(self.get_video_asset(), option=options)
assert all(f'"{k}"' in str(ctx.exception) for k in invalid_keys)

def test_src_info(self):
"""`get_src_stream_info` properly fetches information"""
s = StreamReader(get_video_asset())
s = StreamReader(self.get_video_asset())
assert s.num_src_streams == 6

expected = [
Expand Down Expand Up @@ -117,30 +150,30 @@ def test_src_info(self):

def test_src_info_invalid_index(self):
"""`get_src_stream_info` does not segfault but raise an exception when input is invalid"""
s = StreamReader(get_video_asset())
s = StreamReader(self.get_video_asset())
for i in [-1, 6, 7, 8]:
with self.assertRaises(IndexError):
s.get_src_stream_info(i)

def test_default_streams(self):
"""default stream is not None"""
s = StreamReader(get_video_asset())
s = StreamReader(self.get_video_asset())
assert s.default_audio_stream is not None
assert s.default_video_stream is not None

def test_default_audio_stream_none(self):
"""default audio stream is None for video without audio"""
s = StreamReader(get_video_asset("nasa_13013_no_audio.mp4"))
s = StreamReader(self.get_video_asset("nasa_13013_no_audio.mp4"))
assert s.default_audio_stream is None

def test_default_video_stream_none(self):
"""default video stream is None for video with only audio"""
s = StreamReader(get_video_asset("nasa_13013_no_video.mp4"))
s = StreamReader(self.get_video_asset("nasa_13013_no_video.mp4"))
assert s.default_video_stream is None

def test_num_out_stream(self):
"""num_out_streams gives the correct count of output streams"""
s = StreamReader(get_video_asset())
s = StreamReader(self.get_video_asset())
n, m = 6, 4
for i in range(n):
assert s.num_out_streams == i
Expand All @@ -158,7 +191,7 @@ def test_num_out_stream(self):

def test_basic_audio_stream(self):
"""`add_basic_audio_stream` constructs a correct filter."""
s = StreamReader(get_video_asset())
s = StreamReader(self.get_video_asset())
s.add_basic_audio_stream(frames_per_chunk=-1, dtype=None)
s.add_basic_audio_stream(frames_per_chunk=-1, sample_rate=8000)
s.add_basic_audio_stream(frames_per_chunk=-1, dtype=torch.int16)
Expand All @@ -177,7 +210,7 @@ def test_basic_audio_stream(self):

def test_basic_video_stream(self):
"""`add_basic_video_stream` constructs a correct filter."""
s = StreamReader(get_video_asset())
s = StreamReader(self.get_video_asset())
s.add_basic_video_stream(frames_per_chunk=-1, format=None)
s.add_basic_video_stream(frames_per_chunk=-1, width=3, height=5)
s.add_basic_video_stream(frames_per_chunk=-1, frame_rate=7)
Expand All @@ -201,7 +234,7 @@ def test_basic_video_stream(self):

def test_remove_streams(self):
"""`remove_stream` removes the correct output stream"""
s = StreamReader(get_video_asset())
s = StreamReader(self.get_video_asset())
s.add_basic_audio_stream(frames_per_chunk=-1, sample_rate=24000)
s.add_basic_video_stream(frames_per_chunk=-1, width=16, height=16)
s.add_basic_audio_stream(frames_per_chunk=-1, sample_rate=8000)
Expand All @@ -221,7 +254,7 @@ def test_remove_streams(self):

def test_remove_stream_invalid(self):
"""Attempt to remove invalid output streams raises IndexError"""
s = StreamReader(get_video_asset())
s = StreamReader(self.get_video_asset())
for i in range(-3, 3):
with self.assertRaises(IndexError):
s.remove_stream(i)
Expand All @@ -235,7 +268,7 @@ def test_remove_stream_invalid(self):

def test_process_packet(self):
"""`process_packet` method returns 0 while there is a packet in source stream"""
s = StreamReader(get_video_asset())
s = StreamReader(self.get_video_asset())
# nasa_1013.mp3 contains 1023 packets.
for _ in range(1023):
code = s.process_packet()
Expand All @@ -246,19 +279,19 @@ def test_process_packet(self):

def test_pop_chunks_no_output_stream(self):
"""`pop_chunks` method returns empty list when there is no output stream"""
s = StreamReader(get_video_asset())
s = StreamReader(self.get_video_asset())
assert s.pop_chunks() == []

def test_pop_chunks_empty_buffer(self):
"""`pop_chunks` method returns None when a buffer is empty"""
s = StreamReader(get_video_asset())
s = StreamReader(self.get_video_asset())
s.add_basic_audio_stream(frames_per_chunk=-1)
s.add_basic_video_stream(frames_per_chunk=-1)
assert s.pop_chunks() == [None, None]

def test_pop_chunks_exhausted_stream(self):
"""`pop_chunks` method returns None when the source stream is exhausted"""
s = StreamReader(get_video_asset())
s = StreamReader(self.get_video_asset())
# video is 16.57 seconds.
# audio streams per 10 second chunk
# video streams per 20 second chunk
Expand All @@ -284,14 +317,14 @@ def test_pop_chunks_exhausted_stream(self):

def test_stream_empty(self):
"""`stream` fails when no output stream is configured"""
s = StreamReader(get_video_asset())
s = StreamReader(self.get_video_asset())
with self.assertRaises(RuntimeError):
next(s.stream())

def test_stream_smoke_test(self):
"""`stream` streams chunks fine"""
w, h = 256, 198
s = StreamReader(get_video_asset())
s = StreamReader(self.get_video_asset())
s.add_basic_audio_stream(frames_per_chunk=2000, sample_rate=8000)
s.add_basic_video_stream(frames_per_chunk=15, frame_rate=60, width=w, height=h)
for i, (achunk, vchunk) in enumerate(s.stream()):
Expand All @@ -302,7 +335,7 @@ def test_stream_smoke_test(self):

def test_seek(self):
"""Calling `seek` multiple times should not segfault"""
s = StreamReader(get_video_asset())
s = StreamReader(self.get_video_asset())
for i in range(10):
s.seek(i)
for _ in range(0):
Expand All @@ -312,7 +345,7 @@ def test_seek(self):

def test_seek_negative(self):
"""Calling `seek` with negative value should raise an exception"""
s = StreamReader(get_video_asset())
s = StreamReader(self.get_video_asset())
with self.assertRaises(ValueError):
s.seek(-1.0)

Expand Down

0 comments on commit 4cccfc8

Please sign in to comment.