Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok committed May 18, 2022
1 parent 4cccfc8 commit ab4394b
Show file tree
Hide file tree
Showing 13 changed files with 482 additions and 26 deletions.
14 changes: 9 additions & 5 deletions test/torchaudio_unittest/io/stream_reader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
_TEST_FILEOBJ = "src_is_fileobj"

def _class_name(cls, _, params):
return f'{cls.__name__}{"_fileobj" if params[_TEST_FILEOBJ] else ""}'
return f'{cls.__name__}{"_fileobj" if params[_TEST_FILEOBJ] else "_path"}'


_media_source = parameterized_class(
Expand All @@ -42,12 +42,16 @@ def setUp(self):
super().setUp()
self.src = None

@property
def test_fileobj(self):
return getattr(self, _TEST_FILEOBJ)

def get_video_asset(self, file="nasa_13013.mp4"):
if self.src is not None:
raise ValueError("get_video_asset can be called only once.")

path = get_asset_path(file)
if getattr(self, _TEST_FILEOBJ):
if self.test_fileobj:
self.src = open(path, "rb")
return self.src
return path
Expand Down Expand Up @@ -145,8 +149,8 @@ def test_src_info(self):
bit_rate=None,
),
]
for i, exp in enumerate(expected):
assert exp == s.get_src_stream_info(i)
output = [s.get_src_stream_info(i) for i in range(6)]
assert expected == output

def test_src_info_invalid_index(self):
"""`get_src_stream_info` does not segfault but raise an exception when input is invalid"""
Expand Down Expand Up @@ -326,7 +330,7 @@ def test_stream_smoke_test(self):
w, h = 256, 198
s = StreamReader(self.get_video_asset())
s.add_basic_audio_stream(frames_per_chunk=2000, sample_rate=8000)
s.add_basic_video_stream(frames_per_chunk=15, frame_rate=60, width=w, height=h)
s.add_basic_video_stream(frames_per_chunk=15, frame_rate=60, width=w, height=h, format="YUV" if self.test_fileobj else None)
for i, (achunk, vchunk) in enumerate(s.stream()):
assert achunk.shape == torch.Size([2000, 2])
assert vchunk.shape == torch.Size([15, 3, h, w])
Expand Down
7 changes: 6 additions & 1 deletion tools/setup_helpers/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,12 @@ def get_ext_modules():
]
)
if _USE_FFMPEG:
modules.append(Extension(name="torchaudio.lib.libtorchaudio_ffmpeg", sources=[]))
modules.extend(
[
Extension(name="torchaudio.lib.libtorchaudio_ffmpeg", sources=[]),
Extension(name="torchaudio._torchaudio_ffmpeg", sources=[]),
]
)
return modules


Expand Down
20 changes: 18 additions & 2 deletions torchaudio/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -204,11 +204,11 @@ if (BUILD_TORCHAUDIO_PYTHON_EXTENSION)
find_package(Python3 ${PYTHON_VERSION} EXACT COMPONENTS Development)
set(ADDITIONAL_ITEMS Python3::Python)
endif()
function(define_extension name sources libraries definitions)
function(define_extension name sources include_dirs libraries definitions)
add_library(${name} SHARED ${sources})
target_compile_definitions(${name} PRIVATE "${definitions}")
target_include_directories(
${name} PRIVATE ${PROJECT_SOURCE_DIR} ${Python_INCLUDE_DIR})
${name} PRIVATE ${PROJECT_SOURCE_DIR} ${Python_INCLUDE_DIR} ${include_dirs})
target_link_libraries(
${name}
${libraries}
Expand Down Expand Up @@ -254,6 +254,7 @@ if (BUILD_TORCHAUDIO_PYTHON_EXTENSION)
define_extension(
_torchaudio
"${EXTENSION_SOURCES}"
""
libtorchaudio
"${LIBTORCHAUDIO_COMPILE_DEFINITIONS}"
)
Expand All @@ -265,8 +266,23 @@ if (BUILD_TORCHAUDIO_PYTHON_EXTENSION)
define_extension(
_torchaudio_decoder
"${DECODER_EXTENSION_SOURCES}"
""
"libtorchaudio_decoder"
"${LIBTORCHAUDIO_DECODER_DEFINITIONS}"
)
endif()
if(USE_FFMPEG)
set(
FFMPEG_EXTENSION_SOURCES
ffmpeg/pybind/pybind.cpp
ffmpeg/pybind/stream_reader.cpp
)
define_extension(
_torchaudio_ffmpeg
"${FFMPEG_EXTENSION_SOURCES}"
"${FFMPEG_INCLUDE_DIRS}"
"libtorchaudio_ffmpeg"
"${LIBTORCHAUDIO_DECODER_DEFINITIONS}"
)
endif()
endif()
22 changes: 20 additions & 2 deletions torchaudio/csrc/ffmpeg/ffmpeg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,15 @@ std::string join(std::vector<std::string> vars) {
AVFormatContextPtr get_input_format_context(
const std::string& src,
const c10::optional<std::string>& device,
const c10::optional<OptionDict>& option) {
AVFormatContext* pFormat = NULL;
const c10::optional<OptionDict>& option,
AVIOContext* io_ctx) {
AVFormatContext* pFormat = avformat_alloc_context();
if (!pFormat) {
throw std::runtime_error("Failed to allocate AVFormatContext.");
}
if (io_ctx) {
pFormat->pb = io_ctx;
}

AVINPUT_FORMAT_CONST AVInputFormat* pInput = [&]() -> AVInputFormat* {
if (device.has_value()) {
Expand Down Expand Up @@ -105,6 +112,17 @@ AVFormatContextPtr get_input_format_context(
AVFormatContextPtr::AVFormatContextPtr(AVFormatContext* p)
: Wrapper<AVFormatContext, AVFormatContextDeleter>(p) {}

////////////////////////////////////////////////////////////////////////////////
// AVIO
////////////////////////////////////////////////////////////////////////////////
void AVIOContextDeleter::operator()(AVIOContext* p) {
av_freep(&p->buffer);
av_freep(&p);
};

AVIOContextPtr::AVIOContextPtr(AVIOContext* p)
: Wrapper<AVIOContext, AVIOContextDeleter>(p) {}

////////////////////////////////////////////////////////////////////////////////
// AVPacket
////////////////////////////////////////////////////////////////////////////////
Expand Down
15 changes: 14 additions & 1 deletion torchaudio/csrc/ffmpeg/ffmpeg.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ extern "C" {
#include <libavfilter/buffersink.h>
#include <libavfilter/buffersrc.h>
#include <libavformat/avformat.h>
#include <libavformat/avio.h>
#include <libavutil/avutil.h>
#include <libavutil/frame.h>
#include <libavutil/imgutils.h>
Expand Down Expand Up @@ -74,7 +75,19 @@ struct AVFormatContextPtr
AVFormatContextPtr get_input_format_context(
const std::string& src,
const c10::optional<std::string>& device,
const c10::optional<OptionDict>& option);
const c10::optional<OptionDict>& option,
AVIOContext* io_ctx = nullptr);

////////////////////////////////////////////////////////////////////////////////
// AVIO
////////////////////////////////////////////////////////////////////////////////
struct AVIOContextDeleter {
void operator()(AVIOContext* p);
};

struct AVIOContextPtr : public Wrapper<AVIOContext, AVIOContextDeleter> {
explicit AVIOContextPtr(AVIOContext* p);
};

////////////////////////////////////////////////////////////////////////////////
// AVPacket
Expand Down
10 changes: 7 additions & 3 deletions torchaudio/csrc/ffmpeg/prototype.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ std::tuple<c10::optional<torch::Tensor>, int64_t> load(const std::string& src) {
int i = s.find_best_audio_stream();
auto sinfo = s.Streamer::get_src_stream_info(i);
int64_t sample_rate = static_cast<int64_t>(sinfo.sample_rate);
s.add_audio_stream(i, -1, -1, {}, {}, {});
s.add_audio_stream(i, -1, -1, {}, {}, {}, {});
s.process_all_packets();
auto tensors = s.pop_chunks();
return std::make_tuple<>(tensors[0], sample_rate);
Expand Down Expand Up @@ -66,14 +66,16 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
int64_t num_chunks,
const c10::optional<std::string>& filter_desc,
const c10::optional<std::string>& decoder,
const c10::optional<OptionDict>& decoder_options) {
const c10::optional<OptionDict>& decoder_options,
const c10::optional<std::string>& src_format) {
s->add_audio_stream(
i,
frames_per_chunk,
num_chunks,
filter_desc,
decoder,
decoder_options);
decoder_options,
src_format);
});
m.def(
"torchaudio::ffmpeg_streamer_add_video_stream",
Expand All @@ -84,6 +86,7 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
const c10::optional<std::string>& filter_desc,
const c10::optional<std::string>& decoder,
const c10::optional<OptionDict>& decoder_options,
const c10::optional<std::string>& src_format,
const c10::optional<std::string>& hw_accel) {
s->add_video_stream(
i,
Expand All @@ -92,6 +95,7 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
filter_desc,
decoder,
decoder_options,
src_format,
hw_accel);
});
m.def("torchaudio::ffmpeg_streamer_remove_stream", [](S s, int64_t i) {
Expand Down
35 changes: 35 additions & 0 deletions torchaudio/csrc/ffmpeg/pybind/pybind.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#include <pybind11/stl.h>
#include <torch/extension.h>
#include <torchaudio/csrc/ffmpeg/pybind/stream_reader.h>

namespace torchaudio {
namespace ffmpeg {
namespace {

PYBIND11_MODULE(_torchaudio_ffmpeg, m) {
py::class_<StreamReaderFileObj, c10::intrusive_ptr<StreamReaderFileObj>>(
m, "StreamReaderFileObj")
.def(py::init<py::object, py::object, py::object, int64_t>())
.def("num_src_streams", &StreamReaderFileObj::num_src_streams)
.def("num_out_streams", &StreamReaderFileObj::num_out_streams)
.def(
"find_best_audio_stream",
&StreamReaderFileObj::find_best_audio_stream)
.def(
"find_best_video_stream",
&StreamReaderFileObj::find_best_video_stream)
.def("get_src_stream_info", &StreamReaderFileObj::get_src_stream_info)
.def("get_out_stream_info", &StreamReaderFileObj::get_out_stream_info)
.def("seek", &StreamReaderFileObj::seek)
.def("add_audio_stream", &StreamReaderFileObj::add_audio_stream)
.def("add_video_stream", &StreamReaderFileObj::add_video_stream)
.def("remove_stream", &StreamReaderFileObj::remove_stream)
.def("process_packet", &StreamReaderFileObj::process_packet)
.def("process_all_packets", &StreamReaderFileObj::process_all_packets)
.def("is_buffer_ready", &StreamReaderFileObj::is_buffer_ready)
.def("pop_chunks", &StreamReaderFileObj::pop_chunks);
}

} // namespace
} // namespace ffmpeg
} // namespace torchaudio
108 changes: 108 additions & 0 deletions torchaudio/csrc/ffmpeg/pybind/stream_reader.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
#include <torchaudio/csrc/ffmpeg/ffmpeg.h>
#include <torchaudio/csrc/ffmpeg/pybind/stream_reader.h>

namespace torchaudio {
namespace ffmpeg {
namespace {

static int read_function(void* opaque, uint8_t* buf, int buf_size) {
FileObj* fileobj = static_cast<FileObj*>(opaque);
buf_size = FFMIN(buf_size, fileobj->buffer_size);

int num_read = 0;
while (num_read < buf_size) {
int request = buf_size - num_read;
auto chunk = static_cast<std::string>(
static_cast<py::bytes>(fileobj->fileobj.attr("read")(request)));
auto chunk_len = chunk.length();
if (chunk_len == 0) {
break;
}
if (chunk_len > request) {
std::ostringstream message;
message
<< "Requested up to " << request << " bytes but, "
<< "received " << chunk_len << " bytes. "
<< "The given object does not confirm to read protocol of file object.";
throw std::runtime_error(message.str());
}
memcpy(buf, chunk.data(), chunk_len);
buf += chunk_len;
num_read += chunk_len;
}
return num_read == 0 ? AVERROR_EOF : num_read;
}

static int64_t seek_function(void* opaque, int64_t offset, int whence) {
// We do not know the file size.
if (whence == AVSEEK_SIZE) {
return AVERROR(EIO);
}
FileObj* fileobj = static_cast<FileObj*>(opaque);
return py::cast<int64_t>(fileobj->fileobj.attr("seek")(offset, whence));
}

AVIOContextPtr get_io_context(FileObj* opaque, int buffer_size) {
uint8_t* buffer = static_cast<uint8_t*>(av_malloc(buffer_size));
if (!buffer) {
throw std::runtime_error("Failed to allocate buffer.");
}

// If avio_alloc_context succeeds, then buffer will be cleaned up by
// AVIOContextPtr destructor.
// If avio_alloc_context fails, we need to clean up by ourselves.
AVIOContext* av_io_ctx = avio_alloc_context(
buffer,
buffer_size,
0,
static_cast<void*>(opaque),
&read_function,
nullptr,
py::hasattr(opaque->fileobj, "seek") ? &seek_function : nullptr);

if (!av_io_ctx) {
av_freep(&buffer);
throw std::runtime_error("Failed to allocate AVIO context.");
}
return AVIOContextPtr{av_io_ctx};
}

c10::optional<OptionDict> convert_dict(py::object dict) {
if (dict.is_none())
return c10::optional<OptionDict>{};

c10::Dict<std::string, std::string> out;
for (std::pair<py::handle, py::handle> item : py::cast<py::dict>(dict)) {
out.insert(item.first.cast<std::string>(), item.second.cast<std::string>());
}
return c10::optional<OptionDict>(out);
}

c10::optional<std::string> convert_str(py::object s) {
if (s.is_none())
return c10::optional<std::string>{};
return c10::optional<std::string>{
static_cast<std::string>(py::cast<py::str>(s))};
}

} // namespace

FileObj::FileObj(py::object fileobj_, int buffer_size)
: fileobj(fileobj_),
buffer_size(buffer_size),
pAVIO(get_io_context(this, buffer_size)) {}

StreamReaderFileObj::StreamReaderFileObj(
py::object fileobj_,
py::object format,
py::object option,
int64_t buffer_size)
: FileObj(fileobj_, static_cast<int>(buffer_size)),
StreamReaderBinding(get_input_format_context(
"",
convert_str(format),
convert_dict(option),
pAVIO)) {}

} // namespace ffmpeg
} // namespace torchaudio
33 changes: 33 additions & 0 deletions torchaudio/csrc/ffmpeg/pybind/stream_reader.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#pragma once
#include <torch/extension.h>
#include <torchaudio/csrc/ffmpeg/stream_reader_wrapper.h>

namespace torchaudio {
namespace ffmpeg {

// The purpose of FileObj class is so that
// FileObjStreamReader class can inherit Streamer while
// AVIOContext is initialized before AVFormat (Streamer)
struct FileObj {
py::object fileobj;
int buffer_size;
AVIOContextPtr pAVIO;
FileObj(py::object fileobj, int buffer_size);
};

struct StreamReaderFileObj : public FileObj, public StreamReaderBinding {
public:
StreamReaderFileObj(
py::object fileobj,
// Note:
// Should use `py::str` or `c10::optional<std::string>` for `format`,
// and `py::dict` or `c10::optional<OptionDict>` for `option`, but
// I could not resolve TypeError related to optionality.
// So using generic `py::object`.
py::object format,
py::object option,
int64_t buffer_size);
};

} // namespace ffmpeg
} // namespace torchaudio
Loading

0 comments on commit ab4394b

Please sign in to comment.