Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[1.x][FEATURE] CUDA graphs support #19142

Merged
merged 15 commits into from
Sep 19, 2020
Merged
10 changes: 10 additions & 0 deletions docs/static_site/src/pages/api/faq/env_var.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,16 @@ $env:MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
* MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD
- Values: Int ```(default=<value of MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN>)```
- The maximum number of nodes in the subgraph executed in bulk during training (not inference) in the backward pass.
* MXNET_ENABLE_CUDA_GRAPHS
- Values: 0(false) or 1(true) ```(default=0)```
- If set to `1`, MXNet will utilize CUDA graphs when executing models on the GPU when possible.
- For CUDA graphs execution, one needs to use either symbolic model or Gluon model hybridized with options `static_alloc` and `static_shape` set to True.
* MXNET_CUDA_GRAPHS_VERBOSE
- Values: 0(false) or 1(true) ```(default=0)```
- If set to `1`, CUDA graphs executor will provide information about the graph being captured and executed.
* MXNET_CUDA_GRAPHS_MAX_LOG_ENTRIES
- Values: Int ```(default=0)```
- The maximum number of log messages generated by CUDA graphs executor.

## Control the Data Communication

Expand Down
13 changes: 13 additions & 0 deletions include/mxnet/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,19 @@ using FNeedCalibrateInput = std::function<std::vector<int> (const NodeAttrs& att
*/
using FNeedCalibrateOutput = std::function<std::vector<int> (const NodeAttrs& attrs)>;

#if MXNET_USE_CUDA

/*!
* \brief Register a function to determine if
* the operator implementation is compatible
* with CUDA graphs. This requires the execution
* to stay the same as long as the shape and type
* of input stays the same.
*/
using FIsCUDAGraphsCompatible = std::function<bool (const NodeAttrs& attrs, const bool is_train)>;

#endif

} // namespace mxnet

#endif // MXNET_OP_ATTR_TYPES_H_
64 changes: 36 additions & 28 deletions src/executor/attach_op_execs_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,10 @@ namespace exec {
// FComputeExecutor and FStatefulComputeExecutor inherit from this class
class StorageFallbackOpExecutor : public OpExecutor {
public:
explicit StorageFallbackOpExecutor(const std::vector<uint32_t> &mutate_idx)
: mutate_idx_(mutate_idx) {}
explicit StorageFallbackOpExecutor(const NodeAttrs& attrs,
const DispatchMode& dispatch_mode,
const std::vector<uint32_t> &mutate_idx)
: OpExecutor(attrs, dispatch_mode), mutate_idx_(mutate_idx) {}

void Setup() override {
init_ = false;
Expand Down Expand Up @@ -136,11 +138,13 @@ class StatefulComputeExecutor : public StorageFallbackOpExecutor {
return state_;
}

explicit StatefulComputeExecutor(const OpStatePtr& state,
explicit StatefulComputeExecutor(const NodeAttrs& attrs,
const DispatchMode dispatch_mode,
const OpStatePtr& state,
const FStatefulCompute& fcompute,
ExecType exec_type,
const std::vector<uint32_t> &mutate_idx)
: StorageFallbackOpExecutor(mutate_idx),
: StorageFallbackOpExecutor(attrs, dispatch_mode, mutate_idx),
state_(state), fcompute_(fcompute), exec_type_(exec_type) {}

private:
Expand All @@ -159,7 +163,7 @@ class StatefulComputeExExecutor : public OpExecutor {
InvalidateOutputs(out_array, req);
// TODO(alex): (MXNET-847) Remove this fallback feature after subgraph implemented
const auto is_mkldnn = Op::GetAttr<bool>("TIsMKLDNN");
if (!is_mkldnn.get(attrs_.op, false)) {
if (!is_mkldnn.get(attrs.op, false)) {
CreateDefaultInputs(in_array, &in_array_fallback);
fcompute_(state_, op_ctx, in_array_fallback, req, out_array);
return;
Expand All @@ -183,13 +187,14 @@ class StatefulComputeExExecutor : public OpExecutor {
}

explicit StatefulComputeExExecutor(const NodeAttrs& attrs,
const DispatchMode& dispatch_mode,
const OpStatePtr& state,
const FStatefulComputeEx& fcompute,
ExecType exec_type)
: attrs_(attrs), state_(state), fcompute_(fcompute), exec_type_(exec_type) {}
: OpExecutor(attrs, dispatch_mode), state_(state), fcompute_(fcompute),
exec_type_(exec_type) {}

private:
NodeAttrs attrs_;
OpStatePtr state_;
FStatefulComputeEx fcompute_;
ExecType exec_type_;
Expand All @@ -206,22 +211,22 @@ class FComputeExecutor : public StorageFallbackOpExecutor {
InvalidateOutputs(out_array, req);
#endif
PreFCompute(is_gpu);
fcompute_(attrs_, op_ctx, in_data_, req, out_data_);
fcompute_(attrs, op_ctx, in_data_, req, out_data_);
PostFCompute(is_gpu);
}

ExecType exec_type() const override {
return exec_type_;
}

explicit FComputeExecutor(const NodeAttrs& attrs, FCompute fcompute,
ExecType exec_type, const std::vector<uint32_t> &mutate_idx)
: StorageFallbackOpExecutor(mutate_idx),
attrs_(attrs), fcompute_(fcompute), exec_type_(exec_type) {
explicit FComputeExecutor(const NodeAttrs& attrs, const DispatchMode dispatch_mode,
FCompute fcompute, ExecType exec_type,
const std::vector<uint32_t> &mutate_idx)
: StorageFallbackOpExecutor(attrs, dispatch_mode, mutate_idx),
fcompute_(fcompute), exec_type_(exec_type) {
}

private:
NodeAttrs attrs_;
FCompute fcompute_;
ExecType exec_type_;
};
Expand All @@ -235,13 +240,13 @@ class FComputeExExecutor : public OpExecutor {
InvalidateOutputs(out_array, req);
// TODO(alex): (MXNET-847) Remove this fallback feature after subgraph implemented
const auto is_mkldnn = Op::GetAttr<bool>("TIsMKLDNN");
if (!is_mkldnn.get(attrs_.op, false)) {
if (!is_mkldnn.get(attrs.op, false)) {
CreateDefaultInputs(in_array, &in_array_fallback);
fcompute_(attrs_, op_ctx, in_array_fallback, req, out_array);
fcompute_(attrs, op_ctx, in_array_fallback, req, out_array);
return;
}
#endif
fcompute_(attrs_, op_ctx, in_array, req, out_array);
fcompute_(attrs, op_ctx, in_array, req, out_array);
}

void Setup() override {}
Expand All @@ -250,13 +255,12 @@ class FComputeExExecutor : public OpExecutor {
return exec_type_;
}

explicit FComputeExExecutor(const NodeAttrs& attrs, FComputeEx fcompute,
ExecType exec_type)
: attrs_(attrs), fcompute_(fcompute), exec_type_(exec_type) {
explicit FComputeExExecutor(const NodeAttrs& attrs, const DispatchMode dispatch_mode,
FComputeEx fcompute, ExecType exec_type)
: OpExecutor(attrs, dispatch_mode), fcompute_(fcompute), exec_type_(exec_type) {
}

private:
NodeAttrs attrs_;
FComputeEx fcompute_;
ExecType exec_type_;
};
Expand Down Expand Up @@ -310,15 +314,18 @@ void CreateOpExecs(const Graph& g, OpExecVector* p_ret, OpStateVector* p_state,
op, "FStatefulComputeEx", vctx[i]);
// FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx
if (fcompute_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) {
ret[i] = std::make_shared<StatefulComputeExExecutor>(inode.source->attrs, state,
ret[i] = std::make_shared<StatefulComputeExExecutor>(inode.source->attrs,
dispatch_modes[i], state,
fcompute_ex, exec_type);
} else {
FStatefulCompute fcompute = common::GetFCompute<FStatefulCompute>(
op, "FStatefulCompute", vctx[i]);
CHECK(fcompute != nullptr)
<< "One of FStatefulCompute and FStatefulComputeEx must be registered "
<< "for stateful operator " << op->name;
ret[i] = std::make_shared<StatefulComputeExecutor>(state, fcompute,
ret[i] = std::make_shared<StatefulComputeExecutor>(inode.source->attrs,
dispatch_modes[i],
state, fcompute,
exec_type, mutate_index);
}
} else if (is_layer_backward.get(op, false)) {
Expand All @@ -331,26 +338,27 @@ void CreateOpExecs(const Graph& g, OpExecVector* p_ret, OpStateVector* p_state,
// FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx
if (fcompute_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) {
ret[i] = std::make_shared<StatefulComputeExExecutor>(
inode.source->attrs, ret[fwd_id].get()->state(), fcompute_ex,
exec_type);
inode.source->attrs, dispatch_modes[i], ret[fwd_id].get()->state(),
fcompute_ex, exec_type);
} else {
FStatefulCompute fcompute = common::GetFCompute<FStatefulCompute>(
op, "FStatefulCompute", vctx[i]);
CHECK(fcompute != nullptr)
<< "One of FStatefulCompute and FStatefulComputeEx must be registered "
<< "for stateful operator " << op->name;
ret[i] = std::make_shared<StatefulComputeExecutor>(
ret[fwd_id].get()->state(), fcompute, exec_type, mutate_index);
ret[i] = std::make_shared<StatefulComputeExecutor>(inode.source->attrs,
dispatch_modes[i], ret[fwd_id].get()->state(), fcompute, exec_type,
mutate_index);
}
} else {
FCompute fcompute = common::GetFCompute<FCompute>(op, "FCompute", vctx[i]);
FComputeEx fcomp_ex = common::GetFCompute<FComputeEx>(op, "FComputeEx", vctx[i]);
if (fcomp_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx) {
ret[i] = std::make_shared<FComputeExExecutor>(
inode.source->attrs, fcomp_ex, exec_type);
inode.source->attrs, dispatch_modes[i], fcomp_ex, exec_type);
} else if (fcompute != nullptr) {
ret[i] = std::make_shared<FComputeExecutor>(
inode.source->attrs, fcompute, exec_type, mutate_index);
inode.source->attrs, dispatch_modes[i], fcompute, exec_type, mutate_index);
} else {
LOG(INFO) << "Neither FCompute nor FComputeEx registered " << op->name;
}
Expand Down
Loading