Skip to content

Commit

Permalink
Flush and reset internal state after seek (pytorch#2264)
Browse files Browse the repository at this point in the history
Summary:
This commit adds the following behavior to `seek` so that `seek`
works after a frame is decoded.

1. Flush the decoder buffer.
2. Recreate filter graphs (so that internal state is re-initialized)
3. Discard the buffered tensor. (decoded chunks)

Also it disallows negative values for seek timestamp.

Pull Request resolved: pytorch#2264

Reviewed By: carolineechen

Differential Revision: D34497826

Pulled By: mthrok

fbshipit-source-id: 8b9a5bf160dfeb15f5cced3eed2288c33e2eb35d
  • Loading branch information
mthrok authored and xiaohui-zhang committed May 4, 2022
1 parent 275d0a1 commit 09caf6e
Show file tree
Hide file tree
Showing 14 changed files with 113 additions and 15 deletions.
31 changes: 31 additions & 0 deletions test/torchaudio_unittest/prototype/io_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,22 @@ def test_stream_smoke_test(self):
if i >= 40:
break

def test_seek(self):
"""Calling `seek` multiple times should not segfault"""
s = Streamer(get_video_asset())
for i in range(10):
s.seek(i)
for _ in range(0):
s.seek(0)
for i in range(10, 0, -1):
s.seek(i)

def test_seek_negative(self):
"""Calling `seek` with negative value should raise an exception"""
s = Streamer(get_video_asset())
with self.assertRaises(ValueError):
s.seek(-1.0)


@skipIfNoFFmpeg
class StreamerAudioTest(TempDirMixin, TorchaudioTestCase):
Expand Down Expand Up @@ -363,6 +379,21 @@ def test_audio_seek(self, dtype, num_channels):
(output,) = s.pop_chunks()
self.assertEqual(expected, output)

def test_audio_seek_multiple(self):
"""Calling `seek` after streaming is started should change the position properly"""
path, original = self._get_reference_wav(1, dtype="int16", num_channels=2, num_frames=30)

s = Streamer(path)
s.add_audio_stream(frames_per_chunk=-1)

ts = list(range(20)) + list(range(20, 0, -1)) + list(range(20))
for t in ts:
s.seek(float(t))
s.process_all_packets()
(output,) = s.pop_chunks()
expected = original[t:, :]
self.assertEqual(expected, output)

@nested_params(
[
(18, 6, 3), # num_frames is divisible by frames_per_chunk
Expand Down
4 changes: 4 additions & 0 deletions torchaudio/csrc/ffmpeg/buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,5 +285,9 @@ torch::Tensor Buffer::pop_all() {
return torch::cat(ret, 0);
}

void Buffer::flush() {
chunks.clear();
}

} // namespace ffmpeg
} // namespace torchaudio
2 changes: 2 additions & 0 deletions torchaudio/csrc/ffmpeg/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class Buffer {

c10::optional<torch::Tensor> pop_chunk();

void flush();

private:
virtual torch::Tensor pop_one_chunk() = 0;
torch::Tensor pop_all();
Expand Down
4 changes: 4 additions & 0 deletions torchaudio/csrc/ffmpeg/decoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,9 @@ int Decoder::get_frame(AVFrame* pFrame) {
return avcodec_receive_frame(pCodecContext, pFrame);
}

void Decoder::flush_buffer() {
avcodec_flush_buffers(pCodecContext);
}

} // namespace ffmpeg
} // namespace torchaudio
2 changes: 2 additions & 0 deletions torchaudio/csrc/ffmpeg/decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class Decoder {
int process_packet(AVPacket* pPacket);
// Fetch a decoded frame
int get_frame(AVFrame* pFrame);
// Flush buffer (for seek)
void flush_buffer();
};

} // namespace ffmpeg
Expand Down
4 changes: 4 additions & 0 deletions torchaudio/csrc/ffmpeg/ffmpeg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,5 +196,9 @@ AVFilterGraph* get_filter_graph() {
} // namespace
AVFilterGraphPtr::AVFilterGraphPtr()
: Wrapper<AVFilterGraph, AVFilterGraphDeleter>(get_filter_graph()) {}

void AVFilterGraphPtr::reset() {
ptr.reset(get_filter_graph());
}
} // namespace ffmpeg
} // namespace torchaudio
1 change: 1 addition & 0 deletions torchaudio/csrc/ffmpeg/ffmpeg.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ struct AVFilterGraphDeleter {
};
struct AVFilterGraphPtr : public Wrapper<AVFilterGraph, AVFilterGraphDeleter> {
AVFilterGraphPtr();
void reset();
};
} // namespace ffmpeg
} // namespace torchaudio
35 changes: 23 additions & 12 deletions torchaudio/csrc/ffmpeg/filter_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ FilterGraph::FilterGraph(
AVRational time_base,
AVCodecParameters* codecpar,
std::string filter_description)
: filter_description(filter_description) {
add_src(time_base, codecpar);
add_sink();
add_process();
create_filter();
: input_time_base(time_base),
codecpar(codecpar),
filter_description(std::move(filter_description)),
media_type(codecpar->codec_type) {
init();
}

////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -62,18 +62,29 @@ std::string get_video_src_args(

} // namespace

void FilterGraph::add_src(AVRational time_base, AVCodecParameters* codecpar) {
if (media_type != AVMEDIA_TYPE_UNKNOWN) {
throw std::runtime_error("Source buffer is already allocated.");
}
media_type = codecpar->codec_type;
void FilterGraph::init() {
add_src();
add_sink();
add_process();
create_filter();
}

void FilterGraph::reset() {
pFilterGraph.reset();
buffersrc_ctx = nullptr;
buffersink_ctx = nullptr;

init();
}

void FilterGraph::add_src() {
std::string args;
switch (media_type) {
case AVMEDIA_TYPE_AUDIO:
args = get_audio_src_args(time_base, codecpar);
args = get_audio_src_args(input_time_base, codecpar);
break;
case AVMEDIA_TYPE_VIDEO:
args = get_video_src_args(time_base, codecpar);
args = get_video_src_args(input_time_base, codecpar);
break;
default:
throw std::runtime_error("Only audio/video are supported.");
Expand Down
17 changes: 14 additions & 3 deletions torchaudio/csrc/ffmpeg/filter_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,20 @@ namespace torchaudio {
namespace ffmpeg {

class FilterGraph {
AVMediaType media_type = AVMEDIA_TYPE_UNKNOWN;
// Parameters required for `reset`
// Recreats the underlying filter_graph struct
AVRational input_time_base;
AVCodecParameters* codecpar;
std::string filter_description;

// Constant just for convenient access.
AVMediaType media_type;

AVFilterGraphPtr pFilterGraph;
// AVFilterContext is freed as a part of AVFilterGraph
// so we do not manage the resource.
AVFilterContext* buffersrc_ctx = nullptr;
AVFilterContext* buffersink_ctx = nullptr;
const std::string filter_description;

public:
FilterGraph(
Expand All @@ -35,8 +42,12 @@ class FilterGraph {
//////////////////////////////////////////////////////////////////////////////
// Configuration methods
//////////////////////////////////////////////////////////////////////////////
void init();

void reset();

private:
void add_src(AVRational time_base, AVCodecParameters* codecpar);
void add_src();

void add_sink();

Expand Down
6 changes: 6 additions & 0 deletions torchaudio/csrc/ffmpeg/sink.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,11 @@ int Sink::process_frame(AVFrame* pFrame) {
bool Sink::is_buffer_ready() const {
return buffer->is_ready();
}

void Sink::flush() {
filter.reset();
buffer->flush();
}

} // namespace ffmpeg
} // namespace torchaudio
2 changes: 2 additions & 0 deletions torchaudio/csrc/ffmpeg/sink.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ class Sink {

int process_frame(AVFrame* frame);
bool is_buffer_ready() const;

void flush();
};

} // namespace ffmpeg
Expand Down
7 changes: 7 additions & 0 deletions torchaudio/csrc/ffmpeg/stream_processor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,13 @@ int StreamProcessor::process_packet(AVPacket* packet) {
return ret;
}

void StreamProcessor::flush() {
decoder.flush_buffer();
for (auto& ite : sinks) {
ite.second.flush();
}
}

// 0: some kind of success
// <0: Some error happened
int StreamProcessor::send_frame(AVFrame* pFrame) {
Expand Down
4 changes: 4 additions & 0 deletions torchaudio/csrc/ffmpeg/stream_processor.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ class StreamProcessor {
// - Sending NULL will drain (flush) the internal
int process_packet(AVPacket* packet);

// flush the internal buffer of decoder.
// To be use when seeking
void flush();

private:
int send_frame(AVFrame* pFrame);

Expand Down
9 changes: 9 additions & 0 deletions torchaudio/csrc/ffmpeg/streamer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,20 @@ bool Streamer::is_buffer_ready() const {
// Configure methods
////////////////////////////////////////////////////////////////////////////////
void Streamer::seek(double timestamp) {
if (timestamp < 0) {
throw std::invalid_argument("timestamp must be non-negative.");
}

int64_t ts = static_cast<int64_t>(timestamp * AV_TIME_BASE);
int ret = avformat_seek_file(pFormatContext, -1, INT64_MIN, ts, INT64_MAX, 0);
if (ret < 0) {
throw std::runtime_error("Failed to seek. (" + av_err2string(ret) + ".)");
}
for (const auto& it : processors) {
if (it) {
it->flush();
}
}
}

void Streamer::add_audio_stream(
Expand Down

0 comments on commit 09caf6e

Please sign in to comment.