Skip to content

Commit

Permalink
Introduce backend to StreamReader
Browse files Browse the repository at this point in the history
This is the preparation for the file-like object support in StreamReader.

FileLike object support is bound via PyBind11, which is different from the
existing Torch-based binding even though all the internal implementations
are same.

PyBind11-based bindings are accessed like regular Python object as
`torchaudio._torchaudio_ffmpeg`, while Torch-based bindings are accesed like
`torch.ops.torchaudio_ffmpeg`.

The backend will serve as abstraction layer of these different bindings.
  • Loading branch information
mthrok committed May 18, 2022
1 parent 08ead80 commit 9372591
Showing 1 changed file with 156 additions and 44 deletions.
200 changes: 156 additions & 44 deletions torchaudio/io/_stream_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,138 @@ def _get_vfilter_desc(
return ",".join(descs) if descs else None


class _TorchBindBackend:
def __init__(self, src: str, format: Optional[str], option: Optional[Dict[str, str]]):
self._s = torch.ops.torchaudio.ffmpeg_streamer_init(src, format, option)
i = torch.ops.torchaudio.ffmpeg_streamer_find_best_audio_stream(self._s)
self.default_audio_stream = None if i < 0 else i
i = torch.ops.torchaudio.ffmpeg_streamer_find_best_video_stream(self._s)
self.default_video_stream = None if i < 0 else i

@property
def num_src_streams(self) -> int:
return torch.ops.torchaudio.ffmpeg_streamer_num_src_streams(self._s)

@property
def num_out_streams(self) -> int:
return torch.ops.torchaudio.ffmpeg_streamer_num_out_streams(self._s)

def get_src_stream_info(self, i: int) -> torchaudio.io.StreamReaderSourceStream:
return torch.ops.torchaudio.ffmpeg_streamer_get_src_stream_info(self._s, i)

def get_out_stream_info(self, i: int) -> torchaudio.io.StreamReaderOutputStream:
return torch.ops.torchaudio.ffmpeg_streamer_get_out_stream_info(self._s, i)

def seek(self, timestamp: float):
torch.ops.torchaudio.ffmpeg_streamer_seek(self._s, timestamp)

def add_basic_audio_stream(
self,
frames_per_chunk: int,
buffer_chunk_size: int,
stream_index: Optional[int],
sample_rate: Optional[int],
dtype: torch.dtype,
):
i = self.default_audio_stream if stream_index is None else stream_index
if i is None:
raise RuntimeError("There is no audio stream.")
torch.ops.torchaudio.ffmpeg_streamer_add_audio_stream(
self._s,
i,
frames_per_chunk,
buffer_chunk_size,
_get_afilter_desc(sample_rate, dtype),
None,
None,
)

def add_basic_video_stream(
self,
frames_per_chunk: int,
buffer_chunk_size: int,
stream_index: Optional[int],
frame_rate: Optional[int],
width: Optional[int],
height: Optional[int],
format: str = "RGB",
):
i = self.default_video_stream if stream_index is None else stream_index
if i is None:
raise RuntimeError("There is no video stream.")
torch.ops.torchaudio.ffmpeg_streamer_add_video_stream(
self._s,
i,
frames_per_chunk,
buffer_chunk_size,
_get_vfilter_desc(frame_rate, width, height, format),
None,
None,
None,
)

def add_audio_stream(
self,
frames_per_chunk: int,
buffer_chunk_size: int,
stream_index: Optional[int],
filter_desc: Optional[str],
decoder: Optional[str],
decoder_options: Optional[Dict[str, str]],
):
i = self.default_audio_stream if stream_index is None else stream_index
if i is None:
raise RuntimeError("There is no audio stream.")
torch.ops.torchaudio.ffmpeg_streamer_add_audio_stream(
self._s,
i,
frames_per_chunk,
buffer_chunk_size,
filter_desc,
decoder,
decoder_options,
)

def add_video_stream(
self,
frames_per_chunk: int,
buffer_chunk_size: int,
stream_index: Optional[int],
filter_desc: Optional[str],
decoder: Optional[str],
decoder_options: Optional[Dict[str, str]],
hw_accel: Optional[str],
):
i = self.default_video_stream if stream_index is None else stream_index
if i is None:
raise RuntimeError("There is no video stream.")
torch.ops.torchaudio.ffmpeg_streamer_add_video_stream(
self._s,
i,
frames_per_chunk,
buffer_chunk_size,
filter_desc,
decoder,
decoder_options,
hw_accel,
)

def remove_stream(self, i: int):
torch.ops.torchaudio.ffmpeg_streamer_remove_stream(self._s, i)

def process_packet(self, timeout: Optional[float], backoff: float):
return torch.ops.torchaudio.ffmpeg_streamer_process_packet(self._s, timeout, backoff)

def process_all_packets(self):
torch.ops.torchaudio.ffmpeg_streamer_process_all_packets(self._s)

def is_buffer_ready(self) -> bool:
return torch.ops.torchaudio.ffmpeg_streamer_is_buffer_ready(self._s)

def pop_chunks(self) -> Tuple[Optional[torch.Tensor]]:
return torch.ops.torchaudio.ffmpeg_streamer_pop_chunks(self._s)


class StreamReader:
"""Fetch and decode audio/video streams chunk by chunk.
Expand Down Expand Up @@ -246,43 +378,39 @@ def __init__(
format: Optional[str] = None,
option: Optional[Dict[str, str]] = None,
):
self._s = torch.ops.torchaudio.ffmpeg_streamer_init(src, format, option)
i = torch.ops.torchaudio.ffmpeg_streamer_find_best_audio_stream(self._s)
self._i_audio = None if i < 0 else i
i = torch.ops.torchaudio.ffmpeg_streamer_find_best_video_stream(self._s)
self._i_video = None if i < 0 else i
self._be = _TorchBindBackend(src, format, option)

@property
def num_src_streams(self):
"""Number of streams found in the provided media source.
:type: int
"""
return torch.ops.torchaudio.ffmpeg_streamer_num_src_streams(self._s)
return self._be.num_src_streams

@property
def num_out_streams(self):
"""Number of output streams configured by client code.
:type: int
"""
return torch.ops.torchaudio.ffmpeg_streamer_num_out_streams(self._s)
return self._be.num_out_streams

@property
def default_audio_stream(self):
"""The index of default audio stream. ``None`` if there is no audio stream
:type: Optional[int]
"""
return self._i_audio
return self._be.default_audio_stream

@property
def default_video_stream(self):
"""The index of default video stream. ``None`` if there is no video stream
:type: Optional[int]
"""
return self._i_video
return self._be.default_video_stream

def get_src_stream_info(self, i: int) -> torchaudio.io.StreamReaderSourceStream:
"""Get the metadata of source stream
Expand All @@ -292,7 +420,7 @@ def get_src_stream_info(self, i: int) -> torchaudio.io.StreamReaderSourceStream:
Returns:
SourceStream
"""
return _parse_si(torch.ops.torchaudio.ffmpeg_streamer_get_src_stream_info(self._s, i))
return _parse_si(self._be.get_src_stream_info(i))

def get_out_stream_info(self, i: int) -> torchaudio.io.StreamReaderOutputStream:
"""Get the metadata of output stream
Expand All @@ -302,15 +430,15 @@ def get_out_stream_info(self, i: int) -> torchaudio.io.StreamReaderOutputStream:
Returns:
OutputStream
"""
return _parse_oi(torch.ops.torchaudio.ffmpeg_streamer_get_out_stream_info(self._s, i))
return _parse_oi(self._be.get_out_stream_info(i))

def seek(self, timestamp: float):
"""Seek the stream to the given timestamp [second]
Args:
timestamp (float): Target time in second.
"""
torch.ops.torchaudio.ffmpeg_streamer_seek(self._s, timestamp)
self._be.seek(timestamp)

def add_basic_audio_stream(
self,
Expand Down Expand Up @@ -341,16 +469,7 @@ def add_basic_audio_stream(
If floating point, then the sample value range is
`[-1, 1]`.
"""
i = self.default_audio_stream if stream_index is None else stream_index
torch.ops.torchaudio.ffmpeg_streamer_add_audio_stream(
self._s,
i,
frames_per_chunk,
buffer_chunk_size,
_get_afilter_desc(sample_rate, dtype),
None,
None,
)
self._be.add_basic_audio_stream(frames_per_chunk, buffer_chunk_size, stream_index, sample_rate, dtype)

def add_basic_video_stream(
self,
Expand Down Expand Up @@ -388,17 +507,14 @@ def add_basic_video_stream(
- `YUV`: 8 bits * 3 channels
- `GRAY`: 8 bits * 1 channels
"""
i = self.default_video_stream if stream_index is None else stream_index
torch.ops.torchaudio.ffmpeg_streamer_add_video_stream(
self._s,
i,
self._be.add_basic_video_stream(
frames_per_chunk,
buffer_chunk_size,
_get_vfilter_desc(frame_rate, width, height, format),
None,
None,
None,
)
stream_index,
frame_rate,
width,
height,
format)

def add_audio_stream(
self,
Expand Down Expand Up @@ -435,12 +551,10 @@ def add_audio_stream(
decoder_options (dict or None, optional): Options passed to decoder.
Mapping from str to str.
"""
i = self.default_audio_stream if stream_index is None else stream_index
torch.ops.torchaudio.ffmpeg_streamer_add_audio_stream(
self._s,
i,
self._be.add_audio_stream(
frames_per_chunk,
buffer_chunk_size,
stream_index,
filter_desc,
decoder,
decoder_options,
Expand Down Expand Up @@ -522,12 +636,10 @@ def add_video_stream(
>>> print(chunk.dtype)
... cuda:1
"""
i = self.default_video_stream if stream_index is None else stream_index
torch.ops.torchaudio.ffmpeg_streamer_add_video_stream(
self._s,
i,
self._be.add_video_stream(
frames_per_chunk,
buffer_chunk_size,
stream_index,
filter_desc,
decoder,
decoder_options,
Expand All @@ -540,7 +652,7 @@ def remove_stream(self, i: int):
Args:
i (int): Index of the output stream to be removed.
"""
torch.ops.torchaudio.ffmpeg_streamer_remove_stream(self._s, i)
self._be.remove_stream(i)

def process_packet(self, timeout: Optional[float] = None, backoff: float = 10.0) -> int:
"""Read the source media and process one packet.
Expand Down Expand Up @@ -599,15 +711,15 @@ def process_packet(self, timeout: Optional[float] = None, backoff: float = 10.0)
flushed the pending frames. The caller should stop calling
this method.
"""
return torch.ops.torchaudio.ffmpeg_streamer_process_packet(self._s, timeout, backoff)
return self._be.process_packet(timeout, backoff)

def process_all_packets(self):
"""Process packets until it reaches EOF."""
torch.ops.torchaudio.ffmpeg_streamer_process_all_packets(self._s)
self._be.process_all_packets()

def is_buffer_ready(self) -> bool:
"""Returns true if all the output streams have at least one chunk filled."""
return torch.ops.torchaudio.ffmpeg_streamer_is_buffer_ready(self._s)
return self._be.is_buffer_ready()

def pop_chunks(self) -> Tuple[Optional[torch.Tensor]]:
"""Pop one chunk from all the output stream buffers.
Expand All @@ -617,7 +729,7 @@ def pop_chunks(self) -> Tuple[Optional[torch.Tensor]]:
Buffer contents.
If a buffer does not contain any frame, then `None` is returned instead.
"""
return torch.ops.torchaudio.ffmpeg_streamer_pop_chunks(self._s)
return self._be.pop_chunks()

def _fill_buffer(self, timeout: Optional[float], backoff: float) -> int:
"""Keep processing packets until all buffers have at least one chunk
Expand Down

0 comments on commit 9372591

Please sign in to comment.