Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

Commit

Permalink
fix residual block fusion offload
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Feb 2, 2023
1 parent b90ab7b commit 13a73e5
Showing 1 changed file with 17 additions and 1 deletion.
18 changes: 17 additions & 1 deletion src/relay/backend/contrib/cutlass/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs,
CutlassPrint(conv2d_decl, "void* ptr_b = (void*)(" + func_args[1] + "->data);\n");

if (has_residual_block) {
// TODO(masahi): This code assumes that there is always a bias_add in a residual block.
ICHECK(func_args.size() >= 4);
CutlassPrint(conv2d_decl, "void* ptr_bias = (void*)(" + func_args[2] + "->data);\n");
CutlassPrint(conv2d_decl, "void* ptr_residual = (void*)(" + func_args[3] + "->data);\n");
Expand Down Expand Up @@ -794,9 +795,24 @@ class CodegenCutlass : public backend::MemoizedExprTranslator<std::vector<Output
residual_index = IsAncestor(rhs, lhs) ? 1 : 0;
}
const auto* non_residual_input = binop->args[!residual_index].as<CallNode>();
const auto residual_input = binop->args[residual_index];
const auto* conv2d_call = GetRootCall(non_residual_input, "nn.conv2d");
ICHECK(conv2d_call);
return GenerateBody(conv2d_call, pattern_name.value(), GetArgumentNames(caller),
auto call_args = GetArgumentNames(caller);
auto func_args = call_args;
if (call_args.size() == 3) {
// TODO(masahi): This code assumes that there is always a bias_add in a residual block.
for (size_t i = 0; i < call_args.size(); ++i) {
if (callee->params[i] == residual_input) {
auto residual_input_name = call_args[i];
func_args.push_back(residual_input_name);
}
}
} else {
ICHECK_EQ(func_args.size(), 4) << "Residual block fusion expects 4 input tensors: data, "
"weight, bias, and residual tensor.";
}
return GenerateBody(conv2d_call, pattern_name.value(), func_args,
Conv2dArgs(std::ref(attrs_)));
} else if (pattern_name == "cutlass.conv2d_transpose") {
const auto* conv2d_call = GetRootCall(callee->body.as<CallNode>(), 0,
Expand Down

0 comments on commit 13a73e5

Please sign in to comment.