diff --git a/test/torchaudio_unittest/io/stream_reader_test.py b/test/torchaudio_unittest/io/stream_reader_test.py index 148911def4a..f14c003c56d 100644 --- a/test/torchaudio_unittest/io/stream_reader_test.py +++ b/test/torchaudio_unittest/io/stream_reader_test.py @@ -165,7 +165,7 @@ def test_basic_audio_stream(self): sinfo = s.get_out_stream_info(0) assert sinfo.source_index == s.default_audio_stream - assert sinfo.filter_description == "" + assert sinfo.filter_description == "anull" sinfo = s.get_out_stream_info(1) assert sinfo.source_index == s.default_audio_stream @@ -185,7 +185,7 @@ def test_basic_video_stream(self): sinfo = s.get_out_stream_info(0) assert sinfo.source_index == s.default_video_stream - assert sinfo.filter_description == "" + assert sinfo.filter_description == "null" sinfo = s.get_out_stream_info(1) assert sinfo.source_index == s.default_video_stream diff --git a/torchaudio/csrc/CMakeLists.txt b/torchaudio/csrc/CMakeLists.txt index 615820665af..0b2a0ad33c6 100644 --- a/torchaudio/csrc/CMakeLists.txt +++ b/torchaudio/csrc/CMakeLists.txt @@ -181,6 +181,7 @@ if(USE_FFMPEG) ffmpeg/sink.cpp ffmpeg/stream_processor.cpp ffmpeg/streamer.cpp + ffmpeg/stream_reader_wrapper.cpp ) message(STATUS "FFMPEG_ROOT=$ENV{FFMPEG_ROOT}") find_package(FFMPEG 4.1 REQUIRED COMPONENTS avdevice avfilter avformat avcodec avutil) diff --git a/torchaudio/csrc/ffmpeg/decoder.cpp b/torchaudio/csrc/ffmpeg/decoder.cpp index dcc6c89bca5..34e79492c91 100644 --- a/torchaudio/csrc/ffmpeg/decoder.cpp +++ b/torchaudio/csrc/ffmpeg/decoder.cpp @@ -8,8 +8,8 @@ namespace ffmpeg { //////////////////////////////////////////////////////////////////////////////// Decoder::Decoder( AVCodecParameters* pParam, - const std::string& decoder_name, - const std::map& decoder_option, + const c10::optional& decoder_name, + const c10::optional& decoder_option, const torch::Device& device) : pCodecContext(get_decode_context(pParam->codec_id, decoder_name)) { init_codec_context( diff --git a/torchaudio/csrc/ffmpeg/decoder.h b/torchaudio/csrc/ffmpeg/decoder.h index b292277c190..820060fa303 100644 --- a/torchaudio/csrc/ffmpeg/decoder.h +++ b/torchaudio/csrc/ffmpeg/decoder.h @@ -13,8 +13,8 @@ class Decoder { // Default constructable Decoder( AVCodecParameters* pParam, - const std::string& decoder_name, - const std::map& decoder_option, + const c10::optional& decoder_name, + const c10::optional& decoder_option, const torch::Device& device); // Custom destructor to clean up the resources ~Decoder() = default; diff --git a/torchaudio/csrc/ffmpeg/ffmpeg.cpp b/torchaudio/csrc/ffmpeg/ffmpeg.cpp index 1118ebcfd0d..f2a98112c89 100644 --- a/torchaudio/csrc/ffmpeg/ffmpeg.cpp +++ b/torchaudio/csrc/ffmpeg/ffmpeg.cpp @@ -17,11 +17,12 @@ void AVFormatContextDeleter::operator()(AVFormatContext* p) { namespace { -AVDictionary* get_option_dict( - const std::map& option) { +AVDictionary* get_option_dict(const c10::optional& option) { AVDictionary* opt = nullptr; - for (auto& it : option) { - av_dict_set(&opt, it.first.c_str(), it.second.c_str(), 0); + if (option) { + for (auto& it : option.value()) { + av_dict_set(&opt, it.key().c_str(), it.value().c_str(), 0); + } } return opt; } @@ -66,12 +67,25 @@ std::string join(std::vector vars) { AVFormatContextPtr get_input_format_context( const std::string& src, - const std::string& device, - const std::map& option) { + const c10::optional& device, + const c10::optional& option) { AVFormatContext* pFormat = NULL; - AVINPUT_FORMAT_CONST AVInputFormat* pInput = - device.empty() ? NULL : av_find_input_format(device.c_str()); + AVINPUT_FORMAT_CONST AVInputFormat* pInput = [&]() -> AVInputFormat* { + if (device.has_value()) { + std::string device_str = device.value(); + AVINPUT_FORMAT_CONST AVInputFormat* p = + av_find_input_format(device_str.c_str()); + if (!p) { + std::ostringstream msg; + msg << "Unsupported device: \"" << device_str << "\""; + throw std::runtime_error(msg.str()); + } + return p; + } + return nullptr; + }(); + AVDictionary* opt = get_option_dict(option); int ret = avformat_open_input(&pFormat, src.c_str(), pInput, &opt); @@ -148,18 +162,18 @@ void AVCodecContextDeleter::operator()(AVCodecContext* p) { namespace { const AVCodec* get_decode_codec( enum AVCodecID codec_id, - const std::string& decoder_name) { - const AVCodec* pCodec = decoder_name.empty() + const c10::optional& decoder_name) { + const AVCodec* pCodec = !decoder_name.has_value() ? avcodec_find_decoder(codec_id) - : avcodec_find_decoder_by_name(decoder_name.c_str()); + : avcodec_find_decoder_by_name(decoder_name.value().c_str()); if (!pCodec) { std::stringstream ss; - if (decoder_name.empty()) { + if (!decoder_name.has_value()) { ss << "Unsupported codec: \"" << avcodec_get_name(codec_id) << "\", (" << codec_id << ")."; } else { - ss << "Unsupported codec: \"" << decoder_name << "\"."; + ss << "Unsupported codec: \"" << decoder_name.value() << "\"."; } throw std::runtime_error(ss.str()); } @@ -170,7 +184,7 @@ const AVCodec* get_decode_codec( AVCodecContextPtr get_decode_context( enum AVCodecID codec_id, - const std::string& decoder_name) { + const c10::optional& decoder_name) { const AVCodec* pCodec = get_decode_codec(codec_id, decoder_name); AVCodecContext* pCodecContext = avcodec_alloc_context3(pCodec); @@ -216,8 +230,8 @@ const AVCodecHWConfig* get_cuda_config(const AVCodec* pCodec) { void init_codec_context( AVCodecContext* pCodecContext, AVCodecParameters* pParams, - const std::string& decoder_name, - const std::map& decoder_option, + const c10::optional& decoder_name, + const c10::optional& decoder_option, const torch::Device& device, AVBufferRefPtr& pHWBufferRef) { const AVCodec* pCodec = get_decode_codec(pParams->codec_id, decoder_name); diff --git a/torchaudio/csrc/ffmpeg/ffmpeg.h b/torchaudio/csrc/ffmpeg/ffmpeg.h index 9f6b61077b7..1e77b0bc04f 100644 --- a/torchaudio/csrc/ffmpeg/ffmpeg.h +++ b/torchaudio/csrc/ffmpeg/ffmpeg.h @@ -23,6 +23,8 @@ extern "C" { namespace torchaudio { namespace ffmpeg { +using OptionDict = c10::Dict; + // Replacement of av_err2str, which causes // `error: taking address of temporary array` // https://github.com/joncampbell123/composite-video-simulator/issues/5 @@ -71,8 +73,8 @@ struct AVFormatContextPtr // create format context for reading media AVFormatContextPtr get_input_format_context( const std::string& src, - const std::string& device, - const std::map& option); + const c10::optional& device, + const c10::optional& option); //////////////////////////////////////////////////////////////////////////////// // AVPacket @@ -141,14 +143,14 @@ struct AVCodecContextPtr // Allocate codec context from either decoder name or ID AVCodecContextPtr get_decode_context( enum AVCodecID codec_id, - const std::string& decoder); + const c10::optional& decoder); // Initialize codec context with the parameters void init_codec_context( AVCodecContext* pCodecContext, AVCodecParameters* pParams, - const std::string& decoder_name, - const std::map& decoder_option, + const c10::optional& decoder_name, + const c10::optional& decoder_option, const torch::Device& device, AVBufferRefPtr& pHWBufferRef); diff --git a/torchaudio/csrc/ffmpeg/filter_graph.cpp b/torchaudio/csrc/ffmpeg/filter_graph.cpp index 5ec59f6d5ab..453cd907639 100644 --- a/torchaudio/csrc/ffmpeg/filter_graph.cpp +++ b/torchaudio/csrc/ffmpeg/filter_graph.cpp @@ -7,10 +7,11 @@ namespace ffmpeg { FilterGraph::FilterGraph( AVRational time_base, AVCodecParameters* codecpar, - std::string filter_description) + const c10::optional& filter_description) : input_time_base(time_base), codecpar(codecpar), - filter_description(std::move(filter_description)), + filter_description(filter_description.value_or( + codecpar->codec_type == AVMEDIA_TYPE_AUDIO ? "anull" : "null")), media_type(codecpar->codec_type) { init(); } @@ -49,10 +50,10 @@ std::string get_video_src_args( std::snprintf( args, sizeof(args), - "video_size=%dx%d:pix_fmt=%d:time_base=%d/%d:pixel_aspect=%d/%d", + "video_size=%dx%d:pix_fmt=%s:time_base=%d/%d:pixel_aspect=%d/%d", codecpar->width, codecpar->height, - static_cast(codecpar->format), + av_get_pix_fmt_name(static_cast(codecpar->format)), time_base.num, time_base.den, codecpar->sample_aspect_ratio.num, @@ -165,16 +166,12 @@ void FilterGraph::add_process() { // If you are debugging this part of the code, you might get confused. InOuts in{"in", buffersrc_ctx}, out{"out", buffersink_ctx}; - std::string desc = filter_description.empty() - ? (media_type == AVMEDIA_TYPE_AUDIO) ? "anull" : "null" - : filter_description; - - int ret = - avfilter_graph_parse_ptr(pFilterGraph, desc.c_str(), out, in, nullptr); + int ret = avfilter_graph_parse_ptr( + pFilterGraph, filter_description.c_str(), out, in, nullptr); if (ret < 0) { throw std::runtime_error( - "Failed to create the filter from \"" + desc + "\" (" + + "Failed to create the filter from \"" + filter_description + "\" (" + av_err2string(ret) + ".)"); } } diff --git a/torchaudio/csrc/ffmpeg/filter_graph.h b/torchaudio/csrc/ffmpeg/filter_graph.h index 6202bd7371e..081ee7a9c7a 100644 --- a/torchaudio/csrc/ffmpeg/filter_graph.h +++ b/torchaudio/csrc/ffmpeg/filter_graph.h @@ -24,7 +24,7 @@ class FilterGraph { FilterGraph( AVRational time_base, AVCodecParameters* codecpar, - std::string filter_desc); + const c10::optional& filter_desc); // Custom destructor to release AVFilterGraph* ~FilterGraph() = default; // Non-copyable diff --git a/torchaudio/csrc/ffmpeg/prototype.cpp b/torchaudio/csrc/ffmpeg/prototype.cpp index 661e904692b..22fddd0992e 100644 --- a/torchaudio/csrc/ffmpeg/prototype.cpp +++ b/torchaudio/csrc/ffmpeg/prototype.cpp @@ -1,5 +1,5 @@ #include -#include +#include #include namespace torchaudio { @@ -7,357 +7,27 @@ namespace ffmpeg { namespace { -using OptionDict = c10::Dict; - -std::map convert_dict( - const c10::optional& option) { - std::map opts; - if (option) { - for (auto& it : option.value()) { - opts[it.key()] = it.value(); - } - } - return opts; -} - -struct StreamerHolder : torch::CustomClassHolder { - Streamer s; - StreamerHolder( - const std::string& src, - const c10::optional& device, - const c10::optional& option) - : s(src, device.value_or(""), convert_dict(option)) {} -}; - -using S = c10::intrusive_ptr; - -S init( +c10::intrusive_ptr init( const std::string& src, const c10::optional& device, const c10::optional& option) { - return c10::make_intrusive(src, device, option); -} - -using SrcInfo = std::tuple< - std::string, // media_type - std::string, // codec name - std::string, // codec long name - std::string, // format name - int64_t, // bit_rate - // Audio - double, // sample_rate - int64_t, // num_channels - // Video - int64_t, // width - int64_t, // height - double // frame_rate - >; - -SrcInfo convert(SrcStreamInfo ssi) { - return SrcInfo(std::forward_as_tuple( - av_get_media_type_string(ssi.media_type), - ssi.codec_name, - ssi.codec_long_name, - ssi.fmt_name, - ssi.bit_rate, - ssi.sample_rate, - ssi.num_channels, - ssi.width, - ssi.height, - ssi.frame_rate)); -} - -SrcInfo get_src_stream_info(S s, int64_t i) { - return convert(s->s.get_src_stream_info(i)); -} - -using OutInfo = std::tuple< - int64_t, // source index - std::string // filter description - >; - -OutInfo convert(OutputStreamInfo osi) { - return OutInfo( - std::forward_as_tuple(osi.source_index, osi.filter_description)); -} - -OutInfo get_out_stream_info(S s, int64_t i) { - return convert(s->s.get_out_stream_info(i)); -} - -int64_t num_src_streams(S s) { - return s->s.num_src_streams(); -} - -int64_t num_out_streams(S s) { - return s->s.num_out_streams(); -} - -int64_t find_best_audio_stream(S s) { - return s->s.find_best_audio_stream(); -} - -int64_t find_best_video_stream(S s) { - return s->s.find_best_video_stream(); -} - -void seek(S s, double timestamp) { - s->s.seek(timestamp); -} - -template -std::string string_format(const std::string& format, Args... args) { - char buffer[512]; - std::snprintf(buffer, sizeof(buffer), format.c_str(), args...); - return std::string(buffer); -} - -std::string join( - const std::vector& components, - const std::string& delim) { - std::ostringstream s; - for (int i = 0; i < components.size(); ++i) { - if (i) - s << delim; - s << components[i]; - } - return s.str(); -} -std::string get_afilter_desc( - const c10::optional& sample_rate, - const c10::optional& dtype) { - std::vector components; - if (sample_rate) { - // TODO: test float sample rate - components.emplace_back( - string_format("aresample=%d", static_cast(sample_rate.value()))); - } - if (dtype) { - AVSampleFormat fmt = [&]() { - switch (dtype.value()) { - case c10::ScalarType::Byte: - return AV_SAMPLE_FMT_U8P; - case c10::ScalarType::Short: - return AV_SAMPLE_FMT_S16P; - case c10::ScalarType::Int: - return AV_SAMPLE_FMT_S32P; - case c10::ScalarType::Long: - return AV_SAMPLE_FMT_S64P; - case c10::ScalarType::Float: - return AV_SAMPLE_FMT_FLTP; - case c10::ScalarType::Double: - return AV_SAMPLE_FMT_DBLP; - default: - throw std::runtime_error("Unexpected dtype."); - } - }(); - components.emplace_back( - string_format("aformat=sample_fmts=%s", av_get_sample_fmt_name(fmt))); - } - return join(components, ","); -} -std::string get_vfilter_desc( - const c10::optional& frame_rate, - const c10::optional& width, - const c10::optional& height, - const c10::optional& format) { - // TODO: - // - Add `flags` for different scale algorithm - // https://ffmpeg.org/ffmpeg-filters.html#scale - // - Consider `framerate` as well - // https://ffmpeg.org/ffmpeg-filters.html#framerate - - // - scale - // https://ffmpeg.org/ffmpeg-filters.html#scale-1 - // https://ffmpeg.org/ffmpeg-scaler.html#toc-Scaler-Options - // - framerate - // https://ffmpeg.org/ffmpeg-filters.html#framerate - - // TODO: - // - format - // https://ffmpeg.org/ffmpeg-filters.html#toc-format-1 - // - fps - // https://ffmpeg.org/ffmpeg-filters.html#fps-1 - std::vector components; - if (frame_rate) - components.emplace_back(string_format("fps=%lf", frame_rate.value())); - - std::vector scale_components; - if (width) - scale_components.emplace_back(string_format("width=%d", width.value())); - if (height) - scale_components.emplace_back(string_format("height=%d", height.value())); - if (scale_components.size()) - components.emplace_back( - string_format("scale=%s", join(scale_components, ":").c_str())); - if (format) { - // TODO: - // Check other useful formats - // https://pillow.readthedocs.io/en/stable/handbook/concepts.html#modes - AVPixelFormat fmt = [&]() { - const std::map valid_choices { - {"RGB", AV_PIX_FMT_RGB24}, - {"BGR", AV_PIX_FMT_BGR24}, - {"YUV", AV_PIX_FMT_YUV420P}, - {"GRAY", AV_PIX_FMT_GRAY8}, - }; - - const std::string val = format.value(); - if (valid_choices.find(val) == valid_choices.end()) { - std::stringstream ss; - ss << "Unexpected output video format: \"" << val << "\"." - << "Valid choices are; "; - int i = 0; - for (const auto& p : valid_choices) { - if (i == 0) { - ss << "\"" << p.first << "\""; - } else { - ss << ", \"" << p.first << "\""; - } - } - throw std::runtime_error(ss.str()); - } - return valid_choices.at(val); - }(); - components.emplace_back( - string_format("format=pix_fmts=%s", av_get_pix_fmt_name(fmt))); - } - return join(components, ","); -}; - -void add_basic_audio_stream( - S s, - int64_t i, - int64_t frames_per_chunk, - int64_t num_chunks, - const c10::optional& sample_rate, - const c10::optional& dtype) { - std::string filter_desc = get_afilter_desc(sample_rate, dtype); - s->s.add_audio_stream(i, frames_per_chunk, num_chunks, filter_desc, "", {}); -} - -void add_basic_video_stream( - S s, - int64_t i, - int64_t frames_per_chunk, - int64_t num_chunks, - const c10::optional& frame_rate, - const c10::optional& width, - const c10::optional& height, - const c10::optional& format) { - std::string filter_desc = get_vfilter_desc(frame_rate, width, height, format); - s->s.add_video_stream( - static_cast(i), - static_cast(frames_per_chunk), - static_cast(num_chunks), - std::move(filter_desc), - "", - {}, - torch::Device(c10::DeviceType::CPU)); -} - -void add_audio_stream( - S s, - int64_t i, - int64_t frames_per_chunk, - int64_t num_chunks, - const c10::optional& filter_desc, - const c10::optional& decoder, - const c10::optional& decoder_options) { - s->s.add_audio_stream( - i, - frames_per_chunk, - num_chunks, - filter_desc.value_or(""), - decoder.value_or(""), - convert_dict(decoder_options)); -} - -void add_video_stream( - S s, - int64_t i, - int64_t frames_per_chunk, - int64_t num_chunks, - const c10::optional& filter_desc, - const c10::optional& decoder, - const c10::optional& decoder_options, - const c10::optional& hw_accel) { - const torch::Device device = [&]() { - if (!hw_accel) { - return torch::Device{c10::DeviceType::CPU}; - } -#ifdef USE_CUDA - torch::Device d{hw_accel.value()}; - if (d.type() != c10::DeviceType::CUDA) { - std::stringstream ss; - ss << "Only CUDA is supported for hardware acceleration. Found: " - << device.str(); - throw std::runtime_error(ss.str()); - } - return d; -#else - throw std::runtime_error( - "torchaudio is not compiled with CUDA support. Hardware acceleration is not available."); -#endif - }(); - - s->s.add_video_stream( - i, - frames_per_chunk, - num_chunks, - filter_desc.value_or(""), - decoder.value_or(""), - convert_dict(decoder_options), - device); -} - -void remove_stream(S s, int64_t i) { - s->s.remove_stream(i); -} - -int64_t process_packet( - Streamer& s, - const c10::optional& timeout = c10::optional(), - const double backoff = 10.) { - int64_t code = [&]() { - if (timeout.has_value()) { - return s.process_packet_block(timeout.value(), backoff); - } - return s.process_packet(); - }(); - if (code < 0) { - throw std::runtime_error( - "Failed to process a packet. (" + av_err2string(code) + "). "); - } - return code; -} - -void process_all_packets(Streamer& s) { - int ret = 0; - do { - ret = process_packet(s); - } while (!ret); -} - -bool is_buffer_ready(S s) { - return s->s.is_buffer_ready(); -} - -std::vector> pop_chunks(S s) { - return s->s.pop_chunks(); + return c10::make_intrusive( + get_input_format_context(src, device, option)); } std::tuple, int64_t> load(const std::string& src) { - Streamer s{src, "", {}}; + StreamReaderBinding s{get_input_format_context(src, {}, {})}; int i = s.find_best_audio_stream(); - auto sinfo = s.get_src_stream_info(i); + auto sinfo = s.Streamer::get_src_stream_info(i); int64_t sample_rate = static_cast(sinfo.sample_rate); - s.add_audio_stream(i, -1, -1, "", "", {}); - process_all_packets(s); + s.add_audio_stream(i, -1, -1, {}, {}, {}); + s.process_all_packets(); auto tensors = s.pop_chunks(); return std::make_tuple<>(tensors[0], sample_rate); } +using S = const c10::intrusive_ptr&; + TORCH_LIBRARY_FRAGMENT(torchaudio, m) { m.def("torchaudio::ffmpeg_init", []() { avdevice_register_all(); @@ -365,38 +35,82 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) { av_log_set_level(AV_LOG_ERROR); }); m.def("torchaudio::ffmpeg_load", load); - m.class_("ffmpeg_Streamer"); + m.class_("ffmpeg_Streamer"); m.def("torchaudio::ffmpeg_streamer_init", init); - m.def("torchaudio::ffmpeg_streamer_num_src_streams", num_src_streams); - m.def("torchaudio::ffmpeg_streamer_num_out_streams", num_out_streams); - m.def("torchaudio::ffmpeg_streamer_get_src_stream_info", get_src_stream_info); - m.def("torchaudio::ffmpeg_streamer_get_out_stream_info", get_out_stream_info); - m.def( - "torchaudio::ffmpeg_streamer_find_best_audio_stream", - find_best_audio_stream); - m.def( - "torchaudio::ffmpeg_streamer_find_best_video_stream", - find_best_video_stream); - m.def("torchaudio::ffmpeg_streamer_seek", seek); + m.def("torchaudio::ffmpeg_streamer_num_src_streams", [](S s) { + return s->num_src_streams(); + }); + m.def("torchaudio::ffmpeg_streamer_num_out_streams", [](S s) { + return s->num_out_streams(); + }); + m.def("torchaudio::ffmpeg_streamer_get_src_stream_info", [](S s, int64_t i) { + return s->get_src_stream_info(i); + }); + m.def("torchaudio::ffmpeg_streamer_get_out_stream_info", [](S s, int64_t i) { + return s->get_out_stream_info(i); + }); + m.def("torchaudio::ffmpeg_streamer_find_best_audio_stream", [](S s) { + return s->find_best_audio_stream(); + }); + m.def("torchaudio::ffmpeg_streamer_find_best_video_stream", [](S s) { + return s->find_best_video_stream(); + }); + m.def("torchaudio::ffmpeg_streamer_seek", [](S s, double t) { + return s->seek(t); + }); m.def( - "torchaudio::ffmpeg_streamer_add_basic_audio_stream", - add_basic_audio_stream); + "torchaudio::ffmpeg_streamer_add_audio_stream", + [](S s, + int64_t i, + int64_t frames_per_chunk, + int64_t num_chunks, + const c10::optional& filter_desc, + const c10::optional& decoder, + const c10::optional& decoder_options) { + s->add_audio_stream( + i, + frames_per_chunk, + num_chunks, + filter_desc, + decoder, + decoder_options); + }); m.def( - "torchaudio::ffmpeg_streamer_add_basic_video_stream", - add_basic_video_stream); - m.def("torchaudio::ffmpeg_streamer_add_audio_stream", add_audio_stream); - m.def("torchaudio::ffmpeg_streamer_add_video_stream", add_video_stream); - m.def("torchaudio::ffmpeg_streamer_remove_stream", remove_stream); + "torchaudio::ffmpeg_streamer_add_video_stream", + [](S s, + int64_t i, + int64_t frames_per_chunk, + int64_t num_chunks, + const c10::optional& filter_desc, + const c10::optional& decoder, + const c10::optional& decoder_options, + const c10::optional& hw_accel) { + s->add_video_stream( + i, + frames_per_chunk, + num_chunks, + filter_desc, + decoder, + decoder_options, + hw_accel); + }); + m.def("torchaudio::ffmpeg_streamer_remove_stream", [](S s, int64_t i) { + s->remove_stream(i); + }); m.def( "torchaudio::ffmpeg_streamer_process_packet", - [](S s, const c10::optional& timeout, double backoff) { - return process_packet(s->s, timeout, backoff); + [](S s, const c10::optional& timeout, const double backoff) { + return s->process_packet(timeout, backoff); }); m.def("torchaudio::ffmpeg_streamer_process_all_packets", [](S s) { - return process_all_packets(s->s); + s->process_all_packets(); + }); + m.def("torchaudio::ffmpeg_streamer_is_buffer_ready", [](S s) { + return s->is_buffer_ready(); + }); + m.def("torchaudio::ffmpeg_streamer_pop_chunks", [](S s) { + return s->pop_chunks(); }); - m.def("torchaudio::ffmpeg_streamer_is_buffer_ready", is_buffer_ready); - m.def("torchaudio::ffmpeg_streamer_pop_chunks", pop_chunks); } } // namespace diff --git a/torchaudio/csrc/ffmpeg/sink.cpp b/torchaudio/csrc/ffmpeg/sink.cpp index ea35ffd7f32..4f7c35e277b 100644 --- a/torchaudio/csrc/ffmpeg/sink.cpp +++ b/torchaudio/csrc/ffmpeg/sink.cpp @@ -30,9 +30,9 @@ Sink::Sink( AVCodecParameters* codecpar, int frames_per_chunk, int num_chunks, - std::string filter_description, + const c10::optional& filter_description, const torch::Device& device) - : filter(input_time_base, codecpar, std::move(filter_description)), + : filter(input_time_base, codecpar, filter_description), buffer(get_buffer( codecpar->codec_type, frames_per_chunk, diff --git a/torchaudio/csrc/ffmpeg/sink.h b/torchaudio/csrc/ffmpeg/sink.h index 77e582afc72..b329b20b87c 100644 --- a/torchaudio/csrc/ffmpeg/sink.h +++ b/torchaudio/csrc/ffmpeg/sink.h @@ -18,7 +18,7 @@ class Sink { AVCodecParameters* codecpar, int frames_per_chunk, int num_chunks, - std::string filter_description, + const c10::optional& filter_description, const torch::Device& device); int process_frame(AVFrame* frame); diff --git a/torchaudio/csrc/ffmpeg/stream_processor.cpp b/torchaudio/csrc/ffmpeg/stream_processor.cpp index ca1a3b85233..f5c7d904d90 100644 --- a/torchaudio/csrc/ffmpeg/stream_processor.cpp +++ b/torchaudio/csrc/ffmpeg/stream_processor.cpp @@ -8,8 +8,8 @@ using KeyType = StreamProcessor::KeyType; StreamProcessor::StreamProcessor( AVCodecParameters* codecpar, - const std::string& decoder_name, - const std::map& decoder_option, + const c10::optional& decoder_name, + const c10::optional& decoder_option, const torch::Device& device) : decoder(codecpar, decoder_name, decoder_option, device) {} @@ -21,7 +21,7 @@ KeyType StreamProcessor::add_stream( AVCodecParameters* codecpar, int frames_per_chunk, int num_chunks, - std::string filter_description, + const c10::optional& filter_description, const torch::Device& device) { switch (codecpar->codec_type) { case AVMEDIA_TYPE_AUDIO: @@ -39,7 +39,7 @@ KeyType StreamProcessor::add_stream( codecpar, frames_per_chunk, num_chunks, - std::move(filter_description), + filter_description, device)); decoder_time_base = av_q2d(input_time_base); return key; diff --git a/torchaudio/csrc/ffmpeg/stream_processor.h b/torchaudio/csrc/ffmpeg/stream_processor.h index 5f555d875f1..c710696bbdd 100644 --- a/torchaudio/csrc/ffmpeg/stream_processor.h +++ b/torchaudio/csrc/ffmpeg/stream_processor.h @@ -27,8 +27,8 @@ class StreamProcessor { public: StreamProcessor( AVCodecParameters* codecpar, - const std::string& decoder_name, - const std::map& decoder_option, + const c10::optional& decoder_name, + const c10::optional& decoder_option, const torch::Device& device); ~StreamProcessor() = default; // Non-copyable @@ -52,7 +52,7 @@ class StreamProcessor { AVCodecParameters* codecpar, int frames_per_chunk, int num_chunks, - std::string filter_description, + const c10::optional& filter_description, const torch::Device& device); // 1. Remove the stream diff --git a/torchaudio/csrc/ffmpeg/stream_reader_wrapper.cpp b/torchaudio/csrc/ffmpeg/stream_reader_wrapper.cpp new file mode 100644 index 00000000000..4b3bd4d2b76 --- /dev/null +++ b/torchaudio/csrc/ffmpeg/stream_reader_wrapper.cpp @@ -0,0 +1,62 @@ +#include + +namespace torchaudio { +namespace ffmpeg { +namespace { + +SrcInfo convert(SrcStreamInfo ssi) { + return SrcInfo(std::forward_as_tuple( + av_get_media_type_string(ssi.media_type), + ssi.codec_name, + ssi.codec_long_name, + ssi.fmt_name, + ssi.bit_rate, + ssi.sample_rate, + ssi.num_channels, + ssi.width, + ssi.height, + ssi.frame_rate)); +} + +OutInfo convert(OutputStreamInfo osi) { + return OutInfo( + std::forward_as_tuple(osi.source_index, osi.filter_description)); +} +} // namespace + +StreamReaderBinding::StreamReaderBinding(AVFormatContextPtr&& p) + : Streamer(std::move(p)) {} + +SrcInfo StreamReaderBinding::get_src_stream_info(int64_t i) { + return convert(Streamer::get_src_stream_info(i)); +} + +OutInfo StreamReaderBinding::get_out_stream_info(int64_t i) { + return convert(Streamer::get_out_stream_info(i)); +} + +int64_t StreamReaderBinding::process_packet( + const c10::optional& timeout, + const double backoff) { + int64_t code = [&]() { + if (timeout.has_value()) { + return Streamer::process_packet_block(timeout.value(), backoff); + } + return Streamer::process_packet(); + }(); + if (code < 0) { + throw std::runtime_error( + "Failed to process a packet. (" + av_err2string(code) + "). "); + } + return code; +} + +void StreamReaderBinding::process_all_packets() { + int ret = 0; + do { + ret = process_packet(); + } while (!ret); +} + +} // namespace ffmpeg +} // namespace torchaudio diff --git a/torchaudio/csrc/ffmpeg/stream_reader_wrapper.h b/torchaudio/csrc/ffmpeg/stream_reader_wrapper.h new file mode 100644 index 00000000000..dbd6b75720f --- /dev/null +++ b/torchaudio/csrc/ffmpeg/stream_reader_wrapper.h @@ -0,0 +1,45 @@ +#pragma once +#include +#include + +namespace torchaudio { +namespace ffmpeg { + +using SrcInfo = std::tuple< + std::string, // media_type + std::string, // codec name + std::string, // codec long name + std::string, // format name + int64_t, // bit_rate + // Audio + double, // sample_rate + int64_t, // num_channels + // Video + int64_t, // width + int64_t, // height + double // frame_rate + >; + +using OutInfo = std::tuple< + int64_t, // source index + std::string // filter description + >; + +using OptionDict = c10::Dict; + +// Structure to implement wrapper API around Streamer, which is more suitable +// for Binding the code (i.e. it receives/returns pritimitves) +struct StreamReaderBinding : public Streamer, public torch::CustomClassHolder { + StreamReaderBinding(AVFormatContextPtr&& p); + SrcInfo get_src_stream_info(int64_t i); + OutInfo get_out_stream_info(int64_t i); + + int64_t process_packet( + const c10::optional& timeout = c10::optional(), + const double backoff = 10.); + + void process_all_packets(); +}; + +} // namespace ffmpeg +} // namespace torchaudio diff --git a/torchaudio/csrc/ffmpeg/streamer.cpp b/torchaudio/csrc/ffmpeg/streamer.cpp index 77d36170e09..59fd946dc83 100644 --- a/torchaudio/csrc/ffmpeg/streamer.cpp +++ b/torchaudio/csrc/ffmpeg/streamer.cpp @@ -42,11 +42,7 @@ void Streamer::validate_src_stream_type(int i, AVMediaType type) { ////////////////////////////////////////////////////////////////////////////// // Initialization / resource allocations ////////////////////////////////////////////////////////////////////////////// -Streamer::Streamer( - const std::string& src, - const std::string& device, - const std::map& option) - : pFormatContext(get_input_format_context(src, device, option)) { +Streamer::Streamer(AVFormatContextPtr&& p) : pFormatContext(std::move(p)) { if (avformat_find_stream_info(pFormatContext, nullptr) < 0) { throw std::runtime_error("Failed to find stream information."); } @@ -67,7 +63,7 @@ Streamer::Streamer( //////////////////////////////////////////////////////////////////////////////// // Query methods //////////////////////////////////////////////////////////////////////////////// -int Streamer::num_src_streams() const { +int64_t Streamer::num_src_streams() const { return pFormatContext->nb_streams; } @@ -103,7 +99,7 @@ SrcStreamInfo Streamer::get_src_stream_info(int i) const { return ret; } -int Streamer::num_out_streams() const { +int64_t Streamer::num_out_streams() const { return stream_indices.size(); } @@ -117,12 +113,12 @@ OutputStreamInfo Streamer::get_out_stream_info(int i) const { return ret; } -int Streamer::find_best_audio_stream() const { +int64_t Streamer::find_best_audio_stream() const { return av_find_best_stream( pFormatContext, AVMEDIA_TYPE_AUDIO, -1, -1, NULL, 0); } -int Streamer::find_best_video_stream() const { +int64_t Streamer::find_best_video_stream() const { return av_find_best_stream( pFormatContext, AVMEDIA_TYPE_VIDEO, -1, -1, NULL, 0); } @@ -157,37 +153,56 @@ void Streamer::seek(double timestamp) { } void Streamer::add_audio_stream( - int i, - int frames_per_chunk, - int num_chunks, - std::string filter_desc, - const std::string& decoder, - const std::map& decoder_option) { + int64_t i, + int64_t frames_per_chunk, + int64_t num_chunks, + const c10::optional& filter_desc, + const c10::optional& decoder, + const c10::optional& decoder_option) { add_stream( i, AVMEDIA_TYPE_AUDIO, frames_per_chunk, num_chunks, - std::move(filter_desc), + filter_desc, decoder, decoder_option, torch::Device(torch::DeviceType::CPU)); } void Streamer::add_video_stream( - int i, - int frames_per_chunk, - int num_chunks, - std::string filter_desc, - const std::string& decoder, - const std::map& decoder_option, - const torch::Device& device) { + int64_t i, + int64_t frames_per_chunk, + int64_t num_chunks, + const c10::optional& filter_desc, + const c10::optional& decoder, + const c10::optional& decoder_option, + const c10::optional& hw_accel) { + const torch::Device device = [&]() { + if (!hw_accel) { + return torch::Device{c10::DeviceType::CPU}; + } +#ifdef USE_CUDA + torch::Device d{hw_accel.value()}; + if (d.type() != c10::DeviceType::CUDA) { + std::stringstream ss; + ss << "Only CUDA is supported for hardware acceleration. Found: " + << device.str(); + throw std::runtime_error(ss.str()); + } + return d; +#else + throw std::runtime_error( + "torchaudio is not compiled with CUDA support. Hardware acceleration is not available."); +#endif + }(); + add_stream( i, AVMEDIA_TYPE_VIDEO, frames_per_chunk, num_chunks, - std::move(filter_desc), + filter_desc, decoder, decoder_option, device); @@ -198,9 +213,9 @@ void Streamer::add_stream( AVMediaType media_type, int frames_per_chunk, int num_chunks, - std::string filter_desc, - const std::string& decoder, - const std::map& decoder_option, + const c10::optional& filter_desc, + const c10::optional& decoder, + const c10::optional& decoder_option, const torch::Device& device) { validate_src_stream_type(i, media_type); @@ -214,12 +229,12 @@ void Streamer::add_stream( stream->codecpar, frames_per_chunk, num_chunks, - std::move(filter_desc), + filter_desc, device); stream_indices.push_back(std::make_pair<>(i, key)); } -void Streamer::remove_stream(int i) { +void Streamer::remove_stream(int64_t i) { validate_output_stream_index(i); auto it = stream_indices.begin() + i; int iP = it->first; diff --git a/torchaudio/csrc/ffmpeg/streamer.h b/torchaudio/csrc/ffmpeg/streamer.h index 8e47c5463e4..939ebf34112 100644 --- a/torchaudio/csrc/ffmpeg/streamer.h +++ b/torchaudio/csrc/ffmpeg/streamer.h @@ -19,11 +19,7 @@ class Streamer { std::vector> stream_indices; public: - // Open the input and allocate the resource - Streamer( - const std::string& src, - const std::string& device, - const std::map& option); + explicit Streamer(AVFormatContextPtr&& p); ~Streamer() = default; // Non-copyable Streamer(const Streamer&) = delete; @@ -46,13 +42,13 @@ class Streamer { ////////////////////////////////////////////////////////////////////////////// public: // Find a suitable audio/video streams using heuristics from ffmpeg - int find_best_audio_stream() const; - int find_best_video_stream() const; + int64_t find_best_audio_stream() const; + int64_t find_best_video_stream() const; // Fetch information about source streams - int num_src_streams() const; + int64_t num_src_streams() const; SrcStreamInfo get_src_stream_info(int i) const; // Fetch information about output streams - int num_out_streams() const; + int64_t num_out_streams() const; OutputStreamInfo get_out_stream_info(int i) const; // Check if all the buffers of the output streams are ready. bool is_buffer_ready() const; @@ -63,21 +59,21 @@ class Streamer { void seek(double timestamp); void add_audio_stream( - int i, - int frames_per_chunk, - int num_chunks, - std::string filter_desc, - const std::string& decoder, - const std::map& decoder_option); + int64_t i, + int64_t frames_per_chunk, + int64_t num_chunks, + const c10::optional& filter_desc, + const c10::optional& decoder, + const c10::optional& decoder_option); void add_video_stream( - int i, - int frames_per_chunk, - int num_chunks, - std::string filter_desc, - const std::string& decoder, - const std::map& decoder_option, - const torch::Device& device); - void remove_stream(int i); + int64_t i, + int64_t frames_per_chunk, + int64_t num_chunks, + const c10::optional& filter_desc, + const c10::optional& decoder, + const c10::optional& decoder_option, + const c10::optional& hw_accel); + void remove_stream(int64_t i); private: void add_stream( @@ -85,9 +81,9 @@ class Streamer { AVMediaType media_type, int frames_per_chunk, int num_chunks, - std::string filter_desc, - const std::string& decoder, - const std::map& decoder_option, + const c10::optional& filter_desc, + const c10::optional& decoder, + const c10::optional& decoder_option, const torch::Device& device); public: diff --git a/torchaudio/io/_stream_reader.py b/torchaudio/io/_stream_reader.py index 927fd8dcdae..942492979eb 100644 --- a/torchaudio/io/_stream_reader.py +++ b/torchaudio/io/_stream_reader.py @@ -154,6 +154,51 @@ def _parse_oi(i): return StreamReaderOutputStream(i[0], i[1]) +def _get_afilter_desc( + sample_rate: Optional[int], + dtype: torch.dtype): + descs = [] + if sample_rate is not None: + descs.append(f"aresample={sample_rate}") + if dtype is not None: + fmt = { + torch.uint8: "u8p", + torch.int16: "s16p", + torch.int32: "s32p", + torch.long: "s64p", + torch.float32: "fltp", + torch.float64: "dblp", + }[dtype] + descs.append(f"aformat=sample_fmts={fmt}") + return ",".join(descs) if descs else None + + +def _get_vfilter_desc( + frame_rate: Optional[float], + width: Optional[int], + height: Optional[int], + format: Optional[str]): + descs = [] + if frame_rate is not None: + descs.append(f"fps={frame_rate}") + scales = [] + if width is not None: + scales.append(f"width={width}") + if height is not None: + scales.append(f"height={height}") + if scales: + descs.append(f"scale={':'.join(scales)}") + if format is not None: + fmt = { + "RGB": "rgb24", + "BGR": "bgr24", + "YUV": "yuv420p", + "GRAY": "gray", + }[format] + descs.append(f"format=pix_fmts={fmt}") + return ",".join(descs) if descs else None + + class StreamReader: """Fetch and decode audio/video streams chunk by chunk. @@ -297,8 +342,14 @@ def add_basic_audio_stream( `[-1, 1]`. """ i = self.default_audio_stream if stream_index is None else stream_index - torch.ops.torchaudio.ffmpeg_streamer_add_basic_audio_stream( - self._s, i, frames_per_chunk, buffer_chunk_size, sample_rate, dtype + torch.ops.torchaudio.ffmpeg_streamer_add_audio_stream( + self._s, + i, + frames_per_chunk, + buffer_chunk_size, + _get_afilter_desc(sample_rate, dtype), + None, + None, ) def add_basic_video_stream( @@ -338,15 +389,15 @@ def add_basic_video_stream( - `GRAY`: 8 bits * 1 channels """ i = self.default_video_stream if stream_index is None else stream_index - torch.ops.torchaudio.ffmpeg_streamer_add_basic_video_stream( + torch.ops.torchaudio.ffmpeg_streamer_add_video_stream( self._s, i, frames_per_chunk, buffer_chunk_size, - frame_rate, - width, - height, - format, + _get_vfilter_desc(frame_rate, width, height, format), + None, + None, + None, ) def add_audio_stream(