From 578fdee3b66bafb2da7336f2e83b58409bfaa186 Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Tue, 7 Jun 2022 14:47:06 -0700 Subject: [PATCH] Add metadata to source stream info (#2461) Summary: Add metadata, such as ID3 (https://github.com/pytorch/audio/commit/da3ffe9bf146c89585ed32aaf3be7f81bea2c4ab)tag to `StreamReaderSourceAudioStream`. Pull Request resolved: https://github.com/pytorch/audio/pull/2461 Differential Revision: D36985656 Pulled By: mthrok fbshipit-source-id: 8a8fc6fda5c39c84f5a2fe10237f0eb793c968a6 --- .../io/stream_reader_test.py | 34 +++++++++++++++++++ torchaudio/csrc/ffmpeg/pybind/pybind.cpp | 4 ++- torchaudio/csrc/ffmpeg/stream_reader.cpp | 14 ++++++++ .../csrc/ffmpeg/stream_reader_wrapper.cpp | 33 ++++++++++++++++++ .../csrc/ffmpeg/stream_reader_wrapper.h | 28 +++++++++++++++ torchaudio/csrc/ffmpeg/typedefs.h | 1 + torchaudio/io/_compat.py | 4 +-- torchaudio/io/_stream_reader.py | 18 +++++++--- 8 files changed, 128 insertions(+), 8 deletions(-) diff --git a/test/torchaudio_unittest/io/stream_reader_test.py b/test/torchaudio_unittest/io/stream_reader_test.py index 85e87a6e0a8..8d9fed1d1f2 100644 --- a/test/torchaudio_unittest/io/stream_reader_test.py +++ b/test/torchaudio_unittest/io/stream_reader_test.py @@ -89,6 +89,12 @@ def test_src_info(self): s = StreamReader(self.get_src()) assert s.num_src_streams == 6 + metadata = { + "compatible_brands": "isomiso2avc1mp41", + "encoder": "Lavf58.76.100", + "major_brand": "isom", + "minor_version": "512", + } expected = [ StreamReaderSourceVideoStream( media_type="video", @@ -98,6 +104,7 @@ def test_src_info(self): bit_rate=71925, num_frames=325, bits_per_sample=8, + metadata=metadata, width=320, height=180, frame_rate=25.0, @@ -110,6 +117,7 @@ def test_src_info(self): bit_rate=72093, num_frames=103, bits_per_sample=0, + metadata=metadata, sample_rate=8000.0, num_channels=2, ), @@ -121,6 +129,7 @@ def test_src_info(self): bit_rate=None, num_frames=None, bits_per_sample=None, + metadata=metadata, ), StreamReaderSourceVideoStream( media_type="video", @@ -130,6 +139,7 @@ def test_src_info(self): bit_rate=128783, num_frames=390, bits_per_sample=8, + metadata=metadata, width=480, height=270, frame_rate=29.97002997002997, @@ -142,6 +152,7 @@ def test_src_info(self): bit_rate=128837, num_frames=205, bits_per_sample=0, + metadata=metadata, sample_rate=16000.0, num_channels=2, ), @@ -153,11 +164,34 @@ def test_src_info(self): bit_rate=None, num_frames=None, bits_per_sample=None, + metadata=metadata, ), ] output = [s.get_src_stream_info(i) for i in range(6)] assert expected == output + def test_id3tag(self): + s = StreamReader(self.get_src("steam-train-whistle-daniel_simon.mp3")) + output = s.get_src_stream_info(s.default_audio_stream) + + expected = StreamReaderSourceAudioStream( + media_type="audio", + codec="mp3", + codec_long_name="MP3 (MPEG audio layer 3)", + format="fltp", + bit_rate=210571, + num_frames=0, + bits_per_sample=0, + metadata={ + "title": "SoundBible.com Must Credit", + "artist": "SoundBible.com Must Credit", + "date": "2017", + }, + sample_rate=44100.0, + num_channels=2, + ) + assert output == expected + def test_src_info_invalid_index(self): """`get_src_stream_info` does not segfault but raise an exception when input is invalid""" s = StreamReader(self.get_src()) diff --git a/torchaudio/csrc/ffmpeg/pybind/pybind.cpp b/torchaudio/csrc/ffmpeg/pybind/pybind.cpp index 46e633262c1..bdcb1196497 100644 --- a/torchaudio/csrc/ffmpeg/pybind/pybind.cpp +++ b/torchaudio/csrc/ffmpeg/pybind/pybind.cpp @@ -22,7 +22,9 @@ PYBIND11_MODULE(_torchaudio_ffmpeg, m) { .def( "find_best_video_stream", &StreamReaderFileObj::find_best_video_stream) - .def("get_src_stream_info", &StreamReaderFileObj::get_src_stream_info) + .def( + "get_src_stream_info", + &StreamReaderFileObj::get_src_stream_info_pybind) .def("get_out_stream_info", &StreamReaderFileObj::get_out_stream_info) .def("seek", &StreamReaderFileObj::seek) .def("add_audio_stream", &StreamReaderFileObj::add_audio_stream) diff --git a/torchaudio/csrc/ffmpeg/stream_reader.cpp b/torchaudio/csrc/ffmpeg/stream_reader.cpp index fa47b5cbf07..33da16c2ffc 100644 --- a/torchaudio/csrc/ffmpeg/stream_reader.cpp +++ b/torchaudio/csrc/ffmpeg/stream_reader.cpp @@ -71,6 +71,18 @@ int64_t StreamReader::num_src_streams() const { return pFormatContext->nb_streams; } +namespace { +c10::Dict parse_metadata( + const AVDictionary* metadata) { + AVDictionaryEntry* tag = nullptr; + c10::Dict ret; + while ((tag = av_dict_get(metadata, "", tag, AV_DICT_IGNORE_SUFFIX))) { + ret.insert(std::string(tag->key), std::string(tag->value)); + } + return ret; +} +} // namespace + SrcStreamInfo StreamReader::get_src_stream_info(int i) const { validate_src_stream_index(i); AVStream* stream = pFormatContext->streams[i]; @@ -81,11 +93,13 @@ SrcStreamInfo StreamReader::get_src_stream_info(int i) const { ret.bit_rate = codecpar->bit_rate; ret.num_frames = stream->nb_frames; ret.bits_per_sample = codecpar->bits_per_raw_sample; + ret.metadata = parse_metadata(pFormatContext->metadata); const AVCodecDescriptor* desc = avcodec_descriptor_get(codecpar->codec_id); if (desc) { ret.codec_name = desc->name; ret.codec_long_name = desc->long_name; } + switch (codecpar->codec_type) { case AVMEDIA_TYPE_AUDIO: { AVSampleFormat smp_fmt = static_cast(codecpar->format); diff --git a/torchaudio/csrc/ffmpeg/stream_reader_wrapper.cpp b/torchaudio/csrc/ffmpeg/stream_reader_wrapper.cpp index 4675098f3bb..ba168dc5c01 100644 --- a/torchaudio/csrc/ffmpeg/stream_reader_wrapper.cpp +++ b/torchaudio/csrc/ffmpeg/stream_reader_wrapper.cpp @@ -4,6 +4,17 @@ namespace torchaudio { namespace ffmpeg { namespace { +// TODO: +// merge the implementation with the one from stream_reader_binding.cpp +std::map convert_map( + const c10::Dict& src) { + std::map ret; + for (const auto& it : src) { + ret.insert({it.key(), it.value()}); + } + return ret; +} + SrcInfo convert(SrcStreamInfo ssi) { return SrcInfo(std::forward_as_tuple( av_get_media_type_string(ssi.media_type), @@ -13,6 +24,24 @@ SrcInfo convert(SrcStreamInfo ssi) { ssi.bit_rate, ssi.num_frames, ssi.bits_per_sample, + ssi.metadata, + ssi.sample_rate, + ssi.num_channels, + ssi.width, + ssi.height, + ssi.frame_rate)); +} + +SrcInfoPyBind convert_pybind(SrcStreamInfo ssi) { + return SrcInfoPyBind(std::forward_as_tuple( + av_get_media_type_string(ssi.media_type), + ssi.codec_name, + ssi.codec_long_name, + ssi.fmt_name, + ssi.bit_rate, + ssi.num_frames, + ssi.bits_per_sample, + convert_map(ssi.metadata), ssi.sample_rate, ssi.num_channels, ssi.width, @@ -33,6 +62,10 @@ SrcInfo StreamReaderBinding::get_src_stream_info(int64_t i) { return convert(StreamReader::get_src_stream_info(static_cast(i))); } +SrcInfoPyBind StreamReaderBinding::get_src_stream_info_pybind(int64_t i) { + return convert_pybind(StreamReader::get_src_stream_info(static_cast(i))); +} + OutInfo StreamReaderBinding::get_out_stream_info(int64_t i) { return convert(StreamReader::get_out_stream_info(static_cast(i))); } diff --git a/torchaudio/csrc/ffmpeg/stream_reader_wrapper.h b/torchaudio/csrc/ffmpeg/stream_reader_wrapper.h index fc4e3acce4c..8ef67ec2bd3 100644 --- a/torchaudio/csrc/ffmpeg/stream_reader_wrapper.h +++ b/torchaudio/csrc/ffmpeg/stream_reader_wrapper.h @@ -5,6 +5,14 @@ namespace torchaudio { namespace ffmpeg { +// Because TorchScript requires c10::Dict type to pass dict, +// while PyBind11 requires std::map type to pass dict, +// we duplicate the return tuple. +// Even though all the PyBind-based implementations are placed +// in `pybind` directory, because std::map does not require pybind11 +// header, we define both of them here, for the sake of +// better locality/maintainability. + using SrcInfo = std::tuple< std::string, // media_type std::string, // codec name @@ -13,6 +21,25 @@ using SrcInfo = std::tuple< int64_t, // bit_rate int64_t, // num_frames int64_t, // bits_per_sample + c10::Dict, // metadata + // Audio + double, // sample_rate + int64_t, // num_channels + // Video + int64_t, // width + int64_t, // height + double // frame_rate + >; + +using SrcInfoPyBind = std::tuple< + std::string, // media_type + std::string, // codec name + std::string, // codec long name + std::string, // format name + int64_t, // bit_rate + int64_t, // num_frames + int64_t, // bits_per_sample + std::map, // metadata // Audio double, // sample_rate int64_t, // num_channels @@ -33,6 +60,7 @@ struct StreamReaderBinding : public StreamReader, public torch::CustomClassHolder { explicit StreamReaderBinding(AVFormatContextPtr&& p); SrcInfo get_src_stream_info(int64_t i); + SrcInfoPyBind get_src_stream_info_pybind(int64_t i); OutInfo get_out_stream_info(int64_t i); int64_t process_packet( diff --git a/torchaudio/csrc/ffmpeg/typedefs.h b/torchaudio/csrc/ffmpeg/typedefs.h index 4ac330eda76..a77d4a68945 100644 --- a/torchaudio/csrc/ffmpeg/typedefs.h +++ b/torchaudio/csrc/ffmpeg/typedefs.h @@ -14,6 +14,7 @@ struct SrcStreamInfo { int64_t bit_rate = 0; int64_t num_frames = 0; int bits_per_sample = 0; + c10::Dict metadata{}; // Audio double sample_rate = 0; int num_channels = 0; diff --git a/torchaudio/io/_compat.py b/torchaudio/io/_compat.py index c97d51ef3ff..78eaf97f9e3 100644 --- a/torchaudio/io/_compat.py +++ b/torchaudio/io/_compat.py @@ -12,9 +12,9 @@ def _info_audio( i = s.find_best_audio_stream() sinfo = s.get_src_stream_info(i) return AudioMetaData( - int(sinfo[7]), + int(sinfo[8]), sinfo[5], - sinfo[8], + sinfo[9], sinfo[6], sinfo[1].upper(), ) diff --git a/torchaudio/io/_stream_reader.py b/torchaudio/io/_stream_reader.py index 7e28a94d3b2..b2f65f82bdb 100644 --- a/torchaudio/io/_stream_reader.py +++ b/torchaudio/io/_stream_reader.py @@ -61,6 +61,9 @@ class StreamReaderSourceStream: """This is the number of valid bits in each output sample. For compressed format, it can be 0. """ + metadata: Dict[str, str] + """Metadata attached to the source media. + Note that metadata is common across the source streams.""" @dataclass @@ -108,13 +111,14 @@ class StreamReaderSourceVideoStream(StreamReaderSourceStream): _BIT_RATE = 4 _NUM_FRAMES = 5 _BPS = 6 +_METADATA = 7 # - AUDIO -_SAMPLE_RATE = 7 -_NUM_CHANNELS = 8 +_SAMPLE_RATE = 8 +_NUM_CHANNELS = 9 # - VIDEO -_WIDTH = 9 -_HEIGHT = 10 -_FRAME_RATE = 11 +_WIDTH = 10 +_HEIGHT = 11 +_FRAME_RATE = 12 def _parse_si(i): @@ -125,6 +129,7 @@ def _parse_si(i): bit_rate = i[_BIT_RATE] num_frames = i[_NUM_FRAMES] bps = i[_BPS] + metadata = i[_METADATA] if media_type == "audio": return StreamReaderSourceAudioStream( media_type=media_type, @@ -134,6 +139,7 @@ def _parse_si(i): bit_rate=bit_rate, num_frames=num_frames, bits_per_sample=bps, + metadata=metadata, sample_rate=i[_SAMPLE_RATE], num_channels=i[_NUM_CHANNELS], ) @@ -146,6 +152,7 @@ def _parse_si(i): bit_rate=bit_rate, num_frames=num_frames, bits_per_sample=bps, + metadata=metadata, width=i[_WIDTH], height=i[_HEIGHT], frame_rate=i[_FRAME_RATE], @@ -158,6 +165,7 @@ def _parse_si(i): bit_rate=None, num_frames=None, bits_per_sample=None, + metadata=metadata, )