diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index 32a2ebaa2842..a2e6bce8cfea 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -87,6 +87,9 @@ def visit_call(self, call): if str(op) == "nn.conv2d": self.op_attrs = call.attrs + for arg in call.args: + self.visit(arg) + def select_gemm_kernel( cutlass_profiler, MM, KK, NN, out_dtype, batched, profile_all, use_multiprocessing @@ -213,6 +216,12 @@ def handle_conv2d( if op_type == "cutlass.conv2d": cutlass_op_def = out["opdef"] + elif op_type == "cutlass.conv2d_bias": + cutlass_op_def = out["opdef_bias"] + elif op_type == "cutlass.conv2d_bias_relu": + cutlass_op_def = out["opdef_bias_relu"] + elif op_type == "cutlass.conv2d_bias_sigmoid": + cutlass_op_def = out["opdef_bias_sigmoid"] else: raise ValueError("%s pattern is not implemented." % op_type) diff --git a/python/tvm/contrib/cutlass/conv2d_operation.py b/python/tvm/contrib/cutlass/conv2d_operation.py index 8a886ff260b8..35308928cdab 100644 --- a/python/tvm/contrib/cutlass/conv2d_operation.py +++ b/python/tvm/contrib/cutlass/conv2d_operation.py @@ -143,6 +143,22 @@ class EmitConv2dInstance: """ Responsible for emitting a CUTLASS template definition.""" def __init__(self): + self.epilogue_default = """ + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >""" + self.epilogue_no_beta_scaling = """ + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue}, + cutlass::epilogue::thread::ScaleType::NoBetaScaling + >""" + self.template = """ // Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}" using ${operation_name} = @@ -159,12 +175,7 @@ def __init__(self): cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >, cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, - ${epilogue_functor}< - ${element_c}, - ${epilogue_vector_length}, - ${element_accumulator}, - ${element_epilogue} - >, + ${epilogue}, ${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>, ${stages}, ${math_operator}, @@ -175,7 +186,7 @@ def __init__(self): >::Kernel; """ - def emit(self, operation): + def emit(self, operation, no_beta_scaling=True): """Instantiate a Conv2d kernel from given `operation`.""" warp_shape = [ int( @@ -237,4 +248,12 @@ def emit(self, operation): "align_b": str(operation.B.alignment), } - return substitute_template(self.template, values) + template = substitute_template( + self.template, + { + "epilogue": self.epilogue_no_beta_scaling + if no_beta_scaling + else self.epilogue_default + }, + ) + return substitute_template(template, values) diff --git a/python/tvm/contrib/cutlass/gen_conv2d.py b/python/tvm/contrib/cutlass/gen_conv2d.py index b0d0566d6fab..288f67f39287 100644 --- a/python/tvm/contrib/cutlass/gen_conv2d.py +++ b/python/tvm/contrib/cutlass/gen_conv2d.py @@ -83,15 +83,16 @@ def create_conv2d_operator( op_entry["op"] = op op_entry["src"] = profiler_emitter.emit(op_entry["opdef"], op.procedural_name()) op_entry["name"] = op.procedural_name() - op_entry["runtime"] = 9999999 # fused ops - for epilogue, opdef in zip( + for epilogue, opdef, no_bias_scaling in zip( [ EpilogueFunctor.LinearCombinationBias, EpilogueFunctor.LinearCombinationRelu, + EpilogueFunctor.LinearCombinationSigmoid, ], - ["opdef_bias", "opdef_bias_relu"], + ["opdef_bias", "opdef_bias_relu", "opdef_bias_sigmoid"], + [True, True, False], ): op = Conv2dOperation( ConvKind.Fprop, @@ -107,7 +108,7 @@ def create_conv2d_operator( swizzling_functor_, ) - op_entry[opdef] = kernel_emitter.emit(op) + op_entry[opdef] = kernel_emitter.emit(op, no_bias_scaling) ret.append(op_entry) diff --git a/python/tvm/contrib/cutlass/gen_gemm.py b/python/tvm/contrib/cutlass/gen_gemm.py index c171c5e23a89..7048c32fe1da 100644 --- a/python/tvm/contrib/cutlass/gen_gemm.py +++ b/python/tvm/contrib/cutlass/gen_gemm.py @@ -123,7 +123,6 @@ def create_gemm_operator( DataTypeTag[element_c], op.leading_dim(), ) - op_entry["runtime"] = 9999999 op_entry["tile_description"] = tile_description op_entry["alignment"] = alignment op_entry["data_type"] = data_type diff --git a/python/tvm/contrib/cutlass/library.py b/python/tvm/contrib/cutlass/library.py index 902dc57100a9..8c3f5eb5df63 100644 --- a/python/tvm/contrib/cutlass/library.py +++ b/python/tvm/contrib/cutlass/library.py @@ -148,6 +148,7 @@ class EpilogueFunctor(enum.Enum): LinearCombinationRelu = enum_auto() LinearCombinationBias = enum_auto() LinearCombinationGelu = enum_auto() + LinearCombinationSigmoid = enum_auto() EpilogueFunctorTag = { @@ -155,6 +156,7 @@ class EpilogueFunctor(enum.Enum): EpilogueFunctor.LinearCombinationRelu: "cutlass::epilogue::thread::LinearCombinationRelu", EpilogueFunctor.LinearCombinationBias: "cutlass::epilogue::thread::LinearCombination", EpilogueFunctor.LinearCombinationGelu: "cutlass::epilogue::thread::LinearCombinationGELU", + EpilogueFunctor.LinearCombinationSigmoid: "cutlass::epilogue::thread::LinearCombinationSigmoid", } diff --git a/python/tvm/relay/op/contrib/cutlass.py b/python/tvm/relay/op/contrib/cutlass.py index 0a67581400ed..8fdd90ea109a 100644 --- a/python/tvm/relay/op/contrib/cutlass.py +++ b/python/tvm/relay/op/contrib/cutlass.py @@ -16,8 +16,9 @@ # under the License. # pylint: disable=invalid-name """Patterns supported CUTLASS.""" -from tvm.ir.transform import Sequential +from tvm.ir.transform import Sequential, PassContext from tvm.relay import transform +from tvm.relay.build_module import bind_params_by_name from ...dataflow_pattern import wildcard, is_op, is_constant @@ -57,8 +58,25 @@ def make_batch_matmul_pattern(): return is_op("nn.batch_matmul")(wildcard(), wildcard()) -def make_conv2d_pattern(): - return is_op("nn.conv2d")(wildcard(), wildcard()) +def make_conv2d_pattern(with_bias=False, with_act=None): + """Create a pattern for dense op followed by activations.""" + data = wildcard() + weight = wildcard() + bias = wildcard() + conv2d = is_op("nn.conv2d")(data, weight) + if with_bias: + add_or_bias_add = is_op("add") | is_op("nn.bias_add") + conv2d_out = add_or_bias_add(conv2d, bias) + else: + conv2d_out = conv2d + + if with_act is not None: + if with_act == "relu": + return is_op("nn.relu")(conv2d_out) + if with_act == "sigmoid": + return is_op("sigmoid")(conv2d_out) + + return conv2d_out def check_dtype(lhs, rhs): @@ -109,7 +127,7 @@ def check_conv2d(call): return not is_depthwise_conv2d(IC, OC, conv2d.attrs.groups) -def partition_for_cutlass(mod): +def partition_for_cutlass(mod, params=None): """Partition the input module into CUTLASS-supported subgraphs.""" dense_pat = ("cutlass.dense", make_gemm_pattern(False, None), check_gemm) dense_bias_pat = ("cutlass.dense_bias", make_gemm_pattern(True, None), check_gemm) @@ -131,15 +149,40 @@ def partition_for_cutlass(mod): dense_bias_pat, dense_pat, ("cutlass.batch_matmul", make_batch_matmul_pattern(), check_batch_matmul), - # TODO(masahi): Add more conv2d patterns + ( + "cutlass.conv2d_bias_relu", + make_conv2d_pattern(with_bias=True, with_act="relu"), + check_conv2d, + ), + ( + "cutlass.conv2d_bias_sigmoid", + make_conv2d_pattern(with_bias=True, with_act="sigmoid"), + check_conv2d, + ), + ("cutlass.conv2d_bias", make_conv2d_pattern(with_bias=True), check_conv2d), ("cutlass.conv2d", make_conv2d_pattern(), check_conv2d), ] + + if params is not None: + mod["main"] = bind_params_by_name(mod["main"], params) + remove_bn_pass = Sequential( + [ + transform.InferType(), + transform.SimplifyInference(), + transform.FoldConstant(), + transform.FoldScaleAxis(), + ] + ) + with PassContext(opt_level=3): + mod = remove_bn_pass(mod) + seq = Sequential( [ transform.InferType(), transform.MergeComposite(cutlass_patterns), transform.AnnotateTarget(["cutlass"]), - transform.PartitionGraph(), + transform.PartitionGraph(bind_constants=False), ] ) + return seq(mod) diff --git a/src/relay/backend/contrib/cutlass/codegen.cc b/src/relay/backend/contrib/cutlass/codegen.cc index c226da5864fc..d06ebaa896f4 100644 --- a/src/relay/backend/contrib/cutlass/codegen.cc +++ b/src/relay/backend/contrib/cutlass/codegen.cc @@ -263,6 +263,11 @@ Str2StrMap Conv2dArgs(const Map& attrs) { std::string Conv2dOp(std::string id, const Str2StrMap& attrs, const std::vector& func_args) { + bool has_bias = attrs.at("op_type") == "cutlass.conv2d_bias" || + attrs.at("op_type") == "cutlass.conv2d_bias_relu" || + attrs.at("op_type") == "cutlass.conv2d_bias_sigmoid"; + bool no_bias_scaling = attrs.at("op_type") != "cutlass.conv2d_bias_sigmoid"; + std::ostringstream conv2d_decl; CutlassPrint(conv2d_decl, "using ElementInputA = " + attrs.at("ElementInputA") + ";\n"); CutlassPrint(conv2d_decl, "using ElementInputB = " + attrs.at("ElementInputB") + ";\n"); @@ -307,10 +312,18 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs, ICHECK(func_args.size() >= 2); CutlassPrint(conv2d_decl, "void* ptr_a = (void*)(" + func_args[0] + "->data);\n"); CutlassPrint(conv2d_decl, "void* ptr_b = (void*)(" + func_args[1] + "->data);\n"); + if (has_bias) { + ICHECK(func_args.size() >= 3); + CutlassPrint(conv2d_decl, "void* ptr_c_bias = (void*)(" + func_args[2] + "->data);\n"); + } + CutlassPrint(conv2d_decl, "void* ptr_out = (void*)(out0->data);\n"); CutlassPrint(conv2d_decl, "ElementComputeEpilogue alpha = ElementComputeEpilogue(1);\n"); - CutlassPrint(conv2d_decl, "ElementComputeEpilogue beta = ElementComputeEpilogue(0);\n"); - + if (has_bias && no_bias_scaling) { + CutlassPrint(conv2d_decl, "ElementComputeEpilogue beta = ElementComputeEpilogue(0);\n"); + } else { + CutlassPrint(conv2d_decl, "ElementComputeEpilogue beta = ElementComputeEpilogue(1);\n"); + } CutlassPrint(conv2d_decl, "using cutlass::layout::TensorNHWC;\n"); CutlassPrint(conv2d_decl, "TensorNHWC layout_A(TensorNHWC::packed(cutlass::make_Coord(N, H, W, C)));\n"); @@ -322,9 +335,19 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs, CutlassPrint(conv2d_decl, " problem_size,\n"); CutlassPrint(conv2d_decl, " {static_cast(ptr_a), layout_A},\n"); CutlassPrint(conv2d_decl, " {static_cast(ptr_b), layout_B},\n"); + if (has_bias) { + CutlassPrint( + conv2d_decl, + " {static_cast(ptr_c_bias), cutlass::layout::TensorNHWC::Stride(0)},\n"); + } else { + CutlassPrint(conv2d_decl, " {static_cast(ptr_out),layout_C},\n"); + } CutlassPrint(conv2d_decl, " {static_cast(ptr_out),layout_C},\n"); - CutlassPrint(conv2d_decl, " {static_cast(ptr_out),layout_C},\n"); - CutlassPrint(conv2d_decl, "{alpha, beta}\n};\n"); + if (has_bias && no_bias_scaling) { + CutlassPrint(conv2d_decl, " {alpha}\n};\n"); + } else { + CutlassPrint(conv2d_decl, "{alpha, beta}\n};\n"); + } CutlassPrint(conv2d_decl, "Conv2d conv2d_op;\n"); CutlassPrint(conv2d_decl, "size_t workspace_size = conv2d_op.get_workspace_size(arguments);\n"); @@ -461,6 +484,27 @@ class CodegenCutlass : public MemoizedExprTranslator>, publi const auto* conv2d_call = GetRootCall(callee->body.as(), 0, {"nn.conv2d"}); return GenerateBody(conv2d_call, "cutlass_conv2d", GetArgumentNames(caller), Conv2dArgs(std::ref(attrs_))); + } else if (pattern_name == "cutlass.conv2d_bias") { + const CallNode* current_call = callee->body.as(); + std::string add_or_bias_add = current_call->op.as()->name; + const auto* conv2d_call = + GetRootCall(callee->body.as(), 1, {"nn.conv2d", add_or_bias_add}); + return GenerateBody(conv2d_call, "cutlass_conv2d_bias", GetArgumentNames(caller), + Conv2dArgs(std::ref(attrs_))); + } else if (pattern_name == "cutlass.conv2d_bias_relu") { + const CallNode* current_call = callee->body.as(); + std::string add_or_bias_add = current_call->args[0].as()->op.as()->name; + const auto* conv2d_call = + GetRootCall(callee->body.as(), 2, {"nn.conv2d", add_or_bias_add, "nn.relu"}); + return GenerateBody(conv2d_call, "cutlass_conv2d_bias_relu", GetArgumentNames(caller), + Conv2dArgs(std::ref(attrs_))); + } else if (pattern_name == "cutlass.conv2d_bias_sigmoid") { + const CallNode* current_call = callee->body.as(); + std::string add_or_bias_add = current_call->args[0].as()->op.as()->name; + const auto* conv2d_call = + GetRootCall(callee->body.as(), 2, {"nn.conv2d", add_or_bias_add, "sigmoid"}); + return GenerateBody(conv2d_call, "cutlass_conv2d_bias_sigmoid", GetArgumentNames(caller), + Conv2dArgs(std::ref(attrs_))); } LOG(FATAL) << "Unknown composite function: " << pattern_name; @@ -507,7 +551,9 @@ class CodegenCutlass : public MemoizedExprTranslator>, publi ret.decl = DenseOp(ext_func_id_, attribute_args, func_args); } else if (func_name == "cutlass_batch_matmul") { ret.decl = BatchMatmulOp(ext_func_id_, attribute_args, func_args); - } else if (func_name == "cutlass_conv2d") { + } else if (func_name == "cutlass_conv2d" || func_name == "cutlass_conv2d_bias" || + func_name == "cutlass_conv2d_bias_relu" || + func_name == "cutlass_conv2d_bias_sigmoid") { ret.decl = Conv2dOp(ext_func_id_, attribute_args, func_args); } diff --git a/tests/python/contrib/test_cutlass.py b/tests/python/contrib/test_cutlass.py index fee84d252081..89099c86dc58 100644 --- a/tests/python/contrib/test_cutlass.py +++ b/tests/python/contrib/test_cutlass.py @@ -114,18 +114,30 @@ def get_conv2d_nchw(d_shape, w_shape, padding, out_dtype="float16"): data = relay.var("data", shape=d_shape, dtype="float16") weight = relay.var("weight", shape=w_shape, dtype="float16") out_channel = w_shape[0] - return tvm.IRModule.from_expr( - relay.nn.conv2d( - data=data, - weight=weight, - kernel_size=w_shape[2:], - channels=out_channel, - padding=padding, - out_dtype=out_dtype, - ) + return relay.nn.conv2d( + data=data, + weight=weight, + kernel_size=w_shape[2:], + channels=out_channel, + padding=padding, + out_dtype=out_dtype, ) +def get_conv2d_nchw_bias(d_shape, w_shape, padding, out_dtype="float16"): + conv2d = get_conv2d_nchw(d_shape, w_shape, padding, out_dtype=out_dtype) + bias = relay.var("bias", shape=(w_shape[0],), dtype=out_dtype) + return relay.nn.bias_add(conv2d, bias) + + +def get_conv2d_nchw_bias_relu(d_shape, w_shape, padding, out_dtype="float16"): + return relay.nn.relu(get_conv2d_nchw_bias(d_shape, w_shape, padding, out_dtype=out_dtype)) + + +def get_conv2d_nchw_bias_sigmoid(d_shape, w_shape, padding, out_dtype="float16"): + return relay.sigmoid(get_conv2d_nchw_bias(d_shape, w_shape, padding, out_dtype=out_dtype)) + + def profile_and_build(mod, params, sm, tmp_dir="./tmp", lib_path="compile.so"): mod = partition_for_cutlass(mod) mod, num_cutlass_partition = tune_cutlass_kernels( @@ -314,8 +326,8 @@ def convert_conv2d_layout(mod, desired_layouts): def verify_conv2d( - mod_nchw, # can be dynamic batch - mod_ref, # always static batch + expr_nchw, # can be dynamic batch + expr_ref, # always static batch d_shape, w_shape, sm=80, @@ -327,10 +339,17 @@ def verify_conv2d( if not has_cutlass(): return + mod_nchw = tvm.IRModule.from_expr(expr_nchw) + mod_ref = tvm.IRModule.from_expr(expr_ref) + + typ = relay.transform.InferType()(mod_nchw)["main"].body.checked_type + out_dtype = typ.dtype + np_data = np.random.uniform(-1, 1, d_shape).astype("float16") np_weight = np.random.uniform(-1, 1, w_shape).astype("float16") + np_bias = np.random.uniform(-1, 1, (w_shape[0],)).astype(out_dtype) - params = {"weight": np_weight} + params = {"weight": np_weight, "bias": np_bias} typ = relay.transform.InferType()(mod_nchw)["main"].body.checked_type use_vm = any(isinstance(s, tvm.tir.Any) for s in typ.shape) @@ -373,10 +392,10 @@ def verify_conv2d( def test_conv2d(): + padding = (1, 1) for IC in [3, 16]: d_shape = (16, IC, 32, 32) w_shape = (32, IC, 3, 3) - padding = (1, 1) mod_nchw = get_conv2d_nchw(d_shape, w_shape, padding) verify_conv2d( @@ -404,5 +423,26 @@ def test_conv2d(): ) +def test_conv2d_fusion(): + d_shape = (16, 16, 32, 32) + w_shape = (32, 16, 3, 3) + padding = (1, 1) + + mod_nchw = get_conv2d_nchw_bias(d_shape, w_shape, padding) + verify_conv2d( + mod_nchw, mod_nchw, d_shape, w_shape, sm=80, atol=1e-5, rtol=1e-5, run_benchmark=False + ) + + mod_nchw = get_conv2d_nchw_bias_relu(d_shape, w_shape, padding) + verify_conv2d( + mod_nchw, mod_nchw, d_shape, w_shape, sm=80, atol=1e-5, rtol=1e-5, run_benchmark=False + ) + + mod_nchw = get_conv2d_nchw_bias_sigmoid(d_shape, w_shape, padding, out_dtype="float32") + verify_conv2d( + mod_nchw, mod_nchw, d_shape, w_shape, sm=80, atol=1e-5, rtol=1e-5, run_benchmark=False + ) + + if __name__ == "__main__": pytest.main([__file__])