Skip to content

Commit

Permalink
Restructure the video/video_reader C++ codebase (#3311)
Browse files Browse the repository at this point in the history
* Moving registration of video methods in Video.cpp and removing unnecessary includes.

* Rename files according to cpp styles.

* Adding namespaces and moving private methods to anonymous namespaces.

* Syncing method names.

* Fixing minor issues.
  • Loading branch information
datumbox authored Jan 28, 2021
1 parent 691ec6d commit e95a3d2
Show file tree
Hide file tree
Showing 6 changed files with 202 additions and 149 deletions.
14 changes: 0 additions & 14 deletions torchvision/csrc/io/video/register.cpp

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
#include "Video.h"
#include <c10/util/Logging.h>
#include <torch/script.h>
#include "defs.h"
#include "memory_buffer.h"
#include "sync_decoder.h"
#include "video.h"

using namespace std;
using namespace ffmpeg;
#include <regex>

namespace vision {
namespace video {

namespace {

const size_t decoderTimeoutMs = 600000;
const AVPixelFormat defaultVideoPixelFormat = AV_PIX_FMT_RGB24;
Expand Down Expand Up @@ -93,6 +92,8 @@ std::tuple<std::string, long> _parseStream(const std::string& streamString) {
return std::make_tuple(type_, index_);
}

} // namespace

void Video::_getDecoderParams(
double videoStartS,
int64_t getPtsOnly,
Expand Down Expand Up @@ -159,7 +160,7 @@ Video::Video(std::string videoPath, std::string stream) {
Video::_getDecoderParams(
0, // video start
0, // headerOnly
get<0>(current_stream), // stream info - remove that
std::get<0>(current_stream), // stream info - remove that
long(-1), // stream_id parsed from info above change to -2
true // read all streams
);
Expand Down Expand Up @@ -209,9 +210,9 @@ Video::Video(std::string videoPath, std::string stream) {

succeeded = Video::setCurrentStream(stream);
LOG(INFO) << "\nDecoder inited with: " << succeeded << "\n";
if (get<1>(current_stream) != -1) {
if (std::get<1>(current_stream) != -1) {
LOG(INFO)
<< "Stream index set to " << get<1>(current_stream)
<< "Stream index set to " << std::get<1>(current_stream)
<< ". If you encounter trouble, consider switching it to automatic stream discovery. \n";
}
} // video
Expand All @@ -229,8 +230,8 @@ bool Video::setCurrentStream(std::string stream = "video") {
_getDecoderParams(
ts, // video start
0, // headerOnly
get<0>(current_stream), // stream
long(get<1>(
std::get<0>(current_stream), // stream
long(std::get<1>(
current_stream)), // stream_id parsed from info above change to -2
false // read all streams
);
Expand All @@ -253,8 +254,8 @@ void Video::Seek(double ts) {
_getDecoderParams(
ts, // video start
0, // headerOnly
get<0>(current_stream), // stream
long(get<1>(
std::get<0>(current_stream), // stream
long(std::get<1>(
current_stream)), // stream_id parsed from info above change to -2
false // read all streams
);
Expand Down Expand Up @@ -319,3 +320,15 @@ std::tuple<torch::Tensor, double> Video::Next() {

return std::make_tuple(outFrame, frame_pts_s);
}

static auto registerVideo =
torch::class_<Video>("torchvision", "Video")
.def(torch::init<std::string, std::string>())
.def("get_current_stream", &Video::getCurrentStream)
.def("set_current_stream", &Video::setCurrentStream)
.def("get_metadata", &Video::getStreamMetadata)
.def("seek", &Video::Seek)
.def("next", &Video::Next);

} // namespace video
} // namespace vision
Original file line number Diff line number Diff line change
@@ -1,21 +1,16 @@
#pragma once

#include <map>
#include <regex>
#include <string>
#include <vector>
#include <torch/types.h>

#include <ATen/ATen.h>
#include <c10/util/Logging.h>
#include <torch/script.h>

#include <exception>
#include "defs.h"
#include "memory_buffer.h"
#include "sync_decoder.h"
#include "../decoder/defs.h"
#include "../decoder/memory_buffer.h"
#include "../decoder/sync_decoder.h"

using namespace ffmpeg;

namespace vision {
namespace video {

struct Video : torch::CustomClassHolder {
std::tuple<std::string, long> current_stream; // stream type, id
// global video metadata
Expand Down Expand Up @@ -58,3 +53,6 @@ struct Video : torch::CustomClassHolder {
DecoderParameters params;

}; // struct Video

} // namespace video
} // namespace vision
3 changes: 0 additions & 3 deletions torchvision/csrc/io/video_reader/VideoReader.h

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
#include "VideoReader.h"
#include <ATen/ATen.h>
#include "video_reader.h"

#include <Python.h>
#include <c10/util/Logging.h>
#include <exception>
#include "memory_buffer.h"
#include "sync_decoder.h"

using namespace std;
using namespace ffmpeg;
#include "../decoder/memory_buffer.h"
#include "../decoder/sync_decoder.h"

// If we are in a Windows environment, we need to define
// initialization functions for the _custom_ops extension
Expand All @@ -18,8 +14,13 @@ PyMODINIT_FUNC PyInit_video_reader(void) {
}
#endif

using namespace ffmpeg;

namespace vision {
namespace video_reader {

namespace {

const AVPixelFormat defaultVideoPixelFormat = AV_PIX_FMT_RGB24;
const AVSampleFormat defaultAudioSampleFormat = AV_SAMPLE_FMT_FLT;
const AVRational timeBaseQ = AVRational{1, AV_TIME_BASE};
Expand Down Expand Up @@ -417,95 +418,6 @@ torch::List<torch::Tensor> readVideo(
return result;
}

torch::List<torch::Tensor> readVideoFromMemory(
torch::Tensor input_video,
double seekFrameMargin,
int64_t getPtsOnly,
int64_t readVideoStream,
int64_t width,
int64_t height,
int64_t minDimension,
int64_t maxDimension,
int64_t videoStartPts,
int64_t videoEndPts,
int64_t videoTimeBaseNum,
int64_t videoTimeBaseDen,
int64_t readAudioStream,
int64_t audioSamples,
int64_t audioChannels,
int64_t audioStartPts,
int64_t audioEndPts,
int64_t audioTimeBaseNum,
int64_t audioTimeBaseDen) {
return readVideo(
false,
input_video,
"", // videoPath
seekFrameMargin,
getPtsOnly,
readVideoStream,
width,
height,
minDimension,
maxDimension,
videoStartPts,
videoEndPts,
videoTimeBaseNum,
videoTimeBaseDen,
readAudioStream,
audioSamples,
audioChannels,
audioStartPts,
audioEndPts,
audioTimeBaseNum,
audioTimeBaseDen);
}

torch::List<torch::Tensor> readVideoFromFile(
std::string videoPath,
double seekFrameMargin,
int64_t getPtsOnly,
int64_t readVideoStream,
int64_t width,
int64_t height,
int64_t minDimension,
int64_t maxDimension,
int64_t videoStartPts,
int64_t videoEndPts,
int64_t videoTimeBaseNum,
int64_t videoTimeBaseDen,
int64_t readAudioStream,
int64_t audioSamples,
int64_t audioChannels,
int64_t audioStartPts,
int64_t audioEndPts,
int64_t audioTimeBaseNum,
int64_t audioTimeBaseDen) {
torch::Tensor dummy_input_video = torch::ones({0});
return readVideo(
true,
dummy_input_video,
videoPath,
seekFrameMargin,
getPtsOnly,
readVideoStream,
width,
height,
minDimension,
maxDimension,
videoStartPts,
videoEndPts,
videoTimeBaseNum,
videoTimeBaseDen,
readAudioStream,
audioSamples,
audioChannels,
audioStartPts,
audioEndPts,
audioTimeBaseNum,
audioTimeBaseDen);
}

torch::List<torch::Tensor> probeVideo(
bool isReadFile,
const torch::Tensor& input_video,
Expand Down Expand Up @@ -650,20 +562,112 @@ torch::List<torch::Tensor> probeVideo(
return result;
}

torch::List<torch::Tensor> probeVideoFromMemory(torch::Tensor input_video) {
} // namespace

torch::List<torch::Tensor> read_video_from_memory(
torch::Tensor input_video,
double seekFrameMargin,
int64_t getPtsOnly,
int64_t readVideoStream,
int64_t width,
int64_t height,
int64_t minDimension,
int64_t maxDimension,
int64_t videoStartPts,
int64_t videoEndPts,
int64_t videoTimeBaseNum,
int64_t videoTimeBaseDen,
int64_t readAudioStream,
int64_t audioSamples,
int64_t audioChannels,
int64_t audioStartPts,
int64_t audioEndPts,
int64_t audioTimeBaseNum,
int64_t audioTimeBaseDen) {
return readVideo(
false,
input_video,
"", // videoPath
seekFrameMargin,
getPtsOnly,
readVideoStream,
width,
height,
minDimension,
maxDimension,
videoStartPts,
videoEndPts,
videoTimeBaseNum,
videoTimeBaseDen,
readAudioStream,
audioSamples,
audioChannels,
audioStartPts,
audioEndPts,
audioTimeBaseNum,
audioTimeBaseDen);
}

torch::List<torch::Tensor> read_video_from_file(
std::string videoPath,
double seekFrameMargin,
int64_t getPtsOnly,
int64_t readVideoStream,
int64_t width,
int64_t height,
int64_t minDimension,
int64_t maxDimension,
int64_t videoStartPts,
int64_t videoEndPts,
int64_t videoTimeBaseNum,
int64_t videoTimeBaseDen,
int64_t readAudioStream,
int64_t audioSamples,
int64_t audioChannels,
int64_t audioStartPts,
int64_t audioEndPts,
int64_t audioTimeBaseNum,
int64_t audioTimeBaseDen) {
torch::Tensor dummy_input_video = torch::ones({0});
return readVideo(
true,
dummy_input_video,
videoPath,
seekFrameMargin,
getPtsOnly,
readVideoStream,
width,
height,
minDimension,
maxDimension,
videoStartPts,
videoEndPts,
videoTimeBaseNum,
videoTimeBaseDen,
readAudioStream,
audioSamples,
audioChannels,
audioStartPts,
audioEndPts,
audioTimeBaseNum,
audioTimeBaseDen);
}

torch::List<torch::Tensor> probe_video_from_memory(torch::Tensor input_video) {
return probeVideo(false, input_video, "");
}

torch::List<torch::Tensor> probeVideoFromFile(std::string videoPath) {
torch::List<torch::Tensor> probe_video_from_file(std::string videoPath) {
torch::Tensor dummy_input_video = torch::ones({0});
return probeVideo(true, dummy_input_video, videoPath);
}

} // namespace video_reader

TORCH_LIBRARY_FRAGMENT(video_reader, m) {
m.def("read_video_from_memory", video_reader::readVideoFromMemory);
m.def("read_video_from_file", video_reader::readVideoFromFile);
m.def("probe_video_from_memory", video_reader::probeVideoFromMemory);
m.def("probe_video_from_file", video_reader::probeVideoFromFile);
m.def("read_video_from_memory", read_video_from_memory);
m.def("read_video_from_file", read_video_from_file);
m.def("probe_video_from_memory", probe_video_from_memory);
m.def("probe_video_from_file", probe_video_from_file);
}

} // namespace video_reader
} // namespace vision
Loading

0 comments on commit e95a3d2

Please sign in to comment.