Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dynamic shape input is much slower than fixed shape input in gpu #6978

Open
Emiyassstar opened this issue Mar 11, 2021 · 6 comments
Open

dynamic shape input is much slower than fixed shape input in gpu #6978

Emiyassstar opened this issue Mar 11, 2021 · 6 comments
Labels
ep:CUDA issues related to the CUDA execution provider

Comments

@Emiyassstar
Copy link

Describe the bug
dynamic shape input is much slower than fixed shape input in gpu

System information

  • OS Platform centos

  • ONNX Runtime installed from source

  • ONNX Runtime version: 1.5

  • GCC/Compiler version : 7.3

  • CUDA/cuDNN version:10.1/7.6

  • GPU model and memory:

To Reproduce
Using different dynamic shape input, such as 8 * dynamic * 100, is 10 times slower than 8 * 1000 * 100 on the GPU runtime ,whats the problem

@HectorSVC HectorSVC added the ep:CUDA issues related to the CUDA execution provider label Mar 11, 2021
@HectorSVC
Copy link
Contributor

Could you do some more tests, such as: with dynamic shape input, inference with input data with 8 * 1000 * 100, I suppose the only first inference is slow, following inference should be fast. Which means it's slow only for the first new shape it encountered.

@hariharans29
Copy link
Member

If only the first run with a specific input shape is slow and the subsequent runs with the same input shape are faster, this usually means you have a lot of Conv nodes in your model. By default, the logic in the CUDA EP for Conv is to query CuDNN to find the "most optimal" Conv algo for a given filter shape/input shape and cache it for subsequent runs. This query takes up the most time in the first run.

That is why for most dynamic shape models, users actually set this field (

OrtCudnnConvAlgoSearch cudnn_conv_algo_search; // cudnn conv algo search option
) to DEFAULT (the default value is actually EXHAUSTIVE - which asks CuDnn to search exhaustively for the best Conv algo). If DEFAULT is used, CuDNN is no longer queried and the default Conv algo is used. The trade-off involved is that the first run will now be much faster but the subsequent runs might be a bit slower simply because the default algo may not be the best algo for the filter/input shapes.

You can get access to this API by upgrading to 1.6 or the latest 1.7.

@tianleiwu
Copy link
Contributor

tianleiwu commented Mar 12, 2021

Another possible reason is that graph optimization need shape information to do some fusion, and in some case, static input could get some fusion while dynamic shape will not.

You can save the optimized model with a session option as mentioned in https://www.onnxruntime.ai/docs/resources/graph-optimizations.html to compare.

If this is the reason, you can use static input to get optimized model, then apply dynamic axes to the optimized model to get "dynamic" version of optimized model and use it in inference. That could walkaround the issue.

@Emiyassstar
Copy link
Author

yes,I found this, the first run with a specific input shape is slow ,and the subsequent is faster,i found my model has many depthwise conv and pointwise conv nodes ,and when i set cudnn_conv_algo_search to default ,the first run be faster but All sequences run much slower ,so In order to solve this problem ,i must set the sequences to several fixed lengths,Is there any better solution ?

here is my test
1.static shape, run time230s
2.dynamic shape and padding sequences to three fixed lengths ,run time 400s
3, dynamic shape with default conv , run time 300s

@askhade
Copy link
Contributor

askhade commented Mar 15, 2021

@Emiyassstar : Now that you have narrowed down the reason as Hari mentioned you have 2 options

  1. Spend more time in the first run to find the optimal algo and have the subsequent runs run much faster
  2. Set the also to default and have all the runs run at the same speed albeit the perf may not be the optimal that you can get for this model.
    Is there any reason none of these 2 solutions work for you? Can you elaborate why?

@xiaomaxiao
Copy link

xiaomaxiao commented Aug 29, 2024

the simplest way is to warmup , eg. re-run shape_1 shape_2.... all possible you need .

       for pad_w in [
            960, 896, 832, 768, 704, 640, 576, 512, 448, 384, 320, 256, 192, 128, 64
        ]:
         self.ort_rec_session.run([self.output_rec_name], {self.input_rec_name: np.zeros((1,3,48,pad_w), dtype=np.float32)})

this is my used for ocr

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:CUDA issues related to the CUDA execution provider
Projects
None yet
Development

No branches or pull requests

7 participants