Skip to content

Commit

Permalink
Generalize binop inverse of active float
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Sep 12, 2023
1 parent 5d3206b commit 6f7dd5b
Showing 1 changed file with 16 additions and 2 deletions.
18 changes: 16 additions & 2 deletions enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5653,8 +5653,22 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM,
return li;

} else if (auto arg = dyn_cast<BinaryOperator>(oval)) {
if (arg->getOpcode() == Instruction::FAdd)
return getNewFromOriginal(arg);
switch (mode) {
case DerivativeMode::ReverseModePrimal:
case DerivativeMode::ReverseModeCombined:
case DerivativeMode::ReverseModeGradient:
if (TR.query(arg)[{-1}].isFloat()) {
auto newv = getNewFromOriginal(arg);
IRBuilder<> bb(newv);
auto res =
applyChainRule(newv->getType(), bb, [&newv] { return newv; });
invertedPointers.insert(
std::make_pair((const Value *)oval, InvertedPointerVH(this, res)));
return res;
}
default:
break;
}

if (!arg->getType()->isIntOrIntVectorTy()) {
llvm::errs() << *oval << "\n";
Expand Down

0 comments on commit 6f7dd5b

Please sign in to comment.