-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Video support in torchvision #4392
Comments
It would be nice that ffmpeg support features two README items/examples:
Same for libpng/libjpg - it should be clear if pip-installed / conda-installed torchvision is already compiled with libpng /libjpeg-turbo. If not, will it pick up dynamic libraries? If so, how to install them properly with conda? (if they are not available on the system itself) It should also be checkable in runtime if they are supported and maybe worth shipping a simple smoke test: |
Regarding benchmarking: The main for me are that:
[1] note, these comparisons were done on the machine with above-average CPU and what used to be quite competitive GPU. Running it on different hardware would probably provide more results. [2] https://pyav.org/docs/develop/overview/about.html : see section "Unsuported features" |
(I think there is also a push for more and more transforms supporting GPU and relying on just PyTorch ops) |
|
CPU vs GPU decoding using ffmpeg command line. CPU seems to be faster. This is a 201 MB video(duration: 11:11).
Excerpt from ffmpeg's man page: |
for the ffmpeg command with the output yuv file, I think we're comparing: and so in this case I think it makes sense that we'd expect the second path to be slower. It's using accelerated hardware, but doing far more total IO.
Concretely, an actual ML pipeline for some classification model might look like:
In this later example I'd guess that the second path (GPU) might be faster. It's using accelerated hardware for the decoding but also doing less total IO than the CPU-decoding path.
ah I think I'm assuming our GPU image transforms can be applied to video frames, but maybe there's some additional work we need to do to make this happen, or additional video-specific transforms we'll need? |
Yes, that's correct. This includes the time taken to copy frames between gpu and system memory.
To test this, we would need to add a prototype to torchvision. pyAV doesn't support hardware accelerated decoding. I am looking at https://github.com/NVIDIA/VideoProcessingFramework, but haven't managed to successfully build it yet. |
I've been looking into re-running this with decord, where there is a (somewhat opaque) bridge directly to pytorch tensors. Note that we don't have control for under the hood implementation, but in theory it shouldn't copy the tensors unnecessarily. The code reads and decode entire videos, and returns them into tensors; I've added a blocked autorange comments to serve as a warm up, and am relying on CPU decodingCode
Although torch.benchmark specifies one thread, empirically in htop, I can see that my code spins up all 48 threads to 11-18%. This potentially might be due to some weird multithreading happening within benchmark itself. Overall, the results are:
GPU decodingCode
In nvidia-smi, I can see steady 5% utilization (quadro rtx8000) with about 800MiB allocated (note, there is some torch overhead as well). Results are as follows:
Over multiple videos and multiple threadsCode
|
Video decoding benchmarking results using ffmpeg APIs(for CPU decoding) and NVDECODE APIs(for GPU decoding). Please note that the total time(averaged across 5 runs) includes the time to save the output. This would add extra copy time on GPUs from cuda memory to system memory.
|
The following table does not include the time to save the decoding output.
The above two tables show clear benefits of GPU decoding over CPU decoding especially for longer videos(3x-5x improvement in speed). |
These look great; just for clarification, how come are the number on kinetics so much different now than in the tables from 10 days ago. What have changed? |
@bjuncek The last two tables only measure the time taken in the decoding operation. The tables I reported previously measured the time taken by the C++ program to execute which includes the time taken to read the input file, cuda initialisation, copy from system memory to cuda memory, etc. |
Which version of ffmpeg should be used? I try ffmpeg4.2.3 from conda-forge, but there is no bsf.h. Using ffmpeg >4.3.0 (4.3.0, 4.3.1, 4.3.2, 4.4.0, 4.4.1) from conda-forge encounter segmentation fault. #3367 (comment) |
Are you using GPU decoding? If not, you don't really need |
Thanks for your reply. I want to use GPU decoding, so I need Many thanks! |
It works well when using single worker. and the segmentation fault only happens when the number of workers > 0. Have you done such testing? |
@yuzhms GPU decoding would only work with worker = 0. By the way, I am curious to know how are you testing this. |
So me build is fine. I got it. But worker =0 with GPU decoding is slower than worker=16 with CPU decoding. What is the meaning of this new feature? Do you have plan to extend GPU decoding to worker>0? I do not test this separately. Actually, I am using pyslowfast (https://github.com/facebookresearch/SlowFast) doing some action recognition research, and I write a decoder using the new VideoReader API. I tested the training speed based on that. |
Yes, you are right. I had compared CPU decoding with worker=10 vs GPU decoding with worker=0 and CPU decoding was faster. GPU decoding would need to support multiple workers to improve the speed. You are welcome to contribute. |
Thanks anyway. Perhaps you should mention this point in the documentation, otherwise people will expect it works for multiple workers. |
We haven't yet added a reference script that uses the VideoReader API. Once we do that, we should definitely add this information in the documentation for the reference script. |
This is to keep track of the major work items for video support in TorchVision in this half.
The text was updated successfully, but these errors were encountered: