diff --git a/src/runtime/contrib/tensorrt/tensorrt_builder.cc b/src/runtime/contrib/tensorrt/tensorrt_builder.cc index d8182b0e8378f..578853efcb070 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_builder.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_builder.cc @@ -163,10 +163,19 @@ TensorRTEngineAndContext TensorRTBuilder::BuildEngine() { auto profile = builder_->createOptimizationProfile(); for (int i = 0; i < network_->getNbInputs(); ++i) { auto name = network_->getInput(i)->getName(); - auto dims = network_->getInput(i)->getDimensions(); - profile->setDimensions(name, nvinfer1::OptProfileSelector::kMIN, dims); + const int entry_id = entry_id_map_[name]; + std::vector shape(data_entry_[entry_id]->shape, + data_entry_[entry_id]->shape + data_entry_[entry_id]->ndim); + auto dims = VectorToTrtDims(shape); + profile->setDimensions(name, nvinfer1::OptProfileSelector::kOPT, dims); profile->setDimensions(name, nvinfer1::OptProfileSelector::kMAX, dims); + // Set minimum batch size to 1 when dynamic batching is used. + if (network_->getInput(i)->getDimensions().nbDims >= 1 && + network_->getInput(i)->getDimensions().d[0] == -1) { + dims.d[0] = 1; + } + profile->setDimensions(name, nvinfer1::OptProfileSelector::kMIN, dims); } config_->addOptimizationProfile(profile); } diff --git a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc index 6358e59ce3bc5..5562f853383c6 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc @@ -140,6 +140,12 @@ class TensorRTRuntime : public JSONRuntimeBase { const std::string name = nodes_[nid].GetOpName() + "_" + std::to_string(j); int binding_index = engine->getBindingIndex(name.c_str()); ICHECK_NE(binding_index, -1); + if (!use_implicit_batch_) { + std::vector shape(data_entry_[eid]->shape, + data_entry_[eid]->shape + data_entry_[eid]->ndim); + auto dims = VectorToTrtDims(shape); + ICHECK(context->setBindingDimensions(binding_index, dims)); + } if (data_entry_[eid]->device.device_type == kDLCUDA) { bindings[binding_index] = data_entry_[eid]->data; } else { @@ -300,7 +306,7 @@ class TensorRTRuntime : public JSONRuntimeBase { helper.DeclareField("inputs", &engine_and_context.inputs); helper.DeclareField("outputs", &engine_and_context.outputs); helper.ReadAllFields(&reader); - const int batch_size = 1; + const int batch_size = GetBatchSize(); trt_engine_cache_[std::make_pair(symbol_name_, batch_size)] = engine_and_context; return true; } diff --git a/tests/python/contrib/test_tensorrt.py b/tests/python/contrib/test_tensorrt.py index 59f1c3aa4d682..3f57df5a5f4ac 100644 --- a/tests/python/contrib/test_tensorrt.py +++ b/tests/python/contrib/test_tensorrt.py @@ -1251,33 +1251,35 @@ def test_tensorrt_dynamic_batch_conv(): x_data = np.ones([max(batches_to_test)] + list(x_shape)[1:]).astype("float32") k_shape = (16, 32, 3, 3) params = {"kernel": np.random.uniform(-1, 1, k_shape).astype("float32")} - result_arr = [{"cuda": {}, "llvm": {}} for _ in range(len(batches_to_test))] - for use_trt in [True, False]: - x = relay.var("x", shape=x_shape, dtype="float32") - kernel = relay.var("kernel", shape=k_shape, dtype="float32") - out = relay.nn.conv2d(x, kernel, channels=16, kernel_size=(3, 3), groups=1) - f = relay.Function([x, kernel], out) - mod = tvm.IRModule() - mod["main"] = f - if use_trt: - mod, _ = tensorrt.partition_for_tensorrt(mod, params) - + for use_implicit_batch in [True, False]: + result_arr = [{"cuda": {}, "llvm": {}} for _ in range(len(batches_to_test))] + for use_trt in [True, False]: + x = relay.var("x", shape=x_shape, dtype="float32") + kernel = relay.var("kernel", shape=k_shape, dtype="float32") + out = relay.nn.conv2d(x, kernel, channels=16, kernel_size=(3, 3), groups=1) + f = relay.Function([x, kernel], out) + mod = tvm.IRModule() + mod["main"] = f + if use_trt: + mod, config = tensorrt.partition_for_tensorrt( + mod, params, use_implicit_batch=use_implicit_batch + ) + if not skip_runtime_test(): + for target in ["llvm", "cuda"]: + with tvm.transform.PassContext( + opt_level=3, config={"relay.ext.tensorrt.options": config} + ): + relay_exec = relay.create_executor( + "vm", mod=mod, device=tvm.device(target), target=target + ) + for i, batch_size in enumerate(batches_to_test): + result_arr[i][target][use_trt] = relay_exec.evaluate()( + x_data[:batch_size, ...], **params + ) if not skip_runtime_test(): - for target in ["llvm", "cuda"]: - with relay.build_config(opt_level=3): - relay_exec = relay.create_executor( - "vm", mod=mod, device=tvm.cpu(0), target="llvm" - ) - - for i, batch_size in enumerate(batches_to_test): - result_arr[i][target][use_trt] = relay_exec.evaluate()( - x_data[:batch_size, ...], **params - ) - - if not skip_runtime_test(): - for i in range(len(batches_to_test)): - for target in ["llvm", "cuda"]: - assert_result_dict_holds(result_arr[i][target]) + for i in range(len(batches_to_test)): + for target in ["llvm", "cuda"]: + assert_result_dict_holds(result_arr[i][target]) def test_maskrcnn_resnet50() -> None: