Skip to content

Commit

Permalink
Refactor Streamer implementation
Browse files Browse the repository at this point in the history
* Move the helper wrapping code in TorchBind layer to proper wrapper class for so that it will be re-used in PyBind11.
* Move `add_basic_[audio|video]_stream` methods from C++ to Python, as they are just string manipulation. This will make PyBind11-based binding simpler as it needs not to deal with dtype.
* Move `add_[audio|video]_stream` wrapper signature to Streamer core, so that Streamer directly deals with `c10::optional` and `c10::Dic`. This reduces the code and gets rid of intermediate `std::map` structure.†

† Related to this, there is a slight change in how the empty filter expression is stored. Originally, if an empty filter expression was given to `add_[audio|video]_stream` method, the `StreamReaderOutputStream` was showing it as empty string `""`, even though internally it was using `"anull"` or `"null"`. Now `StreamReaderOutputStream` shows the corresponding filter expression that is actually being used.
  • Loading branch information
mthrok committed May 18, 2022
1 parent c6a376c commit 08ead80
Show file tree
Hide file tree
Showing 18 changed files with 372 additions and 475 deletions.
4 changes: 2 additions & 2 deletions test/torchaudio_unittest/io/stream_reader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def test_basic_audio_stream(self):

sinfo = s.get_out_stream_info(0)
assert sinfo.source_index == s.default_audio_stream
assert sinfo.filter_description == ""
assert sinfo.filter_description == "anull"

sinfo = s.get_out_stream_info(1)
assert sinfo.source_index == s.default_audio_stream
Expand All @@ -185,7 +185,7 @@ def test_basic_video_stream(self):

sinfo = s.get_out_stream_info(0)
assert sinfo.source_index == s.default_video_stream
assert sinfo.filter_description == ""
assert sinfo.filter_description == "null"

sinfo = s.get_out_stream_info(1)
assert sinfo.source_index == s.default_video_stream
Expand Down
1 change: 1 addition & 0 deletions torchaudio/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ if(USE_FFMPEG)
ffmpeg/sink.cpp
ffmpeg/stream_processor.cpp
ffmpeg/streamer.cpp
ffmpeg/stream_reader_wrapper.cpp
)
message(STATUS "FFMPEG_ROOT=$ENV{FFMPEG_ROOT}")
find_package(FFMPEG 4.1 REQUIRED COMPONENTS avdevice avfilter avformat avcodec avutil)
Expand Down
4 changes: 2 additions & 2 deletions torchaudio/csrc/ffmpeg/decoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ namespace ffmpeg {
////////////////////////////////////////////////////////////////////////////////
Decoder::Decoder(
AVCodecParameters* pParam,
const std::string& decoder_name,
const std::map<std::string, std::string>& decoder_option,
const c10::optional<std::string>& decoder_name,
const c10::optional<OptionDict>& decoder_option,
const torch::Device& device)
: pCodecContext(get_decode_context(pParam->codec_id, decoder_name)) {
init_codec_context(
Expand Down
4 changes: 2 additions & 2 deletions torchaudio/csrc/ffmpeg/decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ class Decoder {
// Default constructable
Decoder(
AVCodecParameters* pParam,
const std::string& decoder_name,
const std::map<std::string, std::string>& decoder_option,
const c10::optional<std::string>& decoder_name,
const c10::optional<OptionDict>& decoder_option,
const torch::Device& device);
// Custom destructor to clean up the resources
~Decoder() = default;
Expand Down
46 changes: 30 additions & 16 deletions torchaudio/csrc/ffmpeg/ffmpeg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@ void AVFormatContextDeleter::operator()(AVFormatContext* p) {

namespace {

AVDictionary* get_option_dict(
const std::map<std::string, std::string>& option) {
AVDictionary* get_option_dict(const c10::optional<OptionDict>& option) {
AVDictionary* opt = nullptr;
for (auto& it : option) {
av_dict_set(&opt, it.first.c_str(), it.second.c_str(), 0);
if (option) {
for (auto& it : option.value()) {
av_dict_set(&opt, it.key().c_str(), it.value().c_str(), 0);
}
}
return opt;
}
Expand Down Expand Up @@ -66,12 +67,25 @@ std::string join(std::vector<std::string> vars) {

AVFormatContextPtr get_input_format_context(
const std::string& src,
const std::string& device,
const std::map<std::string, std::string>& option) {
const c10::optional<std::string>& device,
const c10::optional<OptionDict>& option) {
AVFormatContext* pFormat = NULL;

AVINPUT_FORMAT_CONST AVInputFormat* pInput =
device.empty() ? NULL : av_find_input_format(device.c_str());
AVINPUT_FORMAT_CONST AVInputFormat* pInput = [&]() -> AVInputFormat* {
if (device.has_value()) {
std::string device_str = device.value();
AVINPUT_FORMAT_CONST AVInputFormat* p =
av_find_input_format(device_str.c_str());
if (!p) {
std::ostringstream msg;
msg << "Unsupported device: \"" << device_str << "\"";
throw std::runtime_error(msg.str());
}
return p;
}
return nullptr;
}();

AVDictionary* opt = get_option_dict(option);
int ret = avformat_open_input(&pFormat, src.c_str(), pInput, &opt);

Expand Down Expand Up @@ -148,18 +162,18 @@ void AVCodecContextDeleter::operator()(AVCodecContext* p) {
namespace {
const AVCodec* get_decode_codec(
enum AVCodecID codec_id,
const std::string& decoder_name) {
const AVCodec* pCodec = decoder_name.empty()
const c10::optional<std::string>& decoder_name) {
const AVCodec* pCodec = !decoder_name.has_value()
? avcodec_find_decoder(codec_id)
: avcodec_find_decoder_by_name(decoder_name.c_str());
: avcodec_find_decoder_by_name(decoder_name.value().c_str());

if (!pCodec) {
std::stringstream ss;
if (decoder_name.empty()) {
if (!decoder_name.has_value()) {
ss << "Unsupported codec: \"" << avcodec_get_name(codec_id) << "\", ("
<< codec_id << ").";
} else {
ss << "Unsupported codec: \"" << decoder_name << "\".";
ss << "Unsupported codec: \"" << decoder_name.value() << "\".";
}
throw std::runtime_error(ss.str());
}
Expand All @@ -170,7 +184,7 @@ const AVCodec* get_decode_codec(

AVCodecContextPtr get_decode_context(
enum AVCodecID codec_id,
const std::string& decoder_name) {
const c10::optional<std::string>& decoder_name) {
const AVCodec* pCodec = get_decode_codec(codec_id, decoder_name);

AVCodecContext* pCodecContext = avcodec_alloc_context3(pCodec);
Expand Down Expand Up @@ -216,8 +230,8 @@ const AVCodecHWConfig* get_cuda_config(const AVCodec* pCodec) {
void init_codec_context(
AVCodecContext* pCodecContext,
AVCodecParameters* pParams,
const std::string& decoder_name,
const std::map<std::string, std::string>& decoder_option,
const c10::optional<std::string>& decoder_name,
const c10::optional<OptionDict>& decoder_option,
const torch::Device& device,
AVBufferRefPtr& pHWBufferRef) {
const AVCodec* pCodec = get_decode_codec(pParams->codec_id, decoder_name);
Expand Down
12 changes: 7 additions & 5 deletions torchaudio/csrc/ffmpeg/ffmpeg.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ extern "C" {
namespace torchaudio {
namespace ffmpeg {

using OptionDict = c10::Dict<std::string, std::string>;

// Replacement of av_err2str, which causes
// `error: taking address of temporary array`
// https://github.com/joncampbell123/composite-video-simulator/issues/5
Expand Down Expand Up @@ -71,8 +73,8 @@ struct AVFormatContextPtr
// create format context for reading media
AVFormatContextPtr get_input_format_context(
const std::string& src,
const std::string& device,
const std::map<std::string, std::string>& option);
const c10::optional<std::string>& device,
const c10::optional<OptionDict>& option);

////////////////////////////////////////////////////////////////////////////////
// AVPacket
Expand Down Expand Up @@ -141,14 +143,14 @@ struct AVCodecContextPtr
// Allocate codec context from either decoder name or ID
AVCodecContextPtr get_decode_context(
enum AVCodecID codec_id,
const std::string& decoder);
const c10::optional<std::string>& decoder);

// Initialize codec context with the parameters
void init_codec_context(
AVCodecContext* pCodecContext,
AVCodecParameters* pParams,
const std::string& decoder_name,
const std::map<std::string, std::string>& decoder_option,
const c10::optional<std::string>& decoder_name,
const c10::optional<OptionDict>& decoder_option,
const torch::Device& device,
AVBufferRefPtr& pHWBufferRef);

Expand Down
19 changes: 8 additions & 11 deletions torchaudio/csrc/ffmpeg/filter_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@ namespace ffmpeg {
FilterGraph::FilterGraph(
AVRational time_base,
AVCodecParameters* codecpar,
std::string filter_description)
const c10::optional<std::string>& filter_description)
: input_time_base(time_base),
codecpar(codecpar),
filter_description(std::move(filter_description)),
filter_description(filter_description.value_or(
codecpar->codec_type == AVMEDIA_TYPE_AUDIO ? "anull" : "null")),
media_type(codecpar->codec_type) {
init();
}
Expand Down Expand Up @@ -49,10 +50,10 @@ std::string get_video_src_args(
std::snprintf(
args,
sizeof(args),
"video_size=%dx%d:pix_fmt=%d:time_base=%d/%d:pixel_aspect=%d/%d",
"video_size=%dx%d:pix_fmt=%s:time_base=%d/%d:pixel_aspect=%d/%d",
codecpar->width,
codecpar->height,
static_cast<AVPixelFormat>(codecpar->format),
av_get_pix_fmt_name(static_cast<AVPixelFormat>(codecpar->format)),
time_base.num,
time_base.den,
codecpar->sample_aspect_ratio.num,
Expand Down Expand Up @@ -165,16 +166,12 @@ void FilterGraph::add_process() {
// If you are debugging this part of the code, you might get confused.
InOuts in{"in", buffersrc_ctx}, out{"out", buffersink_ctx};

std::string desc = filter_description.empty()
? (media_type == AVMEDIA_TYPE_AUDIO) ? "anull" : "null"
: filter_description;

int ret =
avfilter_graph_parse_ptr(pFilterGraph, desc.c_str(), out, in, nullptr);
int ret = avfilter_graph_parse_ptr(
pFilterGraph, filter_description.c_str(), out, in, nullptr);

if (ret < 0) {
throw std::runtime_error(
"Failed to create the filter from \"" + desc + "\" (" +
"Failed to create the filter from \"" + filter_description + "\" (" +
av_err2string(ret) + ".)");
}
}
Expand Down
2 changes: 1 addition & 1 deletion torchaudio/csrc/ffmpeg/filter_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class FilterGraph {
FilterGraph(
AVRational time_base,
AVCodecParameters* codecpar,
std::string filter_desc);
const c10::optional<std::string>& filter_desc);
// Custom destructor to release AVFilterGraph*
~FilterGraph() = default;
// Non-copyable
Expand Down
Loading

0 comments on commit 08ead80

Please sign in to comment.