Skip to content

Commit

Permalink
[TensorRT EP] Add stream sync after enqueue (#18026)
Browse files Browse the repository at this point in the history
If the model is partitioned into TRT subgraphs and CUDA EP node, we
observed cuda stream synchronization issue when multithreading. Calling
stream sync API after enqueue can solve this issue without adding much
performance overhead.
  • Loading branch information
chilo-ms authored Oct 20, 2023
1 parent 020824e commit 2f57625
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -1869,6 +1869,7 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph,
} else if (number_of_trt_nodes == number_of_ort_nodes) {
LOGS_DEFAULT(INFO) << "[TensorRT EP] Whole graph will run on TensorRT execution provider";
} else {
sync_stream_after_enqueue_ = true;
LOGS_DEFAULT(INFO) << "[TensorRT EP] Graph is partitioned and number of subgraphs running on TensorRT execution provider is " << number_of_subgraphs;
}

Expand Down Expand Up @@ -2387,7 +2388,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
*p = {context->allocate_func, context->release_func, context->allocator_handle, context->node_name,
&parsers_[context->node_name], &engines_[context->node_name], &contexts_[context->node_name], &builders_[context->node_name],
&networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name],
input_shape_ranges_[context->node_name], &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_,
input_shape_ranges_[context->node_name], sync_stream_after_enqueue_, &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_,
dla_enable_, dla_core_, &max_workspace_size_, trt_node_name_with_precision, engine_cache_enable_, cache_path_,
runtime_.get(), profiles_[context->node_name], context_memory_sharing_enable_, &max_ctx_mem_size_,
dynamic_range_map, engine_decryption_enable_, engine_decryption_, engine_encryption_, timing_cache_enable_,
Expand Down Expand Up @@ -2415,6 +2416,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
const std::unordered_map<std::string, size_t>& input_indexes = (trt_state->input_info)[0];
const std::unordered_map<std::string, size_t>& output_indexes = (trt_state->output_info)[0];
const std::unordered_map<std::string, size_t>& output_types = (trt_state->output_info)[1];
bool sync_stream_after_enqueue = trt_state->sync_stream_after_enqueue;
auto fused_node_name = trt_state->fused_node_name;
auto& shape_ranges = trt_state->input_shape_ranges;
auto trt_builder = trt_state->builder->get();
Expand Down Expand Up @@ -3022,6 +3024,10 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "TensorRT EP execution context enqueue failed.");
}

if (sync_stream_after_enqueue) {
cudaStreamSynchronize(stream);
}

// Cast INT64 input to INT32 because TensorRT doesn't fully support INT64
for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) {
const std::string& output_name = output_binding_names[i];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ struct TensorrtFuncState {
std::vector<std::unordered_map<std::string, size_t>> input_info;
std::vector<std::unordered_map<std::string, size_t>> output_info;
std::unordered_map<std::string, std::unordered_map<size_t, std::vector<std::vector<int64_t>>>> input_shape_ranges;
bool sync_stream_after_enqueue = false;
OrtMutex* tensorrt_mu_ptr = nullptr;
bool fp16_enable = false;
bool int8_enable = false;
Expand Down Expand Up @@ -262,6 +263,9 @@ class TensorrtExecutionProvider : public IExecutionProvider {
cudnnHandle_t external_cudnn_handle_ = nullptr;
cublasHandle_t external_cublas_handle_ = nullptr;

// Call cudaStreamSynchronize() after TRT enqueueV2()/enqueueV3()
mutable bool sync_stream_after_enqueue_ = false;

CUDAGraph cuda_graph_;
bool is_graph_captured_ = false;
int regular_run_count_before_graph_capture_ = 0;
Expand Down

0 comments on commit 2f57625

Please sign in to comment.