Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot build TensorT engine for classification models #115

Open
CecileGiang opened this issue Jul 12, 2022 · 2 comments
Open

Cannot build TensorT engine for classification models #115

CecileGiang opened this issue Jul 12, 2022 · 2 comments

Comments

@CecileGiang
Copy link

Hello,

I tried to reproduce the optimization process you decribed in your Accelerating GPT-2 model notebook demo for optimizing a model for ONNX Runtime and Triton server, but with a classification model (namely facebook/bart-large-mnli which I found on HuggingFace's hub).

However I get a problem when trying to build a TensorRT engine from the corresponding ONNX file: the task fails and the resulting engine is of type NoneType. I also tried using the TensorRT builder's build_engine method but the result is also of type NoneType.

TypeError: deserialize_cuda_engine(): incompatible function arguments. The following argument types are supported:
    1. (self: tensorrt.tensorrt.Runtime, serialized_engine: buffer) -> tensorrt.tensorrt.ICudaEngine

Invoked with: <tensorrt.tensorrt.Runtime object at 0x7ff47c047270>, None

After investigating a bit, I found out that this problem arises when initializing facebook/bart-large-mnli with the BartForSequenceClassification class, but not when initializing it with the BartForConditionalGeneration class, even though the model was specifically fine-tuned for the MNLI classification task.

The whole code is attached in a ipynb file and I ran it using Google Colab.

Could you please help me resolve it ? Thanks in advance for your help !

Versions:

  • Python: 3.7.13
  • transformers-deploy: 0.5.0
  • TensorRT: 8.4.1.5
  • Onnxruntime (GPU): 1.11.1
  • transformers: 4.20.1

To reproduce:

model_optimization.zip

@pommedeterresautee
Copy link
Member

Sorry for the delay, I was on vacation.

Can you run the conversion with trtexec and share the logs?
It provides meaningful error message unlike the Python wrapper.

@CecileGiang
Copy link
Author

CecileGiang commented Jul 26, 2022

Sorry for the delay, I was on vacation.

Can you run the conversion with trtexec and share the logs? It provides meaningful error message unlike the Python wrapper.

Hello, thank you very much for your reply. I'm sorry for the late answer, I was on vacation as well !
I ran the conversion with trtexec as you requested, and I get the following log:

&&&& RUNNING TensorRT.trtexec [TensorRT v8401] # ./trtexec --onnx=/content/drive/MyDrive/model.onnx --minShapes=input:1x128 --optShapes=input:1x128 --maxShapes=input:1x128 --shapes=input:1x128 --best --workspace=10000 --saveEngine=test.plan
[07/26/2022-04:30:36] [W] --workspace flag has been deprecated by --memPoolSize flag.
[07/26/2022-04:30:36] [I] === Model Options ===
[07/26/2022-04:30:36] [I] Format: ONNX
[07/26/2022-04:30:36] [I] Model: /content/drive/MyDrive/model.onnx
[07/26/2022-04:30:36] [I] Output:
[07/26/2022-04:30:36] [I] === Build Options ===
[07/26/2022-04:30:36] [I] Max batch: explicit batch
[07/26/2022-04:30:36] [I] Memory Pools: workspace: 10000 MiB, dlaSRAM: default, dlaLocalDRAM: default, dlaGlobalDRAM: default
[07/26/2022-04:30:36] [I] minTiming: 1
[07/26/2022-04:30:36] [I] avgTiming: 8
[07/26/2022-04:30:36] [I] Precision: FP32+FP16+INT8
[07/26/2022-04:30:36] [I] LayerPrecisions: 
[07/26/2022-04:30:36] [I] Calibration: Dynamic
[07/26/2022-04:30:36] [I] Refit: Disabled
[07/26/2022-04:30:36] [I] Sparsity: Disabled
[07/26/2022-04:30:36] [I] Safe mode: Disabled
[07/26/2022-04:30:36] [I] DirectIO mode: Disabled
[07/26/2022-04:30:36] [I] Restricted mode: Disabled
[07/26/2022-04:30:36] [I] Build only: Disabled
[07/26/2022-04:30:36] [I] Save engine: test.plan
[07/26/2022-04:30:36] [I] Load engine: 
[07/26/2022-04:30:36] [I] Profiling verbosity: 0
[07/26/2022-04:30:36] [I] Tactic sources: Using default tactic sources
[07/26/2022-04:30:36] [I] timingCacheMode: local
[07/26/2022-04:30:36] [I] timingCacheFile: 
[07/26/2022-04:30:36] [I] Input(s)s format: fp32:CHW
[07/26/2022-04:30:36] [I] Output(s)s format: fp32:CHW
[07/26/2022-04:30:36] [I] Input build shape: input=1x128+1x128+1x128
[07/26/2022-04:30:36] [I] Input calibration shapes: model
[07/26/2022-04:30:36] [I] === System Options ===
[07/26/2022-04:30:36] [I] Device: 0
[07/26/2022-04:30:36] [I] DLACore: 
[07/26/2022-04:30:36] [I] Plugins:
[07/26/2022-04:30:36] [I] === Inference Options ===
[07/26/2022-04:30:36] [I] Batch: Explicit
[07/26/2022-04:30:36] [I] Input inference shape: input=1x128
[07/26/2022-04:30:36] [I] Iterations: 10
[07/26/2022-04:30:36] [I] Duration: 3s (+ 200ms warm up)
[07/26/2022-04:30:36] [I] Sleep time: 0ms
[07/26/2022-04:30:36] [I] Idle time: 0ms
[07/26/2022-04:30:36] [I] Streams: 1
[07/26/2022-04:30:36] [I] ExposeDMA: Disabled
[07/26/2022-04:30:36] [I] Data transfers: Enabled
[07/26/2022-04:30:36] [I] Spin-wait: Disabled
[07/26/2022-04:30:36] [I] Multithreading: Disabled
[07/26/2022-04:30:36] [I] CUDA Graph: Disabled
[07/26/2022-04:30:36] [I] Separate profiling: Disabled
[07/26/2022-04:30:36] [I] Time Deserialize: Disabled
[07/26/2022-04:30:36] [I] Time Refit: Disabled
[07/26/2022-04:30:36] [I] Inputs:
[07/26/2022-04:30:36] [I] === Reporting Options ===
[07/26/2022-04:30:36] [I] Verbose: Disabled
[07/26/2022-04:30:36] [I] Averages: 10 inferences
[07/26/2022-04:30:36] [I] Percentile: 99
[07/26/2022-04:30:36] [I] Dump refittable layers:Disabled
[07/26/2022-04:30:36] [I] Dump output: Disabled
[07/26/2022-04:30:36] [I] Profile: Disabled
[07/26/2022-04:30:36] [I] Export timing to JSON file: 
[07/26/2022-04:30:36] [I] Export output to JSON file: 
[07/26/2022-04:30:36] [I] Export profile to JSON file: 
[07/26/2022-04:30:36] [I] 
[07/26/2022-04:30:36] [I] === Device Information ===
[07/26/2022-04:30:36] [I] Selected Device: Tesla P100-PCIE-16GB
[07/26/2022-04:30:36] [I] Compute Capability: 6.0
[07/26/2022-04:30:36] [I] SMs: 56
[07/26/2022-04:30:36] [I] Compute Clock Rate: 1.3285 GHz
[07/26/2022-04:30:36] [I] Device Global Memory: 16280 MiB
[07/26/2022-04:30:36] [I] Shared Memory per SM: 64 KiB
[07/26/2022-04:30:36] [I] Memory Bus Width: 4096 bits (ECC enabled)
[07/26/2022-04:30:36] [I] Memory Clock Rate: 0.715 GHz
[07/26/2022-04:30:36] [I] 
[07/26/2022-04:30:36] [I] TensorRT version: 8.4.1
[07/26/2022-04:30:37] [I] [TRT] [MemUsageChange] Init CUDA: CPU +0, GPU +0, now: CPU 0, GPU 367 (MiB)
[07/26/2022-04:30:38] [I] [TRT] [MemUsageChange] Init builder kernel library: CPU +0, GPU +2, now: CPU 0, GPU 369 (MiB)
[07/26/2022-04:30:38] [I] Start parsing network model
[libprotobuf WARNING google/protobuf/io/coded_stream.cc:604] Reading dangerously large protocol message.  If the message turns out to be larger than 2147483647 bytes, parsing will be halted for security reasons.  To increase the limit (or to disable these warnings), see CodedInputStream::SetTotalBytesLimit() in google/protobuf/io/coded_stream.h.
[libprotobuf WARNING google/protobuf/io/coded_stream.cc:604] Reading dangerously large protocol message.  If the message turns out to be larger than 2147483647 bytes, parsing will be halted for security reasons.  To increase the limit (or to disable these warnings), see CodedInputStream::SetTotalBytesLimit() in google/protobuf/io/coded_stream.h.
[libprotobuf WARNING google/protobuf/io/coded_stream.cc:81] The total number of bytes read was 1629838081
[07/26/2022-04:30:55] [W] [TRT] onnx2trt_utils.cpp:369: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.
[07/26/2022-04:30:55] [W] [TRT] onnx2trt_utils.cpp:395: One or more weights outside the range of INT32 was clamped
[07/26/2022-04:30:55] [W] [TRT] onnx2trt_utils.cpp:395: One or more weights outside the range of INT32 was clamped
[07/26/2022-04:30:55] [W] [TRT] onnx2trt_utils.cpp:395: One or more weights outside the range of INT32 was clamped
[07/26/2022-04:30:55] [W] [TRT] onnx2trt_utils.cpp:395: One or more weights outside the range of INT32 was clamped
[07/26/2022-04:30:57] [W] [TRT] onnx2trt_utils.cpp:395: One or more weights outside the range of INT32 was clamped
[07/26/2022-04:30:57] [W] [TRT] onnx2trt_utils.cpp:395: One or more weights outside the range of INT32 was clamped
[07/26/2022-04:31:07] [I] [TRT] No importer registered for op: NonZero. Attempting to import as plugin.
[07/26/2022-04:31:07] [I] [TRT] Searching for plugin: NonZero, plugin_version: 1, plugin_namespace: 
[07/26/2022-04:31:07] [E] [TRT] ModelImporter.cpp:773: While parsing node number 3818 [NonZero -> "onnx::Transpose_6418"]:
[07/26/2022-04:31:07] [E] [TRT] ModelImporter.cpp:774: --- Begin node ---
[07/26/2022-04:31:07] [E] [TRT] ModelImporter.cpp:775: input: "onnx::NonZero_6417"
output: "onnx::Transpose_6418"
name: "NonZero_4444"
op_type: "NonZero"

[07/26/2022-04:31:07] [E] [TRT] ModelImporter.cpp:776: --- End node ---
[07/26/2022-04:31:07] [E] [TRT] ModelImporter.cpp:779: ERROR: builtin_op_importers.cpp:4890 In function importFallbackPluginImporter:
[8] Assertion failed: creator && "Plugin not found, are the plugin name, version, and namespace correct?"
[07/26/2022-04:31:07] [E] Failed to parse onnx file
[07/26/2022-04:31:07] [I] Finish parsing network model
[07/26/2022-04:31:07] [E] Parsing model failed
[07/26/2022-04:31:07] [E] Failed to create engine from model or file.
[07/26/2022-04:31:07] [E] Engine set up failed
&&&& FAILED TensorRT.trtexec [TensorRT v8401] # ./trtexec --onnx=/content/drive/MyDrive/model.onnx --minShapes=input:1x128 --optShapes=input:1x128 --maxShapes=input:1x128 --shapes=input:1x128 --best --workspace=10000 --saveEngine=test.plan

It seems that by using the BartForSequenceClassification module to load my model, the NonZero operation, which is not supported by TensorRT, appears somewhere in my ONNX graph.

Do you have any workaround to this so I can use the BartForClassificationModel to load the facebook/bart-large-mnli model ?

Thanks in advance for taking the time to answer !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants