Skip to content

Commit

Permalink
Fix SCEV memory error (#1524)
Browse files Browse the repository at this point in the history
* Fix SCEV memory error

* Add smalltypeof to inactive
  • Loading branch information
wsmoses authored Nov 4, 2023
1 parent dfb8b55 commit 7508ad4
Show file tree
Hide file tree
Showing 9 changed files with 96 additions and 56 deletions.
1 change: 1 addition & 0 deletions enzyme/Enzyme/ActivityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ const char *KnownInactiveFunctionsContains[] = {
"__enzyme_pointer"};

const StringSet<> InactiveGlobals = {
"small_typeof",
"ompi_request_null",
"ompi_mpi_double",
"ompi_mpi_comm_world",
Expand Down
114 changes: 74 additions & 40 deletions enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6094,6 +6094,11 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
mode == DerivativeMode::ReverseModeCombined);

assert(val->getName() != "<badref>");
{
auto found = incoming_available.find(val);
if (found != incoming_available.end())
return found->second;
}
if (isa<Constant>(val)) {
return val;
}
Expand Down Expand Up @@ -6121,7 +6126,6 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
}

auto inst = cast<Instruction>(val);
assert(inst->getName() != "<badref>");
if (inversionAllocs && inst->getParent() == inversionAllocs) {
return val;
}
Expand Down Expand Up @@ -6418,7 +6422,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
auto li2obj = getBaseObject(li2->getPointerOperand());

if (liobj == li2obj && DT.dominates(li2, li)) {
auto orig2 = isOriginal(li2);
auto orig2 = dyn_cast_or_null<LoadInst>(isOriginal(li2));
if (!orig2)
continue;

Expand All @@ -6427,8 +6431,8 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
// llvm::errs() << "found potential candidate loads: oli:"
// << *origInst << " oli2: " << *orig2 << "\n";

auto scev1 = SE.getSCEV(li->getPointerOperand());
auto scev2 = SE.getSCEV(li2->getPointerOperand());
auto scev1 = SE.getSCEV(origInst->getPointerOperand());
auto scev2 = SE.getSCEV(orig2->getPointerOperand());
// llvm::errs() << " scev1: " << *scev1 << " scev2: " << *scev2
// << "\n";

Expand All @@ -6449,11 +6453,12 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,

if (auto ar1 = dyn_cast<SCEVAddRecExpr>(scev1)) {
if (auto ar2 = dyn_cast<SCEVAddRecExpr>(scev2)) {
if (ar1->getStart() != SE.getCouldNotCompute() &&
if (ar1->getStart() != OrigSE.getCouldNotCompute() &&
ar1->getStart() == ar2->getStart() &&
ar1->getStepRecurrence(SE) != SE.getCouldNotCompute() &&
ar1->getStepRecurrence(SE) ==
ar2->getStepRecurrence(SE)) {
ar1->getStepRecurrence(OrigSE) !=
OrigSE.getCouldNotCompute() &&
ar1->getStepRecurrence(OrigSE) ==
ar2->getStepRecurrence(OrigSE)) {

LoopContext l1;
getContext(ar1->getLoop()->getHeader(), l1);
Expand Down Expand Up @@ -6848,20 +6853,20 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
}
}

auto scev1 = SE.getSCEV(li->getPointerOperand());
auto scev1 = OrigSE.getSCEV(origInst->getPointerOperand());
// Store in memcpy opt
Value *lim = nullptr;
BasicBlock *ctx = nullptr;
Value *start = nullptr;
Value *offset = nullptr;
if (auto ar1 = dyn_cast<SCEVAddRecExpr>(scev1)) {
if (auto step =
dyn_cast<SCEVConstant>(ar1->getStepRecurrence(SE))) {
dyn_cast<SCEVConstant>(ar1->getStepRecurrence(OrigSE))) {
if (step->getAPInt() != loadSize)
goto noSpeedCache;

LoopContext l1;
getContext(ar1->getLoop()->getHeader(), l1);
getContext(getNewFromOriginal(ar1->getLoop()->getHeader()), l1);

if (l1.dynamic)
goto noSpeedCache;
Expand All @@ -6886,40 +6891,69 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
lim = v.CreateAdd(lim, ConstantInt::get(lim->getType(), 1), "",
true, true);

SmallVector<Instruction *, 4> toErase;
{
#if LLVM_VERSION_MAJOR >= 12
SCEVExpander Exp(SE,
ctx->getParent()->getParent()->getDataLayout(),
"enzyme");
#else
fake::SCEVExpander Exp(
SE, ctx->getParent()->getParent()->getDataLayout(),
"enzyme");
#endif
Exp.setInsertPoint(l1.header->getTerminator());
Value *start0 = Exp.expandCodeFor(
ar1->getStart(), li->getPointerOperand()->getType());
start = unwrapM(start0, v,
/*available*/ ValueToValueMapTy(),
UnwrapMode::AttemptFullUnwrapWithLookup);
std::set<Value *> todo = {start0};
while (todo.size()) {
Value *now = *todo.begin();
todo.erase(now);
if (Instruction *inst = dyn_cast<Instruction>(now)) {
if (inst != start && inst->getNumUses() == 0 &&
Exp.isInsertedInstruction(inst)) {
for (auto &op : inst->operands()) {
todo.insert(op);
}
toErase.push_back(inst);
}
Value *start0;
SmallVector<Instruction *, 32> InsertedInstructions;
{
SCEVExpander OrigExp(
OrigSE, ctx->getParent()->getParent()->getDataLayout(),
"enzyme");

OrigExp.setInsertPoint(
isOriginal(l1.header)->getTerminator());

start0 = OrigExp.expandCodeFor(
ar1->getStart(), li->getPointerOperand()->getType());
InsertedInstructions = OrigExp.getAllInsertedInstructions();
}

ValueToValueMapTy available;
for (const auto &pair : originalToNewFn) {
if (pair.first->getType() == pair.second->getType())
available[pair.first] = pair.second;
}

// Sort so that later instructions do not dominate earlier
// instructions.
llvm::stable_sort(InsertedInstructions,
[this](Instruction *A, Instruction *B) {
return OrigDT.dominates(A, B);
});
for (auto a : InsertedInstructions) {
assert(!isa<PHINode>(a));
auto uw = cast<Instruction>(
unwrapM(a, v, available, UnwrapMode::AttemptSingleUnwrap,
/*scope*/ nullptr, /*cache*/ false));
assert(uw->getType() == a->getType());
for (size_t i = 0; i < uw->getNumOperands(); i++) {
auto op = uw->getOperand(i);
if (auto arg = dyn_cast<Argument>(op))
assert(arg->getParent() == newFunc);
else if (auto inst = dyn_cast<Instruction>(op))
assert(inst->getParent()->getParent() == newFunc);
}
available[a] = uw;
unwrappedLoads.erase(cast<Instruction>(uw));
}

start =
isa<Constant>(start0) ? start0 : (Value *)available[start0];
if (!start) {
llvm::errs() << "old: " << *oldFunc << "\n";
llvm::errs() << "new: " << *newFunc << "\n";
llvm::errs() << "start0: " << *start0 << "\n";
}
assert(start);

available.clear();
for (auto I : llvm::reverse(InsertedInstructions)) {
assert(I->getNumUses() == 0);
OrigSE.forgetValue(I);
I->eraseFromParent();
}
#endif
}
for (auto a : toErase)
erase(a);

if (!start)
goto noSpeedCache;
Expand Down
15 changes: 10 additions & 5 deletions enzyme/Enzyme/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1914,7 +1914,8 @@ bool overwritesToMemoryReadBy(llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI,

if (auto LI = dyn_cast<LoadInst>(maybeReader)) {
LoadBegin = SE.getSCEV(LI->getPointerOperand());
if (LoadBegin != SE.getCouldNotCompute()) {
if (LoadBegin != SE.getCouldNotCompute() &&
!LoadBegin->getType()->isIntegerTy()) {
auto &DL = maybeWriter->getModule()->getDataLayout();
auto width = cast<IntegerType>(DL.getIndexType(LoadBegin->getType()))
->getBitWidth();
Expand All @@ -1930,7 +1931,8 @@ bool overwritesToMemoryReadBy(llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI,
}
if (auto SI = dyn_cast<StoreInst>(maybeWriter)) {
StoreBegin = SE.getSCEV(SI->getPointerOperand());
if (StoreBegin != SE.getCouldNotCompute()) {
if (StoreBegin != SE.getCouldNotCompute() &&
!StoreBegin->getType()->isIntegerTy()) {
auto &DL = maybeWriter->getModule()->getDataLayout();
auto width = cast<IntegerType>(DL.getIndexType(StoreBegin->getType()))
->getBitWidth();
Expand All @@ -1948,7 +1950,8 @@ bool overwritesToMemoryReadBy(llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI,
}
if (auto MS = dyn_cast<MemSetInst>(maybeWriter)) {
StoreBegin = SE.getSCEV(MS->getArgOperand(0));
if (StoreBegin != SE.getCouldNotCompute()) {
if (StoreBegin != SE.getCouldNotCompute() &&
!StoreBegin->getType()->isIntegerTy()) {
if (auto Len = dyn_cast<ConstantInt>(MS->getArgOperand(2))) {
auto &DL = MS->getModule()->getDataLayout();
auto width = cast<IntegerType>(DL.getIndexType(StoreBegin->getType()))
Expand All @@ -1961,7 +1964,8 @@ bool overwritesToMemoryReadBy(llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI,
}
if (auto MS = dyn_cast<MemTransferInst>(maybeWriter)) {
StoreBegin = SE.getSCEV(MS->getArgOperand(0));
if (StoreBegin != SE.getCouldNotCompute()) {
if (StoreBegin != SE.getCouldNotCompute() &&
!StoreBegin->getType()->isIntegerTy()) {
if (auto Len = dyn_cast<ConstantInt>(MS->getArgOperand(2))) {
auto &DL = MS->getModule()->getDataLayout();
auto width = cast<IntegerType>(DL.getIndexType(StoreBegin->getType()))
Expand All @@ -1974,7 +1978,8 @@ bool overwritesToMemoryReadBy(llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI,
}
if (auto MS = dyn_cast<MemTransferInst>(maybeReader)) {
LoadBegin = SE.getSCEV(MS->getArgOperand(1));
if (LoadBegin != SE.getCouldNotCompute()) {
if (LoadBegin != SE.getCouldNotCompute() &&
!LoadBegin->getType()->isIntegerTy()) {
if (auto Len = dyn_cast<ConstantInt>(MS->getArgOperand(2))) {
auto &DL = MS->getModule()->getDataLayout();
auto width = cast<IntegerType>(DL.getIndexType(LoadBegin->getType()))
Expand Down
4 changes: 2 additions & 2 deletions enzyme/test/Enzyme/ReverseMode/addrbug.ll
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -loop-deletion -simplifycfg -correlated-propagation -adce -instsimplify -S | FileCheck %s; fi
; RUN: %opt < %s %newLoadEnzyme -passes="enzyme,function(mem2reg,instsimplify,loop(loop-deletion),%simplifycfg,correlated-propagation,adce,instsimplify)" -enzyme-preopt=false -S | FileCheck %s
; RUN: if [ %llvmver -lt 16 ] && [ %llvmver -gt 11 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -loop-deletion -simplifycfg -correlated-propagation -adce -instsimplify -S | FileCheck %s; fi
; RUN: if [ %llvmver -gt 11 ]; then %opt < %s %newLoadEnzyme -passes="enzyme,function(mem2reg,instsimplify,loop(loop-deletion),%simplifycfg,correlated-propagation,adce,instsimplify)" -enzyme-preopt=false -S | FileCheck %s; fi

; Function Attrs: nounwind
declare void @__enzyme_autodiff(i8*, ...)
Expand Down
4 changes: 2 additions & 2 deletions enzyme/test/Enzyme/ReverseMode/makememcpy1.ll
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -instsimplify -loop-deletion -correlated-propagation -adce -simplifycfg -S | FileCheck %s; fi
; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,loop(loop-deletion),correlated-propagation,adce,%simplifycfg)" -S | FileCheck %s
; RUN: if [ %llvmver -lt 16 ] && [ %llvmver -gt 11 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -instsimplify -loop-deletion -correlated-propagation -adce -simplifycfg -S | FileCheck %s; fi
; RUN: if [ %llvmver -gt 11 ]; then %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,loop(loop-deletion),correlated-propagation,adce,%simplifycfg)" -S | FileCheck %s; fi

; This requires the additional optimization to create memcpy's

Expand Down
4 changes: 2 additions & 2 deletions enzyme/test/Enzyme/ReverseMode/metacachelicm.ll
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi
; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -S | FileCheck %s
; RUN: if [ %llvmver -lt 16 ] && [ %llvmver -gt 11 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi
; RUN: if [ %llvmver -gt 11 ]; then %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -S | FileCheck %s; fi

; Function Attrs: nounwind uwtable
define dso_local void @compute(double* noalias nocapture %data, i64* noalias nocapture readnone %array, double* noalias nocapture %out) #0 {
Expand Down
4 changes: 2 additions & 2 deletions enzyme/test/Enzyme/ReverseMode/metacachelicm2.ll
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi
; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -S | FileCheck %s
; RUN: if [ %llvmver -lt 16 ] && [ %llvmver -gt 11 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi
; RUN: if [ %llvmver -gt 11 ]; then %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -S | FileCheck %s; fi

; Function Attrs: nounwind uwtable
define dso_local void @compute(double* noalias nocapture %data, i64* noalias nocapture readonly %array, double* noalias nocapture %out) #0 {
Expand Down
2 changes: 1 addition & 1 deletion enzyme/test/Enzyme/ReverseMode/rwrloop.ll
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,11 @@ attributes #9 = { noreturn nounwind }

; CHECK: for.cond1.preheader: ; preds = %for.cond.cleanup3, %entry
; CHECK-NEXT: %iv = phi i64 [ %iv.next, %for.cond.cleanup3 ], [ 0, %entry ]
; CHECK-NEXT: %[[a2:.+]] = mul {{(nuw nsw )?}}i64 %iv, 10
; CHECK-NEXT: %iv.next = add nuw nsw i64 %iv, 1
; CHECK-NEXT: br i1 %cmp233, label %for.body4.lr.ph, label %for.cond.cleanup3

; CHECK: for.body4.lr.ph: ; preds = %for.cond1.preheader
; CHECK-NEXT: %[[a2:.+]] = mul {{(nuw nsw )?}}i64 %iv, 10
; CHECK-NEXT: %[[a3:.+]] = load i32, i32* %N, align 4, !tbaa !2, !alias.scope !8, !noalias !11, !invariant.group ![[INVG:[0-9]]]
; CHECK-NEXT: %[[a4:.+]] = getelementptr inbounds i32, i32* %[[malloccache12]], i64 %iv
; CHECK-NEXT: store i32 %[[a3]], i32* %[[a4]], align 4, !tbaa !2, !invariant.group ![[INVG]]
Expand Down
4 changes: 2 additions & 2 deletions enzyme/test/Enzyme/ReverseMode/sploop2.ll
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -instsimplify -adce -loop-deletion -correlated-propagation -simplifycfg -S | FileCheck %s; fi
; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,adce,loop(loop-deletion),correlated-propagation,%simplifycfg)" -S | FileCheck %s
; RUN: if [ %llvmver -lt 16 ] && [ %llvmver -gt 11 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -instsimplify -adce -loop-deletion -correlated-propagation -simplifycfg -S | FileCheck %s; fi
; RUN: if [ %llvmver -gt 11 ]; then %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,adce,loop(loop-deletion),correlated-propagation,%simplifycfg)" -S | FileCheck %s; fi

; This requires the memcpy optimization to run

Expand Down

0 comments on commit 7508ad4

Please sign in to comment.