From 13a73e5b5ac1f58cf900daa6648a7f444d24b190 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 2 Feb 2023 12:09:02 +0900 Subject: [PATCH] fix residual block fusion offload --- src/relay/backend/contrib/cutlass/codegen.cc | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/src/relay/backend/contrib/cutlass/codegen.cc b/src/relay/backend/contrib/cutlass/codegen.cc index b434280031..853a90154d 100644 --- a/src/relay/backend/contrib/cutlass/codegen.cc +++ b/src/relay/backend/contrib/cutlass/codegen.cc @@ -339,6 +339,7 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs, CutlassPrint(conv2d_decl, "void* ptr_b = (void*)(" + func_args[1] + "->data);\n"); if (has_residual_block) { + // TODO(masahi): This code assumes that there is always a bias_add in a residual block. ICHECK(func_args.size() >= 4); CutlassPrint(conv2d_decl, "void* ptr_bias = (void*)(" + func_args[2] + "->data);\n"); CutlassPrint(conv2d_decl, "void* ptr_residual = (void*)(" + func_args[3] + "->data);\n"); @@ -794,9 +795,24 @@ class CodegenCutlass : public backend::MemoizedExprTranslatorargs[!residual_index].as(); + const auto residual_input = binop->args[residual_index]; const auto* conv2d_call = GetRootCall(non_residual_input, "nn.conv2d"); ICHECK(conv2d_call); - return GenerateBody(conv2d_call, pattern_name.value(), GetArgumentNames(caller), + auto call_args = GetArgumentNames(caller); + auto func_args = call_args; + if (call_args.size() == 3) { + // TODO(masahi): This code assumes that there is always a bias_add in a residual block. + for (size_t i = 0; i < call_args.size(); ++i) { + if (callee->params[i] == residual_input) { + auto residual_input_name = call_args[i]; + func_args.push_back(residual_input_name); + } + } + } else { + ICHECK_EQ(func_args.size(), 4) << "Residual block fusion expects 4 input tensors: data, " + "weight, bias, and residual tensor."; + } + return GenerateBody(conv2d_call, pattern_name.value(), func_args, Conv2dArgs(std::ref(attrs_))); } else if (pattern_name == "cutlass.conv2d_transpose") { const auto* conv2d_call = GetRootCall(callee->body.as(), 0,