Skip to content

Commit

Permalink
return early if output arg is runtime inactive
Browse files Browse the repository at this point in the history
  • Loading branch information
ZuseZ4 committed Sep 21, 2023
1 parent 8f1cacf commit f895a31
Showing 1 changed file with 25 additions and 0 deletions.
25 changes: 25 additions & 0 deletions enzyme/tools/enzyme-tblgen/blas-tblgen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1237,6 +1237,7 @@ void emit_rev_rewrite_rules(const StringMap<TGPattern> &patternMap,
const auto rules = pattern.getRules();
const auto activeArgs = pattern.getActiveArgs();
const bool lv23 = pattern.isBLASLevel2or3();
const auto mutArgSet = pattern.getMutableArgs();

// If any of the rule uses DiffeRet, the primary function has a ret val
// and we should emit the code for handling it.
Expand Down Expand Up @@ -1323,6 +1324,30 @@ void emit_rev_rewrite_rules(const StringMap<TGPattern> &patternMap,
}
os << " }\n";

// Blas functions return one float XOR modify one output arg.
// If we have runtimeActivity and the output arg is inactive,
// we don't need to do anything here and can return early.
if (mutArgSet.size() == 1) {
for (auto pos : mutArgSet) {
auto name = nameVec[pos];
os << " Value *rt_inactive_out = rt_inactive_" << name << ";\n";
os << " if (EnzymeRuntimeActivityCheck && cacheMode) {\n"
<< " BasicBlock *current = Builder2.GetInsertBlock();\n"
<< " auto bb_name = Builder2.GetInsertBlock()->getName();\n"
<< " auto earlyRetBlock = gutils->addReverseBlock(current, bb_name "
"+ \".early.return\");\n"
<< " auto continueBlock = gutils->addReverseBlock(earlyRetBlock, "
"bb_name + \".continue\");\n"
<< " Builder2.CreateCondBr(rt_inactive_out, earlyRetBlock, "
"continueBlock);\n"
<< " Builder2.SetInsertPoint(earlyRetBlock);\n"
<< " Builder2.CreateRetVoid();\n"
<< " Builder2.SetInsertPoint(continueBlock);\n"
<< " }\n";
break;
}
}

// now we can use it to transpose our trans arguments if they exist
for (size_t i = (lv23 ? 1 : 0); i < nameVec.size(); i++) {
auto name = nameVec[i];
Expand Down

0 comments on commit f895a31

Please sign in to comment.