diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 74d237a62f73d..d9238e41a28cc 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1869,6 +1869,7 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, } else if (number_of_trt_nodes == number_of_ort_nodes) { LOGS_DEFAULT(INFO) << "[TensorRT EP] Whole graph will run on TensorRT execution provider"; } else { + sync_stream_after_enqueue_ = true; LOGS_DEFAULT(INFO) << "[TensorRT EP] Graph is partitioned and number of subgraphs running on TensorRT execution provider is " << number_of_subgraphs; } @@ -2387,7 +2388,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorallocate_func, context->release_func, context->allocator_handle, context->node_name, &parsers_[context->node_name], &engines_[context->node_name], &contexts_[context->node_name], &builders_[context->node_name], &networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name], - input_shape_ranges_[context->node_name], &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_, + input_shape_ranges_[context->node_name], sync_stream_after_enqueue_, &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_, dla_enable_, dla_core_, &max_workspace_size_, trt_node_name_with_precision, engine_cache_enable_, cache_path_, runtime_.get(), profiles_[context->node_name], context_memory_sharing_enable_, &max_ctx_mem_size_, dynamic_range_map, engine_decryption_enable_, engine_decryption_, engine_encryption_, timing_cache_enable_, @@ -2415,6 +2416,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector& input_indexes = (trt_state->input_info)[0]; const std::unordered_map& output_indexes = (trt_state->output_info)[0]; const std::unordered_map& output_types = (trt_state->output_info)[1]; + bool sync_stream_after_enqueue = trt_state->sync_stream_after_enqueue; auto fused_node_name = trt_state->fused_node_name; auto& shape_ranges = trt_state->input_shape_ranges; auto trt_builder = trt_state->builder->get(); @@ -3022,6 +3024,10 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector> input_info; std::vector> output_info; std::unordered_map>>> input_shape_ranges; + bool sync_stream_after_enqueue = false; OrtMutex* tensorrt_mu_ptr = nullptr; bool fp16_enable = false; bool int8_enable = false; @@ -262,6 +263,9 @@ class TensorrtExecutionProvider : public IExecutionProvider { cudnnHandle_t external_cudnn_handle_ = nullptr; cublasHandle_t external_cublas_handle_ = nullptr; + // Call cudaStreamSynchronize() after TRT enqueueV2()/enqueueV3() + mutable bool sync_stream_after_enqueue_ = false; + CUDAGraph cuda_graph_; bool is_graph_captured_ = false; int regular_run_count_before_graph_capture_ = 0;