diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index e5929a1fd763..a5ed6c36d203 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -6094,6 +6094,11 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, mode == DerivativeMode::ReverseModeCombined); assert(val->getName() != ""); + { + auto found = incoming_available.find(val); + if (found != incoming_available.end()) + return found->second; + } if (isa(val)) { return val; } @@ -6121,7 +6126,6 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, } auto inst = cast(val); - assert(inst->getName() != ""); if (inversionAllocs && inst->getParent() == inversionAllocs) { return val; } @@ -6418,7 +6422,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, auto li2obj = getBaseObject(li2->getPointerOperand()); if (liobj == li2obj && DT.dominates(li2, li)) { - auto orig2 = isOriginal(li2); + auto orig2 = dyn_cast_or_null(isOriginal(li2)); if (!orig2) continue; @@ -6427,8 +6431,8 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, // llvm::errs() << "found potential candidate loads: oli:" // << *origInst << " oli2: " << *orig2 << "\n"; - auto scev1 = SE.getSCEV(li->getPointerOperand()); - auto scev2 = SE.getSCEV(li2->getPointerOperand()); + auto scev1 = SE.getSCEV(origInst->getPointerOperand()); + auto scev2 = SE.getSCEV(orig2->getPointerOperand()); // llvm::errs() << " scev1: " << *scev1 << " scev2: " << *scev2 // << "\n"; @@ -6449,11 +6453,12 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, if (auto ar1 = dyn_cast(scev1)) { if (auto ar2 = dyn_cast(scev2)) { - if (ar1->getStart() != SE.getCouldNotCompute() && + if (ar1->getStart() != OrigSE.getCouldNotCompute() && ar1->getStart() == ar2->getStart() && - ar1->getStepRecurrence(SE) != SE.getCouldNotCompute() && - ar1->getStepRecurrence(SE) == - ar2->getStepRecurrence(SE)) { + ar1->getStepRecurrence(OrigSE) != + OrigSE.getCouldNotCompute() && + ar1->getStepRecurrence(OrigSE) == + ar2->getStepRecurrence(OrigSE)) { LoopContext l1; getContext(ar1->getLoop()->getHeader(), l1); @@ -6848,7 +6853,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, } } - auto scev1 = SE.getSCEV(li->getPointerOperand()); + auto scev1 = OrigSE.getSCEV(origInst->getPointerOperand()); // Store in memcpy opt Value *lim = nullptr; BasicBlock *ctx = nullptr; @@ -6856,12 +6861,12 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, Value *offset = nullptr; if (auto ar1 = dyn_cast(scev1)) { if (auto step = - dyn_cast(ar1->getStepRecurrence(SE))) { + dyn_cast(ar1->getStepRecurrence(OrigSE))) { if (step->getAPInt() != loadSize) goto noSpeedCache; LoopContext l1; - getContext(ar1->getLoop()->getHeader(), l1); + getContext(getNewFromOriginal(ar1->getLoop()->getHeader()), l1); if (l1.dynamic) goto noSpeedCache; @@ -6886,40 +6891,60 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM, lim = v.CreateAdd(lim, ConstantInt::get(lim->getType(), 1), "", true, true); - SmallVector toErase; { #if LLVM_VERSION_MAJOR >= 12 - SCEVExpander Exp(SE, - ctx->getParent()->getParent()->getDataLayout(), - "enzyme"); -#else - fake::SCEVExpander Exp( - SE, ctx->getParent()->getParent()->getDataLayout(), - "enzyme"); -#endif - Exp.setInsertPoint(l1.header->getTerminator()); - Value *start0 = Exp.expandCodeFor( - ar1->getStart(), li->getPointerOperand()->getType()); - start = unwrapM(start0, v, - /*available*/ ValueToValueMapTy(), - UnwrapMode::AttemptFullUnwrapWithLookup); - std::set todo = {start0}; - while (todo.size()) { - Value *now = *todo.begin(); - todo.erase(now); - if (Instruction *inst = dyn_cast(now)) { - if (inst != start && inst->getNumUses() == 0 && - Exp.isInsertedInstruction(inst)) { - for (auto &op : inst->operands()) { - todo.insert(op); - } - toErase.push_back(inst); - } + Value *start0; + SmallVector InsertedInstructions; + { + SCEVExpander OrigExp( + OrigSE, ctx->getParent()->getParent()->getDataLayout(), + "enzyme"); + + OrigExp.setInsertPoint( + isOriginal(l1.header)->getTerminator()); + + start0 = OrigExp.expandCodeFor( + ar1->getStart(), li->getPointerOperand()->getType()); + InsertedInstructions = OrigExp.getAllInsertedInstructions(); + } + + ValueToValueMapTy available; + for (const auto &pair : originalToNewFn) { + available[pair.first] = pair.second; + } + + // Sort so that later instructions do not dominate earlier + // instructions. + llvm::stable_sort(InsertedInstructions, + [this](Instruction *A, Instruction *B) { + return OrigDT.dominates(A, B); + }); + for (auto a : InsertedInstructions) { + assert(!isa(a)); + auto uw = cast( + unwrapM(a, v, available, UnwrapMode::AttemptSingleUnwrap, + /*scope*/ nullptr, /*cache*/ false)); + for (size_t i = 0; i < uw->getNumOperands(); i++) { + auto op = uw->getOperand(i); + if (auto arg = dyn_cast(op)) + assert(arg->getParent() == newFunc); + else if (auto inst = dyn_cast(op)) + assert(inst->getParent()->getParent() == newFunc); } + available[a] = uw; + unwrappedLoads.erase(cast(uw)); } + + start = available[start0]; + assert(start); + + available.clear(); + for (auto I : llvm::reverse(InsertedInstructions)) { + assert(I->getNumUses() == 0); + I->eraseFromParent(); + } +#endif } - for (auto a : toErase) - erase(a); if (!start) goto noSpeedCache;