Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TensorRT EP] Add stream sync after enqueue #18026

Merged
merged 6 commits into from
Oct 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1869,6 +1869,7 @@
} else if (number_of_trt_nodes == number_of_ort_nodes) {
LOGS_DEFAULT(INFO) << "[TensorRT EP] Whole graph will run on TensorRT execution provider";
} else {
sync_stream_after_enqueue_ = true;
LOGS_DEFAULT(INFO) << "[TensorRT EP] Graph is partitioned and number of subgraphs running on TensorRT execution provider is " << number_of_subgraphs;
}

Expand Down Expand Up @@ -2387,7 +2388,7 @@
*p = {context->allocate_func, context->release_func, context->allocator_handle, context->node_name,
&parsers_[context->node_name], &engines_[context->node_name], &contexts_[context->node_name], &builders_[context->node_name],
&networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name],
input_shape_ranges_[context->node_name], &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_,
input_shape_ranges_[context->node_name], sync_stream_after_enqueue_, &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_,

Check warning on line 2391 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L2391

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:2391:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
dla_enable_, dla_core_, &max_workspace_size_, trt_node_name_with_precision, engine_cache_enable_, cache_path_,
runtime_.get(), profiles_[context->node_name], context_memory_sharing_enable_, &max_ctx_mem_size_,
dynamic_range_map, engine_decryption_enable_, engine_decryption_, engine_encryption_, timing_cache_enable_,
Expand Down Expand Up @@ -2415,6 +2416,7 @@
const std::unordered_map<std::string, size_t>& input_indexes = (trt_state->input_info)[0];
const std::unordered_map<std::string, size_t>& output_indexes = (trt_state->output_info)[0];
const std::unordered_map<std::string, size_t>& output_types = (trt_state->output_info)[1];
bool sync_stream_after_enqueue = trt_state->sync_stream_after_enqueue;
auto fused_node_name = trt_state->fused_node_name;
auto& shape_ranges = trt_state->input_shape_ranges;
auto trt_builder = trt_state->builder->get();
Expand Down Expand Up @@ -3022,6 +3024,10 @@
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "TensorRT EP execution context enqueue failed.");
}

if (sync_stream_after_enqueue) {
cudaStreamSynchronize(stream);
}

// Cast INT64 input to INT32 because TensorRT doesn't fully support INT64
for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) {
const std::string& output_name = output_binding_names[i];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ struct TensorrtFuncState {
std::vector<std::unordered_map<std::string, size_t>> input_info;
std::vector<std::unordered_map<std::string, size_t>> output_info;
std::unordered_map<std::string, std::unordered_map<size_t, std::vector<std::vector<int64_t>>>> input_shape_ranges;
bool sync_stream_after_enqueue = false;
OrtMutex* tensorrt_mu_ptr = nullptr;
bool fp16_enable = false;
bool int8_enable = false;
Expand Down Expand Up @@ -262,6 +263,9 @@ class TensorrtExecutionProvider : public IExecutionProvider {
cudnnHandle_t external_cudnn_handle_ = nullptr;
cublasHandle_t external_cublas_handle_ = nullptr;

// Call cudaStreamSynchronize() after TRT enqueueV2()/enqueueV3()
mutable bool sync_stream_after_enqueue_ = false;

CUDAGraph cuda_graph_;
bool is_graph_captured_ = false;
int regular_run_count_before_graph_capture_ = 0;
Expand Down
Loading