Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix SCEV memory error #1524

Merged
merged 2 commits into from
Nov 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions enzyme/Enzyme/ActivityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ const char *KnownInactiveFunctionsContains[] = {
"__enzyme_pointer"};

const StringSet<> InactiveGlobals = {
"small_typeof",
"ompi_request_null",
"ompi_mpi_double",
"ompi_mpi_comm_world",
Expand Down
114 changes: 74 additions & 40 deletions enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6094,6 +6094,11 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
mode == DerivativeMode::ReverseModeCombined);

assert(val->getName() != "<badref>");
{
auto found = incoming_available.find(val);
if (found != incoming_available.end())
return found->second;
}
if (isa<Constant>(val)) {
return val;
}
Expand Down Expand Up @@ -6121,7 +6126,6 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
}

auto inst = cast<Instruction>(val);
assert(inst->getName() != "<badref>");
if (inversionAllocs && inst->getParent() == inversionAllocs) {
return val;
}
Expand Down Expand Up @@ -6418,7 +6422,7 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
auto li2obj = getBaseObject(li2->getPointerOperand());

if (liobj == li2obj && DT.dominates(li2, li)) {
auto orig2 = isOriginal(li2);
auto orig2 = dyn_cast_or_null<LoadInst>(isOriginal(li2));
if (!orig2)
continue;

Expand All @@ -6427,8 +6431,8 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
// llvm::errs() << "found potential candidate loads: oli:"
// << *origInst << " oli2: " << *orig2 << "\n";

auto scev1 = SE.getSCEV(li->getPointerOperand());
auto scev2 = SE.getSCEV(li2->getPointerOperand());
auto scev1 = SE.getSCEV(origInst->getPointerOperand());
auto scev2 = SE.getSCEV(orig2->getPointerOperand());
// llvm::errs() << " scev1: " << *scev1 << " scev2: " << *scev2
// << "\n";

Expand All @@ -6449,11 +6453,12 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,

if (auto ar1 = dyn_cast<SCEVAddRecExpr>(scev1)) {
if (auto ar2 = dyn_cast<SCEVAddRecExpr>(scev2)) {
if (ar1->getStart() != SE.getCouldNotCompute() &&
if (ar1->getStart() != OrigSE.getCouldNotCompute() &&
ar1->getStart() == ar2->getStart() &&
ar1->getStepRecurrence(SE) != SE.getCouldNotCompute() &&
ar1->getStepRecurrence(SE) ==
ar2->getStepRecurrence(SE)) {
ar1->getStepRecurrence(OrigSE) !=
OrigSE.getCouldNotCompute() &&
ar1->getStepRecurrence(OrigSE) ==
ar2->getStepRecurrence(OrigSE)) {

LoopContext l1;
getContext(ar1->getLoop()->getHeader(), l1);
Expand Down Expand Up @@ -6848,20 +6853,20 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
}
}

auto scev1 = SE.getSCEV(li->getPointerOperand());
auto scev1 = OrigSE.getSCEV(origInst->getPointerOperand());
// Store in memcpy opt
Value *lim = nullptr;
BasicBlock *ctx = nullptr;
Value *start = nullptr;
Value *offset = nullptr;
if (auto ar1 = dyn_cast<SCEVAddRecExpr>(scev1)) {
if (auto step =
dyn_cast<SCEVConstant>(ar1->getStepRecurrence(SE))) {
dyn_cast<SCEVConstant>(ar1->getStepRecurrence(OrigSE))) {
if (step->getAPInt() != loadSize)
goto noSpeedCache;

LoopContext l1;
getContext(ar1->getLoop()->getHeader(), l1);
getContext(getNewFromOriginal(ar1->getLoop()->getHeader()), l1);

if (l1.dynamic)
goto noSpeedCache;
Expand All @@ -6886,40 +6891,69 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
lim = v.CreateAdd(lim, ConstantInt::get(lim->getType(), 1), "",
true, true);

SmallVector<Instruction *, 4> toErase;
{
#if LLVM_VERSION_MAJOR >= 12
SCEVExpander Exp(SE,
ctx->getParent()->getParent()->getDataLayout(),
"enzyme");
#else
fake::SCEVExpander Exp(
SE, ctx->getParent()->getParent()->getDataLayout(),
"enzyme");
#endif
Exp.setInsertPoint(l1.header->getTerminator());
Value *start0 = Exp.expandCodeFor(
ar1->getStart(), li->getPointerOperand()->getType());
start = unwrapM(start0, v,
/*available*/ ValueToValueMapTy(),
UnwrapMode::AttemptFullUnwrapWithLookup);
std::set<Value *> todo = {start0};
while (todo.size()) {
Value *now = *todo.begin();
todo.erase(now);
if (Instruction *inst = dyn_cast<Instruction>(now)) {
if (inst != start && inst->getNumUses() == 0 &&
Exp.isInsertedInstruction(inst)) {
for (auto &op : inst->operands()) {
todo.insert(op);
}
toErase.push_back(inst);
}
Value *start0;
SmallVector<Instruction *, 32> InsertedInstructions;
{
SCEVExpander OrigExp(
OrigSE, ctx->getParent()->getParent()->getDataLayout(),
"enzyme");

OrigExp.setInsertPoint(
isOriginal(l1.header)->getTerminator());

start0 = OrigExp.expandCodeFor(
ar1->getStart(), li->getPointerOperand()->getType());
InsertedInstructions = OrigExp.getAllInsertedInstructions();
}

ValueToValueMapTy available;
for (const auto &pair : originalToNewFn) {
if (pair.first->getType() == pair.second->getType())
available[pair.first] = pair.second;
}

// Sort so that later instructions do not dominate earlier
// instructions.
llvm::stable_sort(InsertedInstructions,
[this](Instruction *A, Instruction *B) {
return OrigDT.dominates(A, B);
});
for (auto a : InsertedInstructions) {
assert(!isa<PHINode>(a));
auto uw = cast<Instruction>(
unwrapM(a, v, available, UnwrapMode::AttemptSingleUnwrap,
/*scope*/ nullptr, /*cache*/ false));
assert(uw->getType() == a->getType());
for (size_t i = 0; i < uw->getNumOperands(); i++) {
auto op = uw->getOperand(i);
if (auto arg = dyn_cast<Argument>(op))
assert(arg->getParent() == newFunc);
else if (auto inst = dyn_cast<Instruction>(op))
assert(inst->getParent()->getParent() == newFunc);
}
available[a] = uw;
unwrappedLoads.erase(cast<Instruction>(uw));
}

start =
isa<Constant>(start0) ? start0 : (Value *)available[start0];
if (!start) {
llvm::errs() << "old: " << *oldFunc << "\n";
llvm::errs() << "new: " << *newFunc << "\n";
llvm::errs() << "start0: " << *start0 << "\n";
}
assert(start);

available.clear();
for (auto I : llvm::reverse(InsertedInstructions)) {
assert(I->getNumUses() == 0);
OrigSE.forgetValue(I);
I->eraseFromParent();
}
#endif
}
for (auto a : toErase)
erase(a);

if (!start)
goto noSpeedCache;
Expand Down
15 changes: 10 additions & 5 deletions enzyme/Enzyme/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1914,7 +1914,8 @@ bool overwritesToMemoryReadBy(llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI,

if (auto LI = dyn_cast<LoadInst>(maybeReader)) {
LoadBegin = SE.getSCEV(LI->getPointerOperand());
if (LoadBegin != SE.getCouldNotCompute()) {
if (LoadBegin != SE.getCouldNotCompute() &&
!LoadBegin->getType()->isIntegerTy()) {
auto &DL = maybeWriter->getModule()->getDataLayout();
auto width = cast<IntegerType>(DL.getIndexType(LoadBegin->getType()))
->getBitWidth();
Expand All @@ -1930,7 +1931,8 @@ bool overwritesToMemoryReadBy(llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI,
}
if (auto SI = dyn_cast<StoreInst>(maybeWriter)) {
StoreBegin = SE.getSCEV(SI->getPointerOperand());
if (StoreBegin != SE.getCouldNotCompute()) {
if (StoreBegin != SE.getCouldNotCompute() &&
!StoreBegin->getType()->isIntegerTy()) {
auto &DL = maybeWriter->getModule()->getDataLayout();
auto width = cast<IntegerType>(DL.getIndexType(StoreBegin->getType()))
->getBitWidth();
Expand All @@ -1948,7 +1950,8 @@ bool overwritesToMemoryReadBy(llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI,
}
if (auto MS = dyn_cast<MemSetInst>(maybeWriter)) {
StoreBegin = SE.getSCEV(MS->getArgOperand(0));
if (StoreBegin != SE.getCouldNotCompute()) {
if (StoreBegin != SE.getCouldNotCompute() &&
!StoreBegin->getType()->isIntegerTy()) {
if (auto Len = dyn_cast<ConstantInt>(MS->getArgOperand(2))) {
auto &DL = MS->getModule()->getDataLayout();
auto width = cast<IntegerType>(DL.getIndexType(StoreBegin->getType()))
Expand All @@ -1961,7 +1964,8 @@ bool overwritesToMemoryReadBy(llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI,
}
if (auto MS = dyn_cast<MemTransferInst>(maybeWriter)) {
StoreBegin = SE.getSCEV(MS->getArgOperand(0));
if (StoreBegin != SE.getCouldNotCompute()) {
if (StoreBegin != SE.getCouldNotCompute() &&
!StoreBegin->getType()->isIntegerTy()) {
if (auto Len = dyn_cast<ConstantInt>(MS->getArgOperand(2))) {
auto &DL = MS->getModule()->getDataLayout();
auto width = cast<IntegerType>(DL.getIndexType(StoreBegin->getType()))
Expand All @@ -1974,7 +1978,8 @@ bool overwritesToMemoryReadBy(llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI,
}
if (auto MS = dyn_cast<MemTransferInst>(maybeReader)) {
LoadBegin = SE.getSCEV(MS->getArgOperand(1));
if (LoadBegin != SE.getCouldNotCompute()) {
if (LoadBegin != SE.getCouldNotCompute() &&
!LoadBegin->getType()->isIntegerTy()) {
if (auto Len = dyn_cast<ConstantInt>(MS->getArgOperand(2))) {
auto &DL = MS->getModule()->getDataLayout();
auto width = cast<IntegerType>(DL.getIndexType(LoadBegin->getType()))
Expand Down
4 changes: 2 additions & 2 deletions enzyme/test/Enzyme/ReverseMode/addrbug.ll
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -loop-deletion -simplifycfg -correlated-propagation -adce -instsimplify -S | FileCheck %s; fi
; RUN: %opt < %s %newLoadEnzyme -passes="enzyme,function(mem2reg,instsimplify,loop(loop-deletion),%simplifycfg,correlated-propagation,adce,instsimplify)" -enzyme-preopt=false -S | FileCheck %s
; RUN: if [ %llvmver -lt 16 ] && [ %llvmver -gt 11 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -loop-deletion -simplifycfg -correlated-propagation -adce -instsimplify -S | FileCheck %s; fi
; RUN: if [ %llvmver -gt 11 ]; then %opt < %s %newLoadEnzyme -passes="enzyme,function(mem2reg,instsimplify,loop(loop-deletion),%simplifycfg,correlated-propagation,adce,instsimplify)" -enzyme-preopt=false -S | FileCheck %s; fi

; Function Attrs: nounwind
declare void @__enzyme_autodiff(i8*, ...)
Expand Down
4 changes: 2 additions & 2 deletions enzyme/test/Enzyme/ReverseMode/makememcpy1.ll
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -instsimplify -loop-deletion -correlated-propagation -adce -simplifycfg -S | FileCheck %s; fi
; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,loop(loop-deletion),correlated-propagation,adce,%simplifycfg)" -S | FileCheck %s
; RUN: if [ %llvmver -lt 16 ] && [ %llvmver -gt 11 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -instsimplify -loop-deletion -correlated-propagation -adce -simplifycfg -S | FileCheck %s; fi
; RUN: if [ %llvmver -gt 11 ]; then %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,loop(loop-deletion),correlated-propagation,adce,%simplifycfg)" -S | FileCheck %s; fi

; This requires the additional optimization to create memcpy's

Expand Down
4 changes: 2 additions & 2 deletions enzyme/test/Enzyme/ReverseMode/metacachelicm.ll
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi
; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -S | FileCheck %s
; RUN: if [ %llvmver -lt 16 ] && [ %llvmver -gt 11 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi
; RUN: if [ %llvmver -gt 11 ]; then %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -S | FileCheck %s; fi

; Function Attrs: nounwind uwtable
define dso_local void @compute(double* noalias nocapture %data, i64* noalias nocapture readnone %array, double* noalias nocapture %out) #0 {
Expand Down
4 changes: 2 additions & 2 deletions enzyme/test/Enzyme/ReverseMode/metacachelicm2.ll
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi
; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -S | FileCheck %s
; RUN: if [ %llvmver -lt 16 ] && [ %llvmver -gt 11 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi
; RUN: if [ %llvmver -gt 11 ]; then %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -S | FileCheck %s; fi

; Function Attrs: nounwind uwtable
define dso_local void @compute(double* noalias nocapture %data, i64* noalias nocapture readonly %array, double* noalias nocapture %out) #0 {
Expand Down
2 changes: 1 addition & 1 deletion enzyme/test/Enzyme/ReverseMode/rwrloop.ll
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,11 @@ attributes #9 = { noreturn nounwind }

; CHECK: for.cond1.preheader: ; preds = %for.cond.cleanup3, %entry
; CHECK-NEXT: %iv = phi i64 [ %iv.next, %for.cond.cleanup3 ], [ 0, %entry ]
; CHECK-NEXT: %[[a2:.+]] = mul {{(nuw nsw )?}}i64 %iv, 10
; CHECK-NEXT: %iv.next = add nuw nsw i64 %iv, 1
; CHECK-NEXT: br i1 %cmp233, label %for.body4.lr.ph, label %for.cond.cleanup3

; CHECK: for.body4.lr.ph: ; preds = %for.cond1.preheader
; CHECK-NEXT: %[[a2:.+]] = mul {{(nuw nsw )?}}i64 %iv, 10
; CHECK-NEXT: %[[a3:.+]] = load i32, i32* %N, align 4, !tbaa !2, !alias.scope !8, !noalias !11, !invariant.group ![[INVG:[0-9]]]
; CHECK-NEXT: %[[a4:.+]] = getelementptr inbounds i32, i32* %[[malloccache12]], i64 %iv
; CHECK-NEXT: store i32 %[[a3]], i32* %[[a4]], align 4, !tbaa !2, !invariant.group ![[INVG]]
Expand Down
4 changes: 2 additions & 2 deletions enzyme/test/Enzyme/ReverseMode/sploop2.ll
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -instsimplify -adce -loop-deletion -correlated-propagation -simplifycfg -S | FileCheck %s; fi
; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,adce,loop(loop-deletion),correlated-propagation,%simplifycfg)" -S | FileCheck %s
; RUN: if [ %llvmver -lt 16 ] && [ %llvmver -gt 11 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -instsimplify -adce -loop-deletion -correlated-propagation -simplifycfg -S | FileCheck %s; fi
; RUN: if [ %llvmver -gt 11 ]; then %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,adce,loop(loop-deletion),correlated-propagation,%simplifycfg)" -S | FileCheck %s; fi

; This requires the memcpy optimization to run

Expand Down
Loading