-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ONNX Runtime much slower than PyTorch (2-3x slower) #12880
Comments
I don't think the device used by the pytorch export has any relevant to how ORT runs the model so I wouldn't worry about that. If you set the log severity to VERBOSE the node assignments will be printed out. Look for 'VerifyEachNodeIsAssignedToAnEp' in the output. so = onnxruntime.SessionOptions()
so.log_severity_level = 0
ort_session = onnxruntime.InferenceSession('model.onnx', so) Should produce output like this:
ORT will perform various optimizations when loading the model. If you want to view how that modifies the model you can use this script to save the model with the optimizations applied. Optimization levels are described here: https://onnxruntime.ai/docs/performance/graph-optimizations.html Are you binding the model input and outputs so they stay on GPU? Otherwise there's device copy between CPU and GPU that can significantly affect performance. |
Thank you for your answer
I looked into the node assignments and got this output:
So it does seem like some operations run on the CPU, though idk if this is the cause for the slow-down ?
I get the attached graph as output when running the optimisations. The weird thing is that the optimised model is even slower: I go from 350ms to 690ms per inference
I think that I'm doing it correctly, see code below:
Overall, I'm a bit stuck to find where the slow-down comes from.. Without resolving this we'll have to go with another solution than ONNX |
The Have you tried using nvprof and checking which kernel takes up the most time ? That is the best way to move forward with this. |
Since you have a Conv-heavy fp16 model and a card that supports tensor core operations, can you try this simple one-line update to your script - This is why I expect this to help your use-case: |
Sorry - overlooked that that script doesn't have a way to enable the CUDA EP when running. I'm guessing that results in CPU EP specific custom ops being inserted which leads to more of the model running on the CPU EP. The script is very simple though and you can do the equivalent (set optimization level and output filename in session options) manually if you want to see the optimized model for a session with the CUDA EP enabled. onnxruntime/tools/python/util/onnx_model_utils.py Lines 103 to 105 in 0b235b2
|
I was already using this actually.. (I used the same setup as here) |
I was already using |
Any difference is dependent on the model and the EPs that are enabled. If there are no internal ORT operators with CUDA implementations that apply to nodes the CUDA EP is taking there won't be a difference between 'basic' and 'extended'/'all'. |
Meet similar issue. My onnxruntime model is very close to pytorch model (less than pytorch model)。 |
My bad. I didn't notice that. Could you run nvprof against your script and just give us the results of that ? Also (if shareable), can you please give us the model ? |
Can you try running with ORTModule? You can just wrap your nn.Module model via from onnxruntime.training.ortmodule import ORTModule
new_model = ORTModule(model)
# ORTModule is also nn.Module so just use it with the original inputs
output = new_model(inputs) ORTModule has some optimization to reduce overhead in the use of IOBinding. If you observe ORTModule brining some speedup, then it's really IOBinding's problem. |
Why do you have so many Cast's in this reply's figure? ORT recently adds support for "strided" tensors, so I expect those Cast's are no-op's. If I am wrong and ORT does real computation on those Cast's, it could slow down significantly. To verify if |
@thomas-beznik, sure thing. If #9754 is the blocker, you probably need to build ORT from source. Note that you need a clean machine to avoid dependency interference for a clean build. Btw, you don't need ORTModule to do the float32 comparison. If you have time, please do float32 comparison first. Thank you. |
Got it, is there an easy way to undo the installation of ORTModule ? As now I cannot run normal ONNX inference either :/ |
First, pip uninstall How many inputs/outputs/operators do you have? If the number of inputs/outputs is at the same scale as the number of operators, IOBinding is super slow. |
I've attached the result from the profiling. I have a hard time understanding it unfortunately.. (this is without the ORTModule btw) |
I ran into a similar issue where an ONNX model was much slower than it's PyTorch counterpart on the GPU. I tried all the suggestions here including io_binding but nothing worked. To solve the issue, profiling was enabled via the following code: import onnxruntime as ort
opts = ort.SessionOptions()
opts.enable_profiling = True
session = ort.InferenceSession(path, opts,
["CUDAExecutionProvider", "CPUExecutionProvider"])
session.run(None, inputs) Once the program exited, a profiling JSON file was generated. I took a look at that to find the longest running nodes.
Skipping nodes for the full model run and session initialization, I was seeing nodes like this: The program was re-run with that setting changed (to either HEURISTIC or DEFAULT). import onnxruntime as ort
opts = ort.SessionOptions()
opts.enable_profiling = True
session = ort.InferenceSession(path, opts,
[("CUDAExecutionProvider", {"cudnn_conv_algo_search": "DEFAULT"}), "CPUExecutionProvider"])
session.run(None, inputs) This time the performance was equal to or even slightly better than the PyTorch model on the GPU. I'm not sure why ONNX defaults to an EXHAUSTIVE search. In reading similar code in PyTorch, it doesn't appear PyTorch does (looks like it defaults to what ONNX calls HEURISTIC) and that is the performance difference in my case. Hope this helps anyone running into performance issues one way or another. Looking at the original post, there were a lot of Conv operations, so it's worth a try. |
For UNet (diffusion) model, try the following setting for best performance:
@davidmezzetti, cuDNN convolution algo search only happens once. Even though it is slow in first inference run, the following run might be faster. You can use some warm up queries before serving user queries. |
My understanding is that cuDNN only caches the results when the input shape is static. I was able to confirm this same behavior with a Torch model having dynamic input shapes exported to ONNX. Benchmark mode in PyTorch is what ONNX calls EXHAUSTIVE and EXHAUSTIVE is the default ONNX setting per the documentation. PyTorch defaults to using I wrote an article with detailed steps on this comparison. This link also has a related discussion. |
Describe the bug
We built a UNet3D image segmentation model in PyTorch (based on this repo) and want to start distributing it. ONNX seemed like a good option as it allows us to compress our models and the dependencies needed to run them. As our models are large & slow, we need to run them on GPU
We were able to convert these models to ONNX, but noticed a significant slow-down of the inference (2-3x). The issue is that the timing is quite critical, and that our models are already relatively slow, so we can't afford more slow-downs
I'm running my comparison tests following what was done in this issue
I could use your help to better understand where the issue is coming from and if it is resolvable at all. What tests, settings, etc. can I try to see where the issue might be ?
Urgency
This is quite an urgent issue, we need to deliver our models to our clients in the coming month and will need to resolve to other solutions if we can't fix ONNX soon
System information
To Reproduce
The code for the model will be quite hard to extract, so I'll first try to describe the issue and what I've tested. I'm currently generating my model using:
The model that we are using uses the following operations:
When converting to ONNX, I could see some weird things in the graph (see the first screenshot):
cpu
instead ofcuda:0
like all other operations; what does this mean? Will ONNX runtime run these operations on the CPU?. See below for the partial output of the conversion withverbose = True
Mul
,Add
,Reshape
, etc. operations). Could this be the reason for the slowdown?I saw that Group normalisation wasn't directly supported by ONNX and thus thought that this might be the cause for the slow-down, I thus tried with an alternative model where I remove the group norm, which led to a nicer graph (see 2nd screenshot) and to less slow-down (from 3x slower to 2x slower). The slow-down is still significant though, and the
Slice
,Concat
, etc. operations still say that they occur on the cpu; are these then the issue?Overall it would be great to get some guidance on where the problem could be located: should we adapt our model architecture, the way of exporting to ONNX, etc. ? Is it even possible at all with a model like UNet3D ?
Thanks for the help !
Screenshots
The text was updated successfully, but these errors were encountered: