Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUTLASS] Support conv2d activation fusion #9746

Merged
merged 2 commits into from
Dec 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions python/tvm/contrib/cutlass/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ def visit_call(self, call):
if str(op) == "nn.conv2d":
self.op_attrs = call.attrs

for arg in call.args:
self.visit(arg)


def select_gemm_kernel(
cutlass_profiler, MM, KK, NN, out_dtype, batched, profile_all, use_multiprocessing
Expand Down Expand Up @@ -213,6 +216,12 @@ def handle_conv2d(

if op_type == "cutlass.conv2d":
cutlass_op_def = out["opdef"]
elif op_type == "cutlass.conv2d_bias":
cutlass_op_def = out["opdef_bias"]
elif op_type == "cutlass.conv2d_bias_relu":
cutlass_op_def = out["opdef_bias_relu"]
elif op_type == "cutlass.conv2d_bias_sigmoid":
cutlass_op_def = out["opdef_bias_sigmoid"]
else:
raise ValueError("%s pattern is not implemented." % op_type)

Expand Down
35 changes: 27 additions & 8 deletions python/tvm/contrib/cutlass/conv2d_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,22 @@ class EmitConv2dInstance:
""" Responsible for emitting a CUTLASS template definition."""

def __init__(self):
self.epilogue_default = """
${epilogue_functor}<
${element_c},
${epilogue_vector_length},
${element_accumulator},
${element_epilogue}
>"""
self.epilogue_no_beta_scaling = """
${epilogue_functor}<
${element_c},
${epilogue_vector_length},
${element_accumulator},
${element_epilogue},
cutlass::epilogue::thread::ScaleType::NoBetaScaling
>"""

self.template = """
// Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
using ${operation_name} =
Expand All @@ -159,12 +175,7 @@ def __init__(self):
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >,
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
${epilogue_functor}<
${element_c},
${epilogue_vector_length},
${element_accumulator},
${element_epilogue}
>,
${epilogue},
${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>,
${stages},
${math_operator},
Expand All @@ -175,7 +186,7 @@ def __init__(self):
>::Kernel;
"""

def emit(self, operation):
def emit(self, operation, no_beta_scaling=True):
"""Instantiate a Conv2d kernel from given `operation`."""
warp_shape = [
int(
Expand Down Expand Up @@ -237,4 +248,12 @@ def emit(self, operation):
"align_b": str(operation.B.alignment),
}

return substitute_template(self.template, values)
template = substitute_template(
self.template,
{
"epilogue": self.epilogue_no_beta_scaling
if no_beta_scaling
else self.epilogue_default
},
)
return substitute_template(template, values)
9 changes: 5 additions & 4 deletions python/tvm/contrib/cutlass/gen_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,16 @@ def create_conv2d_operator(
op_entry["op"] = op
op_entry["src"] = profiler_emitter.emit(op_entry["opdef"], op.procedural_name())
op_entry["name"] = op.procedural_name()
op_entry["runtime"] = 9999999

# fused ops
for epilogue, opdef in zip(
for epilogue, opdef, no_bias_scaling in zip(
[
EpilogueFunctor.LinearCombinationBias,
EpilogueFunctor.LinearCombinationRelu,
EpilogueFunctor.LinearCombinationSigmoid,
],
["opdef_bias", "opdef_bias_relu"],
["opdef_bias", "opdef_bias_relu", "opdef_bias_sigmoid"],
[True, True, False],
):
op = Conv2dOperation(
ConvKind.Fprop,
Expand All @@ -107,7 +108,7 @@ def create_conv2d_operator(
swizzling_functor_,
)

op_entry[opdef] = kernel_emitter.emit(op)
op_entry[opdef] = kernel_emitter.emit(op, no_bias_scaling)

ret.append(op_entry)

Expand Down
1 change: 0 additions & 1 deletion python/tvm/contrib/cutlass/gen_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ def create_gemm_operator(
DataTypeTag[element_c],
op.leading_dim(),
)
op_entry["runtime"] = 9999999
op_entry["tile_description"] = tile_description
op_entry["alignment"] = alignment
op_entry["data_type"] = data_type
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/contrib/cutlass/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,15 @@ class EpilogueFunctor(enum.Enum):
LinearCombinationRelu = enum_auto()
LinearCombinationBias = enum_auto()
LinearCombinationGelu = enum_auto()
LinearCombinationSigmoid = enum_auto()


EpilogueFunctorTag = {
EpilogueFunctor.LinearCombination: "cutlass::epilogue::thread::LinearCombination",
EpilogueFunctor.LinearCombinationRelu: "cutlass::epilogue::thread::LinearCombinationRelu",
EpilogueFunctor.LinearCombinationBias: "cutlass::epilogue::thread::LinearCombination",
EpilogueFunctor.LinearCombinationGelu: "cutlass::epilogue::thread::LinearCombinationGELU",
EpilogueFunctor.LinearCombinationSigmoid: "cutlass::epilogue::thread::LinearCombinationSigmoid",
}


Expand Down
55 changes: 49 additions & 6 deletions python/tvm/relay/op/contrib/cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
# under the License.
# pylint: disable=invalid-name
"""Patterns supported CUTLASS."""
from tvm.ir.transform import Sequential
from tvm.ir.transform import Sequential, PassContext
from tvm.relay import transform
from tvm.relay.build_module import bind_params_by_name
from ...dataflow_pattern import wildcard, is_op, is_constant


Expand Down Expand Up @@ -57,8 +58,25 @@ def make_batch_matmul_pattern():
return is_op("nn.batch_matmul")(wildcard(), wildcard())


def make_conv2d_pattern():
return is_op("nn.conv2d")(wildcard(), wildcard())
def make_conv2d_pattern(with_bias=False, with_act=None):
"""Create a pattern for dense op followed by activations."""
data = wildcard()
weight = wildcard()
bias = wildcard()
conv2d = is_op("nn.conv2d")(data, weight)
if with_bias:
add_or_bias_add = is_op("add") | is_op("nn.bias_add")
conv2d_out = add_or_bias_add(conv2d, bias)
else:
conv2d_out = conv2d

if with_act is not None:
if with_act == "relu":
return is_op("nn.relu")(conv2d_out)
if with_act == "sigmoid":
return is_op("sigmoid")(conv2d_out)

return conv2d_out


def check_dtype(lhs, rhs):
Expand Down Expand Up @@ -109,7 +127,7 @@ def check_conv2d(call):
return not is_depthwise_conv2d(IC, OC, conv2d.attrs.groups)


def partition_for_cutlass(mod):
def partition_for_cutlass(mod, params=None):
"""Partition the input module into CUTLASS-supported subgraphs."""
dense_pat = ("cutlass.dense", make_gemm_pattern(False, None), check_gemm)
dense_bias_pat = ("cutlass.dense_bias", make_gemm_pattern(True, None), check_gemm)
Expand All @@ -131,15 +149,40 @@ def partition_for_cutlass(mod):
dense_bias_pat,
dense_pat,
("cutlass.batch_matmul", make_batch_matmul_pattern(), check_batch_matmul),
# TODO(masahi): Add more conv2d patterns
(
"cutlass.conv2d_bias_relu",
make_conv2d_pattern(with_bias=True, with_act="relu"),
check_conv2d,
),
(
"cutlass.conv2d_bias_sigmoid",
make_conv2d_pattern(with_bias=True, with_act="sigmoid"),
check_conv2d,
),
("cutlass.conv2d_bias", make_conv2d_pattern(with_bias=True), check_conv2d),
("cutlass.conv2d", make_conv2d_pattern(), check_conv2d),
]

if params is not None:
mod["main"] = bind_params_by_name(mod["main"], params)
remove_bn_pass = Sequential(
[
transform.InferType(),
transform.SimplifyInference(),
transform.FoldConstant(),
transform.FoldScaleAxis(),
]
)
with PassContext(opt_level=3):
mod = remove_bn_pass(mod)

seq = Sequential(
[
transform.InferType(),
transform.MergeComposite(cutlass_patterns),
transform.AnnotateTarget(["cutlass"]),
transform.PartitionGraph(),
transform.PartitionGraph(bind_constants=False),
]
)

return seq(mod)
56 changes: 51 additions & 5 deletions src/relay/backend/contrib/cutlass/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,11 @@ Str2StrMap Conv2dArgs(const Map<String, ObjectRef>& attrs) {

std::string Conv2dOp(std::string id, const Str2StrMap& attrs,
const std::vector<std::string>& func_args) {
bool has_bias = attrs.at("op_type") == "cutlass.conv2d_bias" ||
attrs.at("op_type") == "cutlass.conv2d_bias_relu" ||
attrs.at("op_type") == "cutlass.conv2d_bias_sigmoid";
bool no_bias_scaling = attrs.at("op_type") != "cutlass.conv2d_bias_sigmoid";

std::ostringstream conv2d_decl;
CutlassPrint(conv2d_decl, "using ElementInputA = " + attrs.at("ElementInputA") + ";\n");
CutlassPrint(conv2d_decl, "using ElementInputB = " + attrs.at("ElementInputB") + ";\n");
Expand Down Expand Up @@ -307,10 +312,18 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs,
ICHECK(func_args.size() >= 2);
CutlassPrint(conv2d_decl, "void* ptr_a = (void*)(" + func_args[0] + "->data);\n");
CutlassPrint(conv2d_decl, "void* ptr_b = (void*)(" + func_args[1] + "->data);\n");
if (has_bias) {
ICHECK(func_args.size() >= 3);
CutlassPrint(conv2d_decl, "void* ptr_c_bias = (void*)(" + func_args[2] + "->data);\n");
}

CutlassPrint(conv2d_decl, "void* ptr_out = (void*)(out0->data);\n");
CutlassPrint(conv2d_decl, "ElementComputeEpilogue alpha = ElementComputeEpilogue(1);\n");
CutlassPrint(conv2d_decl, "ElementComputeEpilogue beta = ElementComputeEpilogue(0);\n");

if (has_bias && no_bias_scaling) {
CutlassPrint(conv2d_decl, "ElementComputeEpilogue beta = ElementComputeEpilogue(0);\n");
} else {
CutlassPrint(conv2d_decl, "ElementComputeEpilogue beta = ElementComputeEpilogue(1);\n");
}
CutlassPrint(conv2d_decl, "using cutlass::layout::TensorNHWC;\n");
CutlassPrint(conv2d_decl,
"TensorNHWC layout_A(TensorNHWC::packed(cutlass::make_Coord(N, H, W, C)));\n");
Expand All @@ -322,9 +335,19 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs,
CutlassPrint(conv2d_decl, " problem_size,\n");
CutlassPrint(conv2d_decl, " {static_cast<ElementInputA*>(ptr_a), layout_A},\n");
CutlassPrint(conv2d_decl, " {static_cast<ElementInputB*>(ptr_b), layout_B},\n");
if (has_bias) {
CutlassPrint(
conv2d_decl,
" {static_cast<ElementOutput*>(ptr_c_bias), cutlass::layout::TensorNHWC::Stride(0)},\n");
} else {
CutlassPrint(conv2d_decl, " {static_cast<ElementOutput*>(ptr_out),layout_C},\n");
}
CutlassPrint(conv2d_decl, " {static_cast<ElementOutput*>(ptr_out),layout_C},\n");
CutlassPrint(conv2d_decl, " {static_cast<ElementOutput*>(ptr_out),layout_C},\n");
CutlassPrint(conv2d_decl, "{alpha, beta}\n};\n");
if (has_bias && no_bias_scaling) {
CutlassPrint(conv2d_decl, " {alpha}\n};\n");
} else {
CutlassPrint(conv2d_decl, "{alpha, beta}\n};\n");
}
CutlassPrint(conv2d_decl, "Conv2d conv2d_op;\n");

CutlassPrint(conv2d_decl, "size_t workspace_size = conv2d_op.get_workspace_size(arguments);\n");
Expand Down Expand Up @@ -461,6 +484,27 @@ class CodegenCutlass : public MemoizedExprTranslator<std::vector<Output>>, publi
const auto* conv2d_call = GetRootCall(callee->body.as<CallNode>(), 0, {"nn.conv2d"});
return GenerateBody(conv2d_call, "cutlass_conv2d", GetArgumentNames(caller),
Conv2dArgs(std::ref(attrs_)));
} else if (pattern_name == "cutlass.conv2d_bias") {
const CallNode* current_call = callee->body.as<CallNode>();
std::string add_or_bias_add = current_call->op.as<OpNode>()->name;
const auto* conv2d_call =
GetRootCall(callee->body.as<CallNode>(), 1, {"nn.conv2d", add_or_bias_add});
return GenerateBody(conv2d_call, "cutlass_conv2d_bias", GetArgumentNames(caller),
Conv2dArgs(std::ref(attrs_)));
} else if (pattern_name == "cutlass.conv2d_bias_relu") {
const CallNode* current_call = callee->body.as<CallNode>();
std::string add_or_bias_add = current_call->args[0].as<CallNode>()->op.as<OpNode>()->name;
const auto* conv2d_call =
GetRootCall(callee->body.as<CallNode>(), 2, {"nn.conv2d", add_or_bias_add, "nn.relu"});
return GenerateBody(conv2d_call, "cutlass_conv2d_bias_relu", GetArgumentNames(caller),
Conv2dArgs(std::ref(attrs_)));
} else if (pattern_name == "cutlass.conv2d_bias_sigmoid") {
const CallNode* current_call = callee->body.as<CallNode>();
std::string add_or_bias_add = current_call->args[0].as<CallNode>()->op.as<OpNode>()->name;
const auto* conv2d_call =
GetRootCall(callee->body.as<CallNode>(), 2, {"nn.conv2d", add_or_bias_add, "sigmoid"});
return GenerateBody(conv2d_call, "cutlass_conv2d_bias_sigmoid", GetArgumentNames(caller),
Conv2dArgs(std::ref(attrs_)));
}

LOG(FATAL) << "Unknown composite function: " << pattern_name;
Expand Down Expand Up @@ -507,7 +551,9 @@ class CodegenCutlass : public MemoizedExprTranslator<std::vector<Output>>, publi
ret.decl = DenseOp(ext_func_id_, attribute_args, func_args);
} else if (func_name == "cutlass_batch_matmul") {
ret.decl = BatchMatmulOp(ext_func_id_, attribute_args, func_args);
} else if (func_name == "cutlass_conv2d") {
} else if (func_name == "cutlass_conv2d" || func_name == "cutlass_conv2d_bias" ||
func_name == "cutlass_conv2d_bias_relu" ||
func_name == "cutlass_conv2d_bias_sigmoid") {
ret.decl = Conv2dOp(ext_func_id_, attribute_args, func_args);
}

Expand Down
Loading