Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance] first run 10x slow than the following runs with CUDAProvider #17443

Closed
chiaitian opened this issue Sep 7, 2023 · 2 comments
Closed
Labels
ep:CUDA issues related to the CUDA execution provider

Comments

@chiaitian
Copy link

Describe the issue

My model runs very slow in the first run with CUDAProvider, the following runs are normal. And if run another input shape, the new shape first run also very slow. CpuProvider does not have this issue. My model has a loop node, maybe this cause the issue.

To reproduce

import onnxruntime
import time
import numpy as np

input1 = np.random.rand(30,1,118,504).astype(np.float32)
input2 = np.random.rand(30,1,160,504).astype(np.float32)

bm_onnx_path = 'test.onnx'

onnx_session = onnxruntime.InferenceSession(bm_onnx_path, 
            providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
print(onnx_session.get_providers())

input_dict1 = {'images': input1} #
input_dict2 = {'images': input2} #

t1 = time.time()
onnx_pred1 = onnx_session.run(["word_probs", "word_ids"], input_dict1)[1]
st = time.time()
print("first run time ", (st - t1))
for i in range(10):
    onnx_pred1 = onnx_session.run(["word_probs", "word_ids"], input_dict1)[1]
et = time.time()
print("average", (et-st)/10)
onnx_pred2 = onnx_session.run(["word_probs", "word_ids"], input_dict2)[1]
st = time.time()
print("input2 first run time ", (st - et))
for i in range(10):
    onnx_pred2 = onnx_session.run(["word_probs", "word_ids"], input_dict2)[1]
et = time.time()
print("average", (et-st)/10)

Urgency

No response

Platform

Linux

OS Version

Ubuntu 18.04.6 LTS

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

onnxruntime-gpu 1.15.1

ONNX Runtime API

Python

Architecture

X64

Execution Provider

Default CPU, CUDA

Execution Provider Library Version

CUDA 11.7

Model File

https://drive.google.com/drive/folders/10FhnkqPc6FLVCF6wbvQEZkpg6ptjUyJc?usp=sharing

Is this a quantized model?

No

@github-actions github-actions bot added the ep:CUDA issues related to the CUDA execution provider label Sep 7, 2023
@skottmckay
Copy link
Contributor

It's normal and expected. On the first run with a given combination of input shapes the allocations required to execute the model are traced. On the second run, a single block of memory large enough to provide all these allocations is created, and offsets into that block are used during model execution. This avoids allocation/free calls.

This happens for each set of input shapes as the allocations required to execution the model are dependent on these. e.g. if batch size is 5 in one call and 10 in the second the allocations will be twice as large.

Both CPU and CUDA are doing the same thing on the first run, but allocation and free on CUDA are slower so it's more noticeable.

@hariharans29
Copy link
Member

hariharans29 commented Sep 7, 2023

Also as to why the first run for a new input shape is slow, please take a look at the following for an explanation and options to mitigate that phenomenon. Use the options listed only if you expect the model input to keep changing across runs, if it is expected to be fixed, the default setting is best suited for that.

#6978 (comment)

#12955

https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#cudnn_conv_algo_search

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:CUDA issues related to the CUDA execution provider
Projects
None yet
Development

No branches or pull requests

3 participants