Skip to content

Commit

Permalink
[CUDA EP] Add warning logs when adding memcpy nodes (#18032)
Browse files Browse the repository at this point in the history
Memcpy nodes could have negative impact on performance, they also cause
ORT unable to run CUDA graph.

Here we add a warning log for CUDA EP when this happens. It could help
trouble shooting. For example, when CUDA graph cannot run, we can see
the logs to find out where the Memcpy nodes are inserted (Although it is
also possible through saving optimized model, but that need more time
and disk space).

Note that the warning is per graph. When there are subgraphs, we might
see multiple warnings if the issue happens in multiple graphs.

Example logs:
```
2023-10-19 20:58:10.678176531 [I:onnxruntime:, transformer_memcpy.cc:329 AddCopyNode] Add MemcpyFromHost after input_ids for CUDAExecutionProvider
2023-10-19 20:58:10.678198702 [I:onnxruntime:, transformer_memcpy.cc:329 AddCopyNode] Add MemcpyFromHost after /text_model/ArgMax_output_0 for CUDAExecutionProvider
2023-10-19 20:58:10.678211727 [I:onnxruntime:, transformer_memcpy.cc:329 AddCopyNode] Add MemcpyFromHost after /text_model/Gather_3_output_0 for CUDAExecutionProvider
2023-10-19 20:58:10.678257903 [W:onnxruntime:, transformer_memcpy.cc:74 ApplyImpl] 3 Memcpy nodes are added to the graph main_graph for CUDAExecutionProvider. It might have negative impact on performance (including unable to run CUDA graph). Set session_options.log_severity_level=1 to see the detail logs before this message.
```
  • Loading branch information
tianleiwu authored Oct 24, 2023
1 parent 555b2af commit 688524a
Showing 1 changed file with 30 additions and 10 deletions.
40 changes: 30 additions & 10 deletions onnxruntime/core/optimizer/transformer_memcpy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

#include "transformer_memcpy.h"
#include "core/common/logging/logging.h"
#include "core/framework/kernel_registry_manager.h"
#include "core/framework/execution_providers.h"
#include "core/framework/utils.h"
Expand All @@ -16,12 +17,12 @@ class TransformerMemcpyImpl {
TransformerMemcpyImpl(onnxruntime::Graph& graph, const std::string& provider)
: graph_(graph), provider_(provider) {}

bool ModifyGraph(const KernelRegistryManager& schema_registries);
bool ModifyGraph(const KernelRegistryManager& schema_registries, const logging::Logger& logger, int& copy_node_counter);

private:
void ProcessDefs(onnxruntime::Node& node, const KernelRegistryManager& kernel_registries, InitializedTensorSet& initializers_consumed);
void BuildDefsMapping(const onnxruntime::NodeArg* arg, const KernelRegistryManager& kernel_registries);
void AddCopyNode(onnxruntime::NodeArg* arg, bool is_input);
void AddCopyNode(onnxruntime::NodeArg* arg, bool is_input, const logging::Logger& logger);
bool ProcessInitializers(const KernelRegistryManager& kernel_registries, const InitializedTensorSet& initializers_consumed);

private:
Expand Down Expand Up @@ -61,11 +62,21 @@ static const onnx::TensorProto* GetInitializer(const Graph& graph, const std::st

// very simple GraphTransformer that uses TransformerMemcpyImpl for each graph
// and mainly provides the subgraph recursion functionality
common::Status MemcpyTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
common::Status MemcpyTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level,
const logging::Logger& logger) const {
for (auto& provider : provider_types_) {
if (!utils::ProviderIsCpuBased(provider)) {
TransformerMemcpyImpl copy_impl(graph, provider);
auto current_modified = copy_impl.ModifyGraph(registry_manager_);

int copy_node_counter = 0;
auto current_modified = copy_impl.ModifyGraph(registry_manager_, logger, copy_node_counter);
if (copy_node_counter > 0 && provider == kCudaExecutionProvider) {
LOGS(logger, WARNING) << copy_node_counter << " Memcpy nodes are added to the graph " << graph.Name()
<< " for " << provider
<< ". It might have negative impact on performance (including unable to run CUDA graph). "
<< "Set session_options.log_severity_level=1 to see the detail logs before this message.";
}

modified = modified || current_modified;
break;
}
Expand Down Expand Up @@ -111,7 +122,9 @@ This transformer does not currently optimize copies between, e.g., two different
*/

bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_registries) {
bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_registries,
const logging::Logger& logger,
int& copy_node_counter) {
bool modified = false;
InitializedTensorSet initializers_consumed;
// find defs that require copy
Expand All @@ -137,19 +150,22 @@ bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_regi
// For inputs we need to create a copy node only when the input is connected to both provider
// and non-provider nodes. Otherwise utils::CopyInputsAcrossDevices() will do the job.
if (provider_input_defs_.count(arg) && non_provider_input_defs_.count(arg)) {
AddCopyNode(const_cast<onnxruntime::NodeArg*>(arg), true);
AddCopyNode(const_cast<onnxruntime::NodeArg*>(arg), true, logger);
copy_node_counter++;
modified = true;
}

for (auto arg : non_provider_output_defs_)
if (provider_input_defs_.count(arg)) {
AddCopyNode(arg, true);
AddCopyNode(arg, true, logger);
copy_node_counter++;
modified = true;
}

for (auto arg : provider_output_defs_)
if (non_provider_input_defs_.count(arg)) {
AddCopyNode(arg, false);
AddCopyNode(arg, false, logger);
copy_node_counter++;
modified = true;
}

Expand All @@ -176,7 +192,8 @@ bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_regi
// (the name will be the same as the parent node's implicit input)
const auto* node_arg_in_current_graph_level = *provider_input_defs_.find(arg);

AddCopyNode(const_cast<onnxruntime::NodeArg*>(node_arg_in_current_graph_level), true);
AddCopyNode(const_cast<onnxruntime::NodeArg*>(node_arg_in_current_graph_level), true, logger);
copy_node_counter++;
modified = true;
}
}
Expand Down Expand Up @@ -297,7 +314,7 @@ void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg, co
}
}

void TransformerMemcpyImpl::AddCopyNode(onnxruntime::NodeArg* arg, bool is_input) {
void TransformerMemcpyImpl::AddCopyNode(onnxruntime::NodeArg* arg, bool is_input, const logging::Logger& logger) {
// create unique name for new def
std::string new_def_name = graph_.GenerateNodeArgName(arg->Name() + "_" + provider_);

Expand All @@ -309,6 +326,9 @@ void TransformerMemcpyImpl::AddCopyNode(onnxruntime::NodeArg* arg, bool is_input
std::string new_node_name = graph_.GenerateNodeName("Memcpy");

const auto op_name = is_input ? "MemcpyFromHost" : "MemcpyToHost";
LOGS(logger, INFO) << "Add " << op_name << (is_input ? " after " : " before ") << arg->Name()
<< " for " << provider_;

auto& new_node = graph_.AddNode(new_node_name, op_name, "Copy from/to host memory",
std::vector<onnxruntime::NodeArg*>{src_arg},
std::vector<onnxruntime::NodeArg*>{dst_arg});
Expand Down

0 comments on commit 688524a

Please sign in to comment.