diff --git a/torchaudio/io/_stream_reader.py b/torchaudio/io/_stream_reader.py index 942492979eb..0ed18fc9b3b 100644 --- a/torchaudio/io/_stream_reader.py +++ b/torchaudio/io/_stream_reader.py @@ -199,6 +199,138 @@ def _get_vfilter_desc( return ",".join(descs) if descs else None +class _TorchBindBackend: + def __init__(self, src: str, format: Optional[str], option: Optional[Dict[str, str]]): + self._s = torch.ops.torchaudio.ffmpeg_streamer_init(src, format, option) + i = torch.ops.torchaudio.ffmpeg_streamer_find_best_audio_stream(self._s) + self.default_audio_stream = None if i < 0 else i + i = torch.ops.torchaudio.ffmpeg_streamer_find_best_video_stream(self._s) + self.default_video_stream = None if i < 0 else i + + @property + def num_src_streams(self) -> int: + return torch.ops.torchaudio.ffmpeg_streamer_num_src_streams(self._s) + + @property + def num_out_streams(self) -> int: + return torch.ops.torchaudio.ffmpeg_streamer_num_out_streams(self._s) + + def get_src_stream_info(self, i: int) -> torchaudio.io.StreamReaderSourceStream: + return torch.ops.torchaudio.ffmpeg_streamer_get_src_stream_info(self._s, i) + + def get_out_stream_info(self, i: int) -> torchaudio.io.StreamReaderOutputStream: + return torch.ops.torchaudio.ffmpeg_streamer_get_out_stream_info(self._s, i) + + def seek(self, timestamp: float): + torch.ops.torchaudio.ffmpeg_streamer_seek(self._s, timestamp) + + def add_basic_audio_stream( + self, + frames_per_chunk: int, + buffer_chunk_size: int, + stream_index: Optional[int], + sample_rate: Optional[int], + dtype: torch.dtype, + ): + i = self.default_audio_stream if stream_index is None else stream_index + if i is None: + raise RuntimeError("There is no audio stream.") + torch.ops.torchaudio.ffmpeg_streamer_add_audio_stream( + self._s, + i, + frames_per_chunk, + buffer_chunk_size, + _get_afilter_desc(sample_rate, dtype), + None, + None, + ) + + def add_basic_video_stream( + self, + frames_per_chunk: int, + buffer_chunk_size: int, + stream_index: Optional[int], + frame_rate: Optional[int], + width: Optional[int], + height: Optional[int], + format: str = "RGB", + ): + i = self.default_video_stream if stream_index is None else stream_index + if i is None: + raise RuntimeError("There is no video stream.") + torch.ops.torchaudio.ffmpeg_streamer_add_video_stream( + self._s, + i, + frames_per_chunk, + buffer_chunk_size, + _get_vfilter_desc(frame_rate, width, height, format), + None, + None, + None, + ) + + def add_audio_stream( + self, + frames_per_chunk: int, + buffer_chunk_size: int, + stream_index: Optional[int], + filter_desc: Optional[str], + decoder: Optional[str], + decoder_options: Optional[Dict[str, str]], + ): + i = self.default_audio_stream if stream_index is None else stream_index + if i is None: + raise RuntimeError("There is no audio stream.") + torch.ops.torchaudio.ffmpeg_streamer_add_audio_stream( + self._s, + i, + frames_per_chunk, + buffer_chunk_size, + filter_desc, + decoder, + decoder_options, + ) + + def add_video_stream( + self, + frames_per_chunk: int, + buffer_chunk_size: int, + stream_index: Optional[int], + filter_desc: Optional[str], + decoder: Optional[str], + decoder_options: Optional[Dict[str, str]], + hw_accel: Optional[str], + ): + i = self.default_video_stream if stream_index is None else stream_index + if i is None: + raise RuntimeError("There is no video stream.") + torch.ops.torchaudio.ffmpeg_streamer_add_video_stream( + self._s, + i, + frames_per_chunk, + buffer_chunk_size, + filter_desc, + decoder, + decoder_options, + hw_accel, + ) + + def remove_stream(self, i: int): + torch.ops.torchaudio.ffmpeg_streamer_remove_stream(self._s, i) + + def process_packet(self, timeout: Optional[float], backoff: float): + return torch.ops.torchaudio.ffmpeg_streamer_process_packet(self._s, timeout, backoff) + + def process_all_packets(self): + torch.ops.torchaudio.ffmpeg_streamer_process_all_packets(self._s) + + def is_buffer_ready(self) -> bool: + return torch.ops.torchaudio.ffmpeg_streamer_is_buffer_ready(self._s) + + def pop_chunks(self) -> Tuple[Optional[torch.Tensor]]: + return torch.ops.torchaudio.ffmpeg_streamer_pop_chunks(self._s) + + class StreamReader: """Fetch and decode audio/video streams chunk by chunk. @@ -246,11 +378,7 @@ def __init__( format: Optional[str] = None, option: Optional[Dict[str, str]] = None, ): - self._s = torch.ops.torchaudio.ffmpeg_streamer_init(src, format, option) - i = torch.ops.torchaudio.ffmpeg_streamer_find_best_audio_stream(self._s) - self._i_audio = None if i < 0 else i - i = torch.ops.torchaudio.ffmpeg_streamer_find_best_video_stream(self._s) - self._i_video = None if i < 0 else i + self._be = _TorchBindBackend(src, format, option) @property def num_src_streams(self): @@ -258,7 +386,7 @@ def num_src_streams(self): :type: int """ - return torch.ops.torchaudio.ffmpeg_streamer_num_src_streams(self._s) + return self._be.num_src_streams @property def num_out_streams(self): @@ -266,7 +394,7 @@ def num_out_streams(self): :type: int """ - return torch.ops.torchaudio.ffmpeg_streamer_num_out_streams(self._s) + return self._be.num_out_streams @property def default_audio_stream(self): @@ -274,7 +402,7 @@ def default_audio_stream(self): :type: Optional[int] """ - return self._i_audio + return self._be.default_audio_stream @property def default_video_stream(self): @@ -282,7 +410,7 @@ def default_video_stream(self): :type: Optional[int] """ - return self._i_video + return self._be.default_video_stream def get_src_stream_info(self, i: int) -> torchaudio.io.StreamReaderSourceStream: """Get the metadata of source stream @@ -292,7 +420,7 @@ def get_src_stream_info(self, i: int) -> torchaudio.io.StreamReaderSourceStream: Returns: SourceStream """ - return _parse_si(torch.ops.torchaudio.ffmpeg_streamer_get_src_stream_info(self._s, i)) + return _parse_si(self._be.get_src_stream_info(i)) def get_out_stream_info(self, i: int) -> torchaudio.io.StreamReaderOutputStream: """Get the metadata of output stream @@ -302,7 +430,7 @@ def get_out_stream_info(self, i: int) -> torchaudio.io.StreamReaderOutputStream: Returns: OutputStream """ - return _parse_oi(torch.ops.torchaudio.ffmpeg_streamer_get_out_stream_info(self._s, i)) + return _parse_oi(self._be.get_out_stream_info(i)) def seek(self, timestamp: float): """Seek the stream to the given timestamp [second] @@ -310,7 +438,7 @@ def seek(self, timestamp: float): Args: timestamp (float): Target time in second. """ - torch.ops.torchaudio.ffmpeg_streamer_seek(self._s, timestamp) + self._be.seek(timestamp) def add_basic_audio_stream( self, @@ -341,16 +469,7 @@ def add_basic_audio_stream( If floating point, then the sample value range is `[-1, 1]`. """ - i = self.default_audio_stream if stream_index is None else stream_index - torch.ops.torchaudio.ffmpeg_streamer_add_audio_stream( - self._s, - i, - frames_per_chunk, - buffer_chunk_size, - _get_afilter_desc(sample_rate, dtype), - None, - None, - ) + self._be.add_basic_audio_stream(frames_per_chunk, buffer_chunk_size, stream_index, sample_rate, dtype) def add_basic_video_stream( self, @@ -388,17 +507,14 @@ def add_basic_video_stream( - `YUV`: 8 bits * 3 channels - `GRAY`: 8 bits * 1 channels """ - i = self.default_video_stream if stream_index is None else stream_index - torch.ops.torchaudio.ffmpeg_streamer_add_video_stream( - self._s, - i, + self._be.add_basic_video_stream( frames_per_chunk, buffer_chunk_size, - _get_vfilter_desc(frame_rate, width, height, format), - None, - None, - None, - ) + stream_index, + frame_rate, + width, + height, + format) def add_audio_stream( self, @@ -435,12 +551,10 @@ def add_audio_stream( decoder_options (dict or None, optional): Options passed to decoder. Mapping from str to str. """ - i = self.default_audio_stream if stream_index is None else stream_index - torch.ops.torchaudio.ffmpeg_streamer_add_audio_stream( - self._s, - i, + self._be.add_audio_stream( frames_per_chunk, buffer_chunk_size, + stream_index, filter_desc, decoder, decoder_options, @@ -522,12 +636,10 @@ def add_video_stream( >>> print(chunk.dtype) ... cuda:1 """ - i = self.default_video_stream if stream_index is None else stream_index - torch.ops.torchaudio.ffmpeg_streamer_add_video_stream( - self._s, - i, + self._be.add_video_stream( frames_per_chunk, buffer_chunk_size, + stream_index, filter_desc, decoder, decoder_options, @@ -540,7 +652,7 @@ def remove_stream(self, i: int): Args: i (int): Index of the output stream to be removed. """ - torch.ops.torchaudio.ffmpeg_streamer_remove_stream(self._s, i) + self._be.remove_stream(i) def process_packet(self, timeout: Optional[float] = None, backoff: float = 10.0) -> int: """Read the source media and process one packet. @@ -599,15 +711,15 @@ def process_packet(self, timeout: Optional[float] = None, backoff: float = 10.0) flushed the pending frames. The caller should stop calling this method. """ - return torch.ops.torchaudio.ffmpeg_streamer_process_packet(self._s, timeout, backoff) + return self._be.process_packet(timeout, backoff) def process_all_packets(self): """Process packets until it reaches EOF.""" - torch.ops.torchaudio.ffmpeg_streamer_process_all_packets(self._s) + self._be.process_all_packets() def is_buffer_ready(self) -> bool: """Returns true if all the output streams have at least one chunk filled.""" - return torch.ops.torchaudio.ffmpeg_streamer_is_buffer_ready(self._s) + return self._be.is_buffer_ready() def pop_chunks(self) -> Tuple[Optional[torch.Tensor]]: """Pop one chunk from all the output stream buffers. @@ -617,7 +729,7 @@ def pop_chunks(self) -> Tuple[Optional[torch.Tensor]]: Buffer contents. If a buffer does not contain any frame, then `None` is returned instead. """ - return torch.ops.torchaudio.ffmpeg_streamer_pop_chunks(self._s) + return self._be.pop_chunks() def _fill_buffer(self, timeout: Optional[float], backoff: float) -> int: """Keep processing packets until all buffers have at least one chunk