Skip to content

Commit

Permalink
Add metadata to source stream info
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok committed Jun 7, 2022
1 parent 4c19e2c commit 09cfdac
Show file tree
Hide file tree
Showing 8 changed files with 124 additions and 8 deletions.
34 changes: 34 additions & 0 deletions test/torchaudio_unittest/io/stream_reader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ def test_src_info(self):
s = StreamReader(self.get_src())
assert s.num_src_streams == 6

metadata = {
'compatible_brands': 'isomiso2avc1mp41',
'encoder': 'Lavf58.76.100',
'major_brand': 'isom',
'minor_version': '512',
}
expected = [
StreamReaderSourceVideoStream(
media_type="video",
Expand All @@ -98,6 +104,7 @@ def test_src_info(self):
bit_rate=71925,
num_frames=325,
bits_per_sample=8,
metadata=metadata,
width=320,
height=180,
frame_rate=25.0,
Expand All @@ -110,6 +117,7 @@ def test_src_info(self):
bit_rate=72093,
num_frames=103,
bits_per_sample=0,
metadata=metadata,
sample_rate=8000.0,
num_channels=2,
),
Expand All @@ -121,6 +129,7 @@ def test_src_info(self):
bit_rate=None,
num_frames=None,
bits_per_sample=None,
metadata=metadata,
),
StreamReaderSourceVideoStream(
media_type="video",
Expand All @@ -130,6 +139,7 @@ def test_src_info(self):
bit_rate=128783,
num_frames=390,
bits_per_sample=8,
metadata=metadata,
width=480,
height=270,
frame_rate=29.97002997002997,
Expand All @@ -142,6 +152,7 @@ def test_src_info(self):
bit_rate=128837,
num_frames=205,
bits_per_sample=0,
metadata=metadata,
sample_rate=16000.0,
num_channels=2,
),
Expand All @@ -153,11 +164,34 @@ def test_src_info(self):
bit_rate=None,
num_frames=None,
bits_per_sample=None,
metadata=metadata,
),
]
output = [s.get_src_stream_info(i) for i in range(6)]
assert expected == output

def test_id3tag(self):
s = StreamReader(self.get_src("steam-train-whistle-daniel_simon.mp3"))
output = s.get_src_stream_info(s.default_audio_stream)

expected = StreamReaderSourceAudioStream(
media_type="audio",
codec="mp3",
codec_long_name="MP3 (MPEG audio layer 3)",
format="fltp",
bit_rate=210571,
num_frames=0,
bits_per_sample=0,
metadata={
"title": "SoundBible.com Must Credit",
"artist": "SoundBible.com Must Credit",
"date": "2017",
},
sample_rate=44100.0,
num_channels=2,
)
assert output == expected

def test_src_info_invalid_index(self):
"""`get_src_stream_info` does not segfault but raise an exception when input is invalid"""
s = StreamReader(self.get_src())
Expand Down
2 changes: 1 addition & 1 deletion torchaudio/csrc/ffmpeg/pybind/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ PYBIND11_MODULE(_torchaudio_ffmpeg, m) {
.def(
"find_best_video_stream",
&StreamReaderFileObj::find_best_video_stream)
.def("get_src_stream_info", &StreamReaderFileObj::get_src_stream_info)
.def("get_src_stream_info", &StreamReaderFileObj::get_src_stream_info_pybind)
.def("get_out_stream_info", &StreamReaderFileObj::get_out_stream_info)
.def("seek", &StreamReaderFileObj::seek)
.def("add_audio_stream", &StreamReaderFileObj::add_audio_stream)
Expand Down
13 changes: 13 additions & 0 deletions torchaudio/csrc/ffmpeg/stream_reader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,17 @@ int64_t StreamReader::num_src_streams() const {
return pFormatContext->nb_streams;
}

namespace {
c10::Dict<std::string, std::string> parse_metadata(const AVDictionary * metadata) {
AVDictionaryEntry *tag = nullptr;
c10::Dict<std::string, std::string> ret;
while ((tag = av_dict_get(metadata, "", tag, AV_DICT_IGNORE_SUFFIX))) {
ret.insert(std::string(tag->key), std::string(tag->value));
}
return ret;
}
} // namespace

SrcStreamInfo StreamReader::get_src_stream_info(int i) const {
validate_src_stream_index(i);
AVStream* stream = pFormatContext->streams[i];
Expand All @@ -81,11 +92,13 @@ SrcStreamInfo StreamReader::get_src_stream_info(int i) const {
ret.bit_rate = codecpar->bit_rate;
ret.num_frames = stream->nb_frames;
ret.bits_per_sample = codecpar->bits_per_raw_sample;
ret.metadata = parse_metadata(pFormatContext->metadata);
const AVCodecDescriptor* desc = avcodec_descriptor_get(codecpar->codec_id);
if (desc) {
ret.codec_name = desc->name;
ret.codec_long_name = desc->long_name;
}

switch (codecpar->codec_type) {
case AVMEDIA_TYPE_AUDIO: {
AVSampleFormat smp_fmt = static_cast<AVSampleFormat>(codecpar->format);
Expand Down
32 changes: 32 additions & 0 deletions torchaudio/csrc/ffmpeg/stream_reader_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,16 @@ namespace torchaudio {
namespace ffmpeg {
namespace {

// TODO:
// merge the implementation with the one from stream_reader_binding.cpp
std::map<std::string, std::string> convert_map(const c10::Dict<std::string, std::string>& src) {
std::map<std::string, std::string> ret;
for (const auto& it : src) {
ret.insert({it.key(), it.value()});
}
return ret;
}

SrcInfo convert(SrcStreamInfo ssi) {
return SrcInfo(std::forward_as_tuple(
av_get_media_type_string(ssi.media_type),
Expand All @@ -13,6 +23,24 @@ SrcInfo convert(SrcStreamInfo ssi) {
ssi.bit_rate,
ssi.num_frames,
ssi.bits_per_sample,
ssi.metadata,
ssi.sample_rate,
ssi.num_channels,
ssi.width,
ssi.height,
ssi.frame_rate));
}

SrcInfoPyBind convert_pybind(SrcStreamInfo ssi) {
return SrcInfoPyBind(std::forward_as_tuple(
av_get_media_type_string(ssi.media_type),
ssi.codec_name,
ssi.codec_long_name,
ssi.fmt_name,
ssi.bit_rate,
ssi.num_frames,
ssi.bits_per_sample,
convert_map(ssi.metadata),
ssi.sample_rate,
ssi.num_channels,
ssi.width,
Expand All @@ -33,6 +61,10 @@ SrcInfo StreamReaderBinding::get_src_stream_info(int64_t i) {
return convert(StreamReader::get_src_stream_info(static_cast<int>(i)));
}

SrcInfoPyBind StreamReaderBinding::get_src_stream_info_pybind(int64_t i) {
return convert_pybind(StreamReader::get_src_stream_info(static_cast<int>(i)));
}

OutInfo StreamReaderBinding::get_out_stream_info(int64_t i) {
return convert(StreamReader::get_out_stream_info(static_cast<int>(i)));
}
Expand Down
28 changes: 28 additions & 0 deletions torchaudio/csrc/ffmpeg/stream_reader_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@
namespace torchaudio {
namespace ffmpeg {

// Because TorchScript requires c10::Dict type to pass dict,
// while PyBind11 requires std::map type to pass dict,
// we duplicate the return tuple.
// Even though all the PyBind-based implementations are placed
// in `pybind` directory, because std::map does not require pybind11
// header, we define both of them here, for the sake of
// better locality/maintainability.

using SrcInfo = std::tuple<
std::string, // media_type
std::string, // codec name
Expand All @@ -13,6 +21,25 @@ using SrcInfo = std::tuple<
int64_t, // bit_rate
int64_t, // num_frames
int64_t, // bits_per_sample
c10::Dict<std::string, std::string>, // metadata
// Audio
double, // sample_rate
int64_t, // num_channels
// Video
int64_t, // width
int64_t, // height
double // frame_rate
>;

using SrcInfoPyBind = std::tuple<
std::string, // media_type
std::string, // codec name
std::string, // codec long name
std::string, // format name
int64_t, // bit_rate
int64_t, // num_frames
int64_t, // bits_per_sample
std::map<std::string, std::string>, // metadata
// Audio
double, // sample_rate
int64_t, // num_channels
Expand All @@ -33,6 +60,7 @@ struct StreamReaderBinding : public StreamReader,
public torch::CustomClassHolder {
explicit StreamReaderBinding(AVFormatContextPtr&& p);
SrcInfo get_src_stream_info(int64_t i);
SrcInfoPyBind get_src_stream_info_pybind(int64_t i);
OutInfo get_out_stream_info(int64_t i);

int64_t process_packet(
Expand Down
1 change: 1 addition & 0 deletions torchaudio/csrc/ffmpeg/typedefs.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ struct SrcStreamInfo {
int64_t bit_rate = 0;
int64_t num_frames = 0;
int bits_per_sample = 0;
c10::Dict<std::string, std::string> metadata{};
// Audio
double sample_rate = 0;
int num_channels = 0;
Expand Down
4 changes: 2 additions & 2 deletions torchaudio/io/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ def _info_audio(
i = s.find_best_audio_stream()
sinfo = s.get_src_stream_info(i)
return AudioMetaData(
int(sinfo[7]),
int(sinfo[8]),
sinfo[5],
sinfo[8],
sinfo[9],
sinfo[6],
sinfo[1].upper(),
)
Expand Down
18 changes: 13 additions & 5 deletions torchaudio/io/_stream_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ class StreamReaderSourceStream:
"""This is the number of valid bits in each output sample.
For compressed format, it can be 0.
"""
metadata: Dict[str, str]
"""Metadata attached to the source media.
Note that metadata is common across the source streams."""


@dataclass
Expand Down Expand Up @@ -108,13 +111,14 @@ class StreamReaderSourceVideoStream(StreamReaderSourceStream):
_BIT_RATE = 4
_NUM_FRAMES = 5
_BPS = 6
_METADATA = 7
# - AUDIO
_SAMPLE_RATE = 7
_NUM_CHANNELS = 8
_SAMPLE_RATE = 8
_NUM_CHANNELS = 9
# - VIDEO
_WIDTH = 9
_HEIGHT = 10
_FRAME_RATE = 11
_WIDTH = 10
_HEIGHT = 11
_FRAME_RATE = 12


def _parse_si(i):
Expand All @@ -125,6 +129,7 @@ def _parse_si(i):
bit_rate = i[_BIT_RATE]
num_frames = i[_NUM_FRAMES]
bps = i[_BPS]
metadata = i[_METADATA]
if media_type == "audio":
return StreamReaderSourceAudioStream(
media_type=media_type,
Expand All @@ -134,6 +139,7 @@ def _parse_si(i):
bit_rate=bit_rate,
num_frames=num_frames,
bits_per_sample=bps,
metadata=metadata,
sample_rate=i[_SAMPLE_RATE],
num_channels=i[_NUM_CHANNELS],
)
Expand All @@ -146,6 +152,7 @@ def _parse_si(i):
bit_rate=bit_rate,
num_frames=num_frames,
bits_per_sample=bps,
metadata=metadata,
width=i[_WIDTH],
height=i[_HEIGHT],
frame_rate=i[_FRAME_RATE],
Expand All @@ -158,6 +165,7 @@ def _parse_si(i):
bit_rate=None,
num_frames=None,
bits_per_sample=None,
metadata=metadata,
)


Expand Down

0 comments on commit 09cfdac

Please sign in to comment.