Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Video support in torchvision #4392

Closed
4 tasks done
prabhat00155 opened this issue Sep 10, 2021 · 23 comments
Closed
4 tasks done

Video support in torchvision #4392

prabhat00155 opened this issue Sep 10, 2021 · 23 comments
Assignees

Comments

@prabhat00155
Copy link
Contributor

prabhat00155 commented Sep 10, 2021

This is to keep track of the major work items for video support in TorchVision in this half.

  • Building torchvision with ffmpeg.
  • Updating the documentation to clearly state the steps to use video api.
  • Benchmarking performance.
  • Video GPU Decoding
@vadimkantorov
Copy link

vadimkantorov commented Sep 14, 2021

It would be nice that ffmpeg support features two README items/examples:

  1. How to install a working torchvision + ffmpeg combination from conda
  2. How to compile torchvision from sources with ffmpeg (with two ffmpeg options: one is installable from conda; another is ffmpeg also built from sources in the most basic variant, for decoding it should be enough)

Same for libpng/libjpg - it should be clear if pip-installed / conda-installed torchvision is already compiled with libpng /libjpeg-turbo. If not, will it pick up dynamic libraries? If so, how to install them properly with conda? (if they are not available on the system itself) It should also be checkable in runtime if they are supported and maybe worth shipping a simple smoke test:
python -m torchvision.io.test to check what backends are available

@bjuncek
Copy link
Contributor

bjuncek commented Sep 30, 2021

Regarding benchmarking:
I've done the first steps towards it in this repo.

The main for me are that:

  1. in straight-up video reading, it is not obviously clear that GPU decoding is actually faster than CPU decoding [1]. In chat with people, it seems like there is benefit of GPU decoding in end-to-end pipelines where decoded frames can be directly consumed by the GPU. But that raises the questions whether we can actually support this bc our transforms are afaik done on CPUs, so you'd not gain much. Similar things have been observed by Mike from pyAV as well in their docs [2]

test

  1. As with our VideoReader and older read_video APIs, each come with their set of advantages and trade-offs. You can simply check the demo.ipynb in the repo above to see that different readers have a different approach to how the API should look like. It is also useful to note that some libraries (namely decord), actually count on the approximations in order to implement the VideoReader API that allowes getting of the k-th frame. Many filetypes (for example stock HMDB51 videos) do not contain extradata metadata; when approximating number of frames, we found that up to 4% of the frames can be missed/duplicated in these situations.

[1] note, these comparisons were done on the machine with above-average CPU and what used to be quite competitive GPU. Running it on different hardware would probably provide more results.

[2] https://pyav.org/docs/develop/overview/about.html : see section "Unsuported features"

@vadimkantorov
Copy link

(I think there is also a push for more and more transforms supporting GPU and relying on just PyTorch ops)

@vadimkantorov
Copy link

  • from my experience this is a very bad idea of having anything relying on getting exactly kth frame. because this depends heavily on the decoder, its queue implementation, and on its handling of missing frames / pts (some decoder may choose to fill in missing frames, some decoders may not do it). I was bitten by it exacty when dealing with HMDB. From what I understood, the only quasi-reliable addressing when dealing with real-world video is time in seconds that's tied to more reliable pts addressing in frame structures

@prabhat00155
Copy link
Contributor Author

CPU vs GPU decoding using ffmpeg command line. CPU seems to be faster. This is a 201 MB video(duration: 11:11).

time ffmpeg -y -vsync 0 -c:v h264 -i eFcpy2RClJQ.mp4 output11.yuv	23.027s	cpu
time ffmpeg -y -vsync 0 -c:v h264 -i eFcpy2RClJQ.mp4 output12.yuv	20.442s	cpu
time ffmpeg -y -vsync 0 -c:v h264 -i eFcpy2RClJQ.mp4 output13.yuv	21.597s	cpu
time ffmpeg -y -vsync 0 -c:v h264 -i eFcpy2RClJQ.mp4 output14.yuv	22.682s	cpu
time ffmpeg -y -vsync 0 -c:v h264 -i eFcpy2RClJQ.mp4 output15.yuv	20.872s	cpu
		
time ffmpeg -hwaccel cuda -y -vsync 0 -c:v h264_cuvid -i eFcpy2RClJQ.mp4 output21.yuv	31.217s	gpu
time ffmpeg -hwaccel cuda -y -vsync 0 -c:v h264_cuvid -i eFcpy2RClJQ.mp4 output22.yuv	30.595s	gpu
time ffmpeg -hwaccel cuda -y -vsync 0 -c:v h264_cuvid -i eFcpy2RClJQ.mp4 output23.yuv	31.505s	gpu
time ffmpeg -hwaccel cuda -y -vsync 0 -c:v h264_cuvid -i eFcpy2RClJQ.mp4 output24.yuv	30.294s	gpu
time ffmpeg -hwaccel cuda -y -vsync 0 -c:v h264_cuvid -i eFcpy2RClJQ.mp4 output25.yuv	30.790s	gpu

Excerpt from ffmpeg's man page: Note that most acceleration methods are intended for playback and will not be faster than software decoding on modern CPUs. Additionally, ffmpeg will usually need to copy the decoded frames from the GPU memory into the system memory, resulting in further performance loss. This option is thus mainly useful for testing.

@nairbv
Copy link
Contributor

nairbv commented Oct 13, 2021

ffmpeg will usually need to copy the decoded frames from the GPU memory into the system memory

for the ffmpeg command with the output yuv file, I think we're comparing:
Read encoded file from disk -> decode on cpu -> write to disk
to:
Read encoded file from disk -> copy to gpu -> decode -> copy (larger, full frames) to cpu -> write to disk

and so in this case I think it makes sense that we'd expect the second path to be slower. It's using accelerated hardware, but doing far more total IO.

it seems like there is benefit of GPU decoding in end-to-end pipelines where decoded frames can be directly consumed by the GPU

Concretely, an actual ML pipeline for some classification model might look like:

read encoded file -> decode on cpu -> copy full frames to GPU -> perform image transforms on frames -> run through GPU model -> copy a "class" back to CPU.
vs:
Read encoded file -> copy (smaller encoded file) to GPU -> perform image transforms on frames -> run through GPU model -> copy a "class" back to CPU.

In this later example I'd guess that the second path (GPU) might be faster. It's using accelerated hardware for the decoding but also doing less total IO than the CPU-decoding path.

But that raises the questions whether we can actually support this bc our transforms are afaik done on CPUs, so you'd not gain much.

ah I think I'm assuming our GPU image transforms can be applied to video frames, but maybe there's some additional work we need to do to make this happen, or additional video-specific transforms we'll need?

@prabhat00155
Copy link
Contributor Author

for the ffmpeg command with the output yuv file, I think we're comparing: Read encoded file from disk -> decode on cpu -> write to disk to: Read encoded file from disk -> copy to gpu -> decode -> copy (larger, full frames) to cpu -> write to disk

Yes, that's correct. This includes the time taken to copy frames between gpu and system memory.

Concretely, an actual ML pipeline for some classification model might look like:
read encoded file -> decode on cpu -> copy full frames to GPU -> perform image transforms on frames -> run through GPU model -> copy a "class" back to CPU. vs: Read encoded file -> copy (smaller encoded file) to GPU -> perform image transforms on frames -> run through GPU model -> copy a "class" back to CPU.

In this later example I'd guess that the second path (GPU) might be faster. It's using accelerated hardware for the decoding but also doing less total IO than the CPU-decoding path.

To test this, we would need to add a prototype to torchvision. pyAV doesn't support hardware accelerated decoding. I am looking at https://github.com/NVIDIA/VideoProcessingFramework, but haven't managed to successfully build it yet.

@bjuncek
Copy link
Contributor

bjuncek commented Oct 19, 2021

I've been looking into re-running this with decord, where there is a (somewhat opaque) bridge directly to pytorch tensors. Note that we don't have control for under the hood implementation, but in theory it shouldn't copy the tensors unnecessarily. The code reads and decode entire videos, and returns them into tensors; I've added a blocked autorange comments to serve as a warm up, and am relying on torch.benchmark in order to get the results. Note that I think decord does something weird with multi-threading bc I get utilisation on all my cores when running the benchmark. Repro code is or will be pushed soon to the repo mentioned above.

CPU decoding

Code
# setup
from decord import VideoReader, cpu
import decord
decord.bridge.set_bridge('torch')

# main 
def read_video_cpu():
    path_to_video = "./videos/WUzgd7C1pWA.mp4"
    vr = VideoReader(path_to_video, ctx=cpu(0))
    fr_tensor = vr.get_batch(list(range(len(vr))))
    return fr_tensor

Although torch.benchmark specifies one thread, empirically in htop, I can see that my code spins up all 48 threads to 11-18%. This potentially might be due to some weird multithreading happening within benchmark itself. Overall, the results are:

<torch.utils.benchmark.utils.common.Measurement object at 0x7fc265aaa790>
read_video_cpu()
setup:
  from decord import VideoReader, cpu
  import decord
  decord.bridge.set_bridge('torch')
  from __main__ import read_video_cpu

  84.18 ms
  1 measurement, 1000 runs , 1 thread

GPU decoding

Code
  from decord import VideoReader, gpu
import decord
decord.bridge.set_bridge('torch')
  
  def read_video_gpu():
    path_to_video = "./videos/WUzgd7C1pWA.mp4"
    vr = VideoReader(path_to_video, ctx=gpu(0))
    fr_tensor = vr.get_batch(list(range(len(vr))))
    return fr_tensor

In nvidia-smi, I can see steady 5% utilization (quadro rtx8000) with about 800MiB allocated (note, there is some torch overhead as well).

Results are as follows:

<torch.utils.benchmark.utils.common.Measurement object at 0x7fc2b86d7fd0>
read_video_gpu()
setup:
  from decord import VideoReader, gpu
  import decord
  decord.bridge.set_bridge('torch')
  from __main__ import read_video_gpu

  197.61 ms
  1 measurement, 1000 runs , 1 thread

Over multiple videos and multiple threads

Code
results = []
for file in os.listdir("./videos"):
    label = 'Whole video read'
    sub_label = Path(file).name
    if file in ["README", ".ipynb_checkpoints", "avadl.py"]:
        print(f"Skipping {file}")
        continue
    print(file)
    file = os.path.join("./videos/", file)
    for num_threads in [1, 4, 16, 32]:
        results.append(benchmark.Timer(
            stmt='read_video_gpu(x)',
            setup=stfug,
            globals={'x': file},
            num_threads=num_threads,
            label=label,
            sub_label=sub_label,
            description='gpu',
        ).blocked_autorange(min_run_time=10))
        results.append(benchmark.Timer(
            stmt='read_video_cpu(x)',
            setup=stfuc,
            globals={'x': file},
            num_threads=num_threads,
            label=label,
            sub_label=sub_label,
            description='cpu',
        ).blocked_autorange(min_run_time=10))

compare = benchmark.Compare(results)
compare.print()
[----------------------------- Whole video read ----------------------------]
                                                           |   gpu   |   cpu 
1 threads: ------------------------------------------------------------------
      v_SoccerJuggling_g23_c01.avi                         |   95.0  |   47.2
      SOX5yA1l24A.mp4                                      |  151.4  |   78.6
      R6llTwEh07w.mp4                                      |  151.3  |  139.7
      RATRACE_wave_f_nm_np1_fr_goo_37.avi                  |   46.4  |   57.0
      TrumanShow_wave_f_nm_np1_fr_med_26.avi               |   59.0  |   35.8
      SchoolRulesHowTheyHelpUs_wave_f_nm_np1_ba_med_0.avi  |   62.2  |   40.8
      v_SoccerJuggling_g24_c01.avi                         |   99.2  |   49.8
      WUzgd7C1pWA.mp4                                      |  157.4  |   86.6
4 threads: ------------------------------------------------------------------
      v_SoccerJuggling_g23_c01.avi                         |   95.3  |   47.0
      SOX5yA1l24A.mp4                                      |  156.7  |   83.2
      R6llTwEh07w.mp4                                      |  169.6  |  142.0
      RATRACE_wave_f_nm_np1_fr_goo_37.avi                  |   72.5  |   56.7
      TrumanShow_wave_f_nm_np1_fr_med_26.avi               |   59.0  |   36.2
      SchoolRulesHowTheyHelpUs_wave_f_nm_np1_ba_med_0.avi  |   62.7  |   41.2
      v_SoccerJuggling_g24_c01.avi                         |  123.4  |   49.8
      WUzgd7C1pWA.mp4                                      |  148.1  |   83.8
16 threads: -----------------------------------------------------------------
      v_SoccerJuggling_g23_c01.avi                         |   95.8  |   46.5
      SOX5yA1l24A.mp4                                      |  157.7  |   85.6
      R6llTwEh07w.mp4                                      |  167.5  |   76.4
      RATRACE_wave_f_nm_np1_fr_goo_37.avi                  |   71.9  |   57.4
      TrumanShow_wave_f_nm_np1_fr_med_26.avi               |   59.8  |   36.2
      SchoolRulesHowTheyHelpUs_wave_f_nm_np1_ba_med_0.avi  |   42.4  |   41.3
      v_SoccerJuggling_g24_c01.avi                         |  101.9  |   49.7
      WUzgd7C1pWA.mp4                                      |  148.8  |   81.7
32 threads: -----------------------------------------------------------------
      v_SoccerJuggling_g23_c01.avi                         |   94.7  |   46.1
      SOX5yA1l24A.mp4                                      |  167.5  |   85.5
      R6llTwEh07w.mp4                                      |  143.0  |   82.5
      RATRACE_wave_f_nm_np1_fr_goo_37.avi                  |   70.9  |   57.1
      TrumanShow_wave_f_nm_np1_fr_med_26.avi               |   60.4  |   35.6
      SchoolRulesHowTheyHelpUs_wave_f_nm_np1_ba_med_0.avi  |   64.3  |   41.2
      v_SoccerJuggling_g24_c01.avi                         |  123.0  |   49.8
      WUzgd7C1pWA.mp4                                      |  140.2  |   81.0

Times are in milliseconds (ms).

@prabhat00155
Copy link
Contributor Author

prabhat00155 commented Nov 1, 2021

Video decoding benchmarking results using ffmpeg APIs(for CPU decoding) and NVDECODE APIs(for GPU decoding). Please note that the total time(averaged across 5 runs) includes the time to save the output. This would add extra copy time on GPUs from cuda memory to system memory.

Videos                        | Source        | Size |Duration(in s)|CPU(in s)|GPU(in s)|CPU/GPU
------------------------------|---------------|------|--------------|---------|---------|-------
8hFYAVhnAg4_000001_000011.mp4 | kinetics      | 5.6M | 10           | 1.9966  | 1.4662  | 1.361751
5bljAEyjYvQ_000000_000010.mp4 | kinetics      | 131K | 2            | 0.0488  | 0.902   | 0.054102
4nWUEQXQ0xA_000000_000010.mp4 | kinetics      | 969K | 10           | 0.3842  | 1.039.  | 0.369779
ARb0rBemhFE_000025_000035.mp4 | kinetics      | 783K | 10           | 0.2814  | 1.2534  | 0.224509
m6YMDxAWigs_000024_000034.mp4 | kinetics      | 4.8M | 10           | 1.8334  | 1.469   | 1.24806
7GOB1u0scWw_000426_000436.avi | kinetics      | 39M  | 10           | 0.9252  | 1.33    | 0.695639
LzHGPWvaJQc_000300_000310.avi | kinetics      | 12M  | 10           | 0.5734  | 1.2822  | 0.4472
SOX5yA1l24A.mp4               | torchvision   | 548K | 11           | 0.2108  | 1.2126  | 0.173841
R6llTwEh07w.mp4               | torchvision   | 844K | 10           | 0.2416  | 1.0272  | 0.235202
eFcpy2RClJQ.mp4               | activityNet200| 201M | 671          | 85.5172 | 33.8556 | 2.52594
YtgiDWEY_1A.mp4.              | activityNet200| 185M | 755          | 96.5216 | 38.3324 | 2.518016
6I1aP4O04R8.mp4               | activityNet200| 159M | 235          | 38.6964 | 13.6624 | 2.832328
b993qWuMRBA.mp4               | activityNet200| 136M | 490          | 59.068  | 20.7052 | 2.85281
aOzMA2rpWEw.mp4               | activityNet200| 131M | 537          | 73.4712 | 30.7392 | 2.390147
BYLxSOPFOuc.mp4               | activityNet200| 109M | 489          | 66.029  | 28.2638 | 2.336169
H-5nHSHwFOk.mp4               | activityNet200| 103M | 745          | 67.9444 | 34.2386 | 1.984439
16725zS5kVM.mp4               | activityNet200| 90M  | 236          | 34.13   | 13.5688 | 2.515329
lPCl1ZYH2xI.mp4               | activityNet200| 90M  | 360          | 40.8122 | 20.4978 | 1.991053
QWRGRAod0no.mp4               | activityNet200| 81M  | 227          | 29.1128 | 11.613  | 2.506915

@prabhat00155
Copy link
Contributor Author

prabhat00155 commented Nov 2, 2021

The following table does not include the time to save the decoding output.

Videos                        | Source         |Size  |Duration(in s)|CPU(in s)|GPU(in s)| CPU/GPU
----------------------------- | -------------- | ---- | ------------ | ------- | ------- | ---------
8hFYAVhnAg4_000001_000011.mp4 | kinetics       | 5.6M | 10           | 1.5692  | 1.2692  | 1.236369367
5bljAEyjYvQ_000000_000010.mp4 | kinetics       | 131K | 2            | 0.0412  | 0.873   | 0.047193585
4nWUEQXQ0xA_000000_000010.mp4 | kinetics       | 969K | 10           | 0.2952  | 0.9698  | 0.304392658
ARb0rBemhFE_000025_000035.mp4 | kinetics       | 783K | 10           | 0.2076  | 0.9226  | 0.225016258
m6YMDxAWigs_000024_000034.mp4 | kinetics       | 4.8M | 10           | 1.382   | 1.1076  | 1.247742867
7GOB1u0scWw_000426_000436.avi | kinetics       | 39M  | 10           | 0.715   | 1.0718  | 0.667102071
LzHGPWvaJQc_000300_000310.avi | kinetics       | 12M  | 10           | 0.3814  | 1.0092  | 0.377923107
SOX5yA1l24A.mp4               | torchvision    | 548K | 11           | 0.1534  | 0.9325  | 0.164504021
R6llTwEh07w.mp4               | torchvision    | 844K | 10           | 0.2022  | 0.9728  | 0.207853618
eFcpy2RClJQ.mp4               | activityNet200 | 201M | 671          | 63.504  | 14.4204 | 4.403761338
YtgiDWEY_1A.mp4               | activityNet200 | 185M | 755          | 71.7632 | 23.8412 | 3.01004983
6I1aP4O04R8.mp4               | activityNet200 | 159M | 235          | 29.9866 | 6.1586  | 4.86906115
b993qWuMRBA.mp4               | activityNet200 | 136M | 490          | 44.984  | 11.729  | 3.835280075
aOzMA2rpWEw.mp4               | activityNet200 | 131M | 537          | 56.4512 | 18.4192 | 3.064801946
BYLxSOPFOuc.mp4               | activityNet200 | 109M | 489          | 50.603  | 15.8836 | 3.185864665
H-5nHSHwFOk.mp4               | activityNet200 | 103M | 745          | 45.6914 | 15.9778 | 2.859680306
16725zS5kVM.mp4               | activityNet200 | 90M  | 236          | 26.4814 | 5.6624  | 4.676709522
lPCl1ZYH2xI.mp4               | activityNet200 | 90M  | 360          | 27.526  | 8.1726  | 3.368083596
QWRGRAod0no.mp4               | activityNet200 | 81M  | 227          | 20.776  | 4.7256  | 4.396478754

The above two tables show clear benefits of GPU decoding over CPU decoding especially for longer videos(3x-5x improvement in speed).

@prabhat00155
Copy link
Contributor Author

prabhat00155 commented Nov 12, 2021

The following only measures the time taken in decoding different number of frames from a given video on both CPU and GPU. GPU decoding is much faster across the board.
Screenshot 2021-11-12 at 00 07 36

@prabhat00155
Copy link
Contributor Author

prabhat00155 commented Nov 12, 2021

The following measures the time taken in decoding(just measures the time spent in the decoding operation) whole videos. GPU decoding again performs much better the CPU decoding.
Screenshot 2021-11-12 at 15 01 45

Note: Colour coding for last column:

  • Green: GPU decoding faster
  • Red: CPU decoding faster

@bjuncek
Copy link
Contributor

bjuncek commented Nov 12, 2021

These look great; just for clarification, how come are the number on kinetics so much different now than in the tables from 10 days ago. What have changed?
(to clarify, I still get similar numbers to the tables above, i.e. around 20% improvement over single-threaded cpu decoding on h264x-high with decord)

image

@prabhat00155
Copy link
Contributor Author

prabhat00155 commented Nov 15, 2021

These look great; just for clarification, how come are the number on kinetics so much different now than in the tables from 10 days ago. What have changed? (to clarify, I still get similar numbers to the tables above, i.e. around 20% improvement over single-threaded cpu decoding on h264x-high with decord)

@bjuncek The last two tables only measure the time taken in the decoding operation. The tables I reported previously measured the time taken by the C++ program to execute which includes the time taken to read the input file, cuda initialisation, copy from system memory to cuda memory, etc.

@yuzhms
Copy link

yuzhms commented Mar 11, 2022

Which version of ffmpeg should be used? I try ffmpeg4.2.3 from conda-forge, but there is no bsf.h. Using ffmpeg >4.3.0 (4.3.0, 4.3.1, 4.3.2, 4.4.0, 4.4.1) from conda-forge encounter segmentation fault. #3367 (comment)
@prabhat00155

@prabhat00155
Copy link
Contributor Author

Are you using GPU decoding? If not, you don't really need bfs.h.
Also, last time I tried, I was able to build with latest ffmpeg from conda-forge: conda install -c conda-forge ffmpeg with Python 3.8, without any issues.

@yuzhms
Copy link

yuzhms commented Mar 14, 2022

Thanks for your reply. I want to use GPU decoding, so I need bfs.h . I try the same conda install -c conda-forge ffmpeg with Python 3.8, but it has segmentation fault.
Could you share the full building script or a Docker image which has already installed the gpu decoding version of torchvision?

Many thanks!

@yuzhms
Copy link

yuzhms commented Mar 14, 2022

It works well when using single worker. and the segmentation fault only happens when the number of workers > 0. Have you done such testing?

@prabhat00155
Copy link
Contributor Author

@yuzhms GPU decoding would only work with worker = 0. By the way, I am curious to know how are you testing this.

@yuzhms
Copy link

yuzhms commented Mar 14, 2022

So me build is fine. I got it.

But worker =0 with GPU decoding is slower than worker=16 with CPU decoding. What is the meaning of this new feature? Do you have plan to extend GPU decoding to worker>0?

I do not test this separately. Actually, I am using pyslowfast (https://github.com/facebookresearch/SlowFast) doing some action recognition research, and I write a decoder using the new VideoReader API. I tested the training speed based on that.

@prabhat00155
Copy link
Contributor Author

Yes, you are right. I had compared CPU decoding with worker=10 vs GPU decoding with worker=0 and CPU decoding was faster. GPU decoding would need to support multiple workers to improve the speed. You are welcome to contribute.

@yuzhms
Copy link

yuzhms commented Mar 14, 2022

Thanks anyway. Perhaps you should mention this point in the documentation, otherwise people will expect it works for multiple workers.

@prabhat00155
Copy link
Contributor Author

prabhat00155 commented Mar 14, 2022

We haven't yet added a reference script that uses the VideoReader API. Once we do that, we should definitely add this information in the documentation for the reference script.
FYI: @bjuncek

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants