Skip to content

Commit

Permalink
Simplify insert/extract, as required for address prop
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Sep 13, 2023
1 parent 65d6481 commit 4ddf79f
Show file tree
Hide file tree
Showing 18 changed files with 225 additions and 258 deletions.
45 changes: 45 additions & 0 deletions enzyme/Enzyme/FunctionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,52 @@ void PreProcessCache::AlwaysInline(Function *NewF) {
}
}

// Simplify all extractions to use inserted values, if possible.
void simplifyExtractions(Function *NewF) {
// First rewrite/remove any extractions
for (auto &BB : *NewF) {
IRBuilder<> B(&BB);
auto first = BB.begin();
auto last = BB.empty() ? BB.end() : std::prev(BB.end());
for (auto it = first; it != last;) {
auto inst = &*it;
// We iterate first here, since we may delete the instruction
// in the body
++it;
if (auto E = dyn_cast<ExtractValueInst>(inst)) {
auto rep = GradientUtils::extractMeta(B, E->getAggregateOperand(),
E->getIndices(), E->getName(),
/*fallback*/ false);
if (rep) {
E->replaceAllUsesWith(rep);
E->eraseFromParent();
}
}
}
}
// Now that there may be unused insertions, delete them. We keep a list of
// todo's since deleting an insertvalue may cause a different insertvalue to
// have no uses
SmallVector<InsertValueInst *, 1> todo;
for (auto &BB : *NewF) {
for (auto &inst : BB)
if (auto I = dyn_cast<InsertValueInst>(&inst)) {
if (I->getNumUses() == 0)
todo.push_back(I);
}
}
while (todo.size()) {
auto I = todo.pop_back_val();
auto op = I->getAggregateOperand();
I->eraseFromParent();
if (auto I2 = dyn_cast<InsertValueInst>(op))
if (I2->getNumUses() == 0)
todo.push_back(I2);
}
}

void PreProcessCache::LowerAllocAddr(Function *NewF) {
simplifyExtractions(NewF);
SmallVector<Instruction *, 1> Todo;
for (auto &BB : *NewF) {
for (auto &I : BB) {
Expand Down
6 changes: 5 additions & 1 deletion enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4804,7 +4804,7 @@ Value *GradientUtils::extractMeta(IRBuilder<> &Builder, Value *Agg,

Value *GradientUtils::extractMeta(IRBuilder<> &Builder, Value *Agg,
ArrayRef<unsigned> off_init,
const Twine &name) {
const Twine &name, bool fallback) {
std::vector<unsigned> off(off_init.begin(), off_init.end());
while (off.size() != 0) {
if (auto Ins = dyn_cast<InsertValueInst>(Agg)) {
Expand Down Expand Up @@ -4843,6 +4843,10 @@ Value *GradientUtils::extractMeta(IRBuilder<> &Builder, Value *Agg,
}
if (off.size() == 0)
return Agg;

if (!fallback)
return nullptr;

if (Agg->getType()->isVectorTy() && off.size() == 1)
return Builder.CreateExtractElement(Agg, off[0], name);

Expand Down
11 changes: 10 additions & 1 deletion enzyme/Enzyme/GradientUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -501,11 +501,20 @@ class GradientUtils : public CacheUtility {

llvm::Type *getShadowType(llvm::Type *ty);

//! Helper routine to extract a nested element from a struct/array. This is
// a one dimensional special case of the multi-dim extractMeta below.
static llvm::Value *extractMeta(llvm::IRBuilder<> &Builder, llvm::Value *Agg,
unsigned off, const llvm::Twine &name = "");

//! Helper routine to extract a nested element from a struct/array. Unlike the
// LLVM instruction, this will attempt to re-use the inserted value, if it
// exists, rather than always creating a new instruction. If fallback is
// true (the default), it will create an instruction if it fails to find an
// appropriate existing value, otherwise it returns nullptr.
static llvm::Value *extractMeta(llvm::IRBuilder<> &Builder, llvm::Value *Agg,
llvm::ArrayRef<unsigned> off,
const llvm::Twine &name = "");
const llvm::Twine &name = "",
bool fallback = true);

static llvm::Value *recursiveFAdd(llvm::IRBuilder<> &B, llvm::Value *lhs,
llvm::Value *rhs,
Expand Down
1 change: 0 additions & 1 deletion enzyme/test/Enzyme/ForwardMode/invptrint.ll
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ bb:
; CHECK: define internal { { double*, i64 }, double } @fwddiffez0({ double*, i64 } %const, double %act, double %"act'")
; CHECK-NEXT: bb:
; CHECK-NEXT: %"res'ipiv" = insertvalue { { double*, i64 }, double } zeroinitializer, { double*, i64 } %const, 0
; CHECK-NEXT: %res = insertvalue { { double*, i64 }, double } undef, { double*, i64 } %const, 0
; CHECK-NEXT: %"res2'ipiv" = insertvalue { { double*, i64 }, double } %"res'ipiv", double %"act'", 1
; CHECK-NEXT: ret { { double*, i64 }, double } %"res2'ipiv"
; CHECK-NEXT: }
1 change: 0 additions & 1 deletion enzyme/test/Enzyme/ForwardMode/invptrint2.ll
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ bb:
; CHECK-NEXT: store double 0.000000e+00, double* %3
; CHECK-NEXT: %4 = load { double*, i64, double }, { double*, i64, double }* %0
; CHECK-NEXT: %"res'ipiv" = insertvalue { { double*, i64, double }, double } zeroinitializer, { double*, i64, double } %4, 0
; CHECK-NEXT: %res = insertvalue { { double*, i64, double }, double } undef, { double*, i64, double } %const, 0
; CHECK-NEXT: %"res2'ipiv" = insertvalue { { double*, i64, double }, double } %"res'ipiv", double %"act'", 1
; CHECK-NEXT: ret { { double*, i64, double }, double } %"res2'ipiv"
; CHECK-NEXT: }
4 changes: 2 additions & 2 deletions enzyme/test/Enzyme/ForwardModeVector/fabs.ll
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ declare double @llvm.fabs.f64(double)
; CHECK-NEXT: %[[i2:.+]] = select {{(fast )?}}i1 %[[i1]], double -1.000000e+00, double 1.000000e+00
; CHECK-NEXT: %[[i0:.+]] = extractvalue [2 x double] %"x'", 0
; CHECK-NEXT: %[[i3:.+]] = fmul fast double %[[i0]], %[[i2]]
; CHECK-NEXT: %[[i4:.+]] = insertvalue [2 x double] undef, double %[[i3]], 0
; CHECK-NEXT: %[[i5:.+]] = extractvalue [2 x double] %"x'", 1
; CHECK-NEXT: %[[i6:.+]] = fmul fast double %[[i5]], %[[i2]]
; CHECK-NEXT: %[[i4:.+]] = insertvalue [2 x double] undef, double %[[i3]], 0
; CHECK-NEXT: %[[i7:.+]] = insertvalue [2 x double] %[[i4]], double %[[i6]], 1
; CHECK-NEXT: ret [2 x double] %[[i7]]
; CHECK-NEXT: }
; CHECK-NEXT: }
33 changes: 6 additions & 27 deletions enzyme/test/Enzyme/ForwardModeVector/globallower.ll
Original file line number Diff line number Diff line change
Expand Up @@ -31,51 +31,30 @@ entry:

; CHECK: define {{[^@]+}}@fwddiffe3mulglobal(double [[X:%.*]], [3 x double] %"x'")
; CHECK-NEXT: entry:
; CHECK-NEXT: %"global'ipa" = alloca double, align 8
; CHECK-NEXT: %"global'ipa1" = alloca double, align 8
; CHECK-NEXT: %"global'ipa2" = alloca double, align 8
; CHECK-NEXT: [[TMP0:%.*]] = bitcast double* %"global'ipa" to i8*
; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* nonnull align 8 [[TMP0]], i8 0, i64 8, i1 false)
; CHECK-NEXT: [[TMP1:%.*]] = bitcast double* %"global'ipa1" to i8*
; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* nonnull align 8 [[TMP1]], i8 0, i64 8, i1 false)
; CHECK-NEXT: [[TMP2:%.*]] = bitcast double* %"global'ipa2" to i8*
; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* nonnull align 8 [[TMP2]], i8 0, i64 8, i1 false)
; CHECK-NEXT: %"global_local.0.copyload'ipl" = load double, double* %"global'ipa", align 8
; CHECK-NEXT: %"global_local.0.copyload'ipl3" = load double, double* %"global'ipa1", align 8
; CHECK-NEXT: %"global_local.0.copyload'ipl4" = load double, double* %"global'ipa2", align 8
; CHECK-NEXT: [[GLOBAL_LOCAL_0_COPYLOAD:%.*]] = load double, double* @global, align 8
; CHECK-NEXT: [[MUL:%.*]] = fmul fast double [[GLOBAL_LOCAL_0_COPYLOAD]], [[X]]
; CHECK-NEXT: [[TMP4:%.*]] = fmul fast double %"global_local.0.copyload'ipl", [[X]]
; CHECK-NEXT: [[TMP8:%.*]] = fmul fast double %"global_local.0.copyload'ipl3", [[X]]
; CHECK-NEXT: [[TMP12:%.*]] = fmul fast double %"global_local.0.copyload'ipl4", [[X]]
; CHECK-NEXT: [[TMP3:%.*]] = extractvalue [3 x double] %"x'", 0
; CHECK-NEXT: [[TMP5:%.*]] = fmul fast double [[TMP3]], [[GLOBAL_LOCAL_0_COPYLOAD]]
; CHECK-NEXT: [[TMP7:%.*]] = extractvalue [3 x double] %"x'", 1
; CHECK-NEXT: [[TMP9:%.*]] = fmul fast double [[TMP7]], [[GLOBAL_LOCAL_0_COPYLOAD]]
; CHECK-NEXT: [[TMP11:%.*]] = extractvalue [3 x double] %"x'", 2
; CHECK-NEXT: [[TMP13:%.*]] = fmul fast double [[TMP11]], [[GLOBAL_LOCAL_0_COPYLOAD]]
; CHECK-NEXT: [[TMP6:%.*]] = fadd fast double [[TMP4]], [[TMP5]]
; CHECK-NEXT: [[TMP10:%.*]] = fadd fast double [[TMP8]], [[TMP9]]
; CHECK-NEXT: [[TMP14:%.*]] = fadd fast double [[TMP12]], [[TMP13]]
; CHECK-NEXT: [[MUL2:%.*]] = fmul fast double [[MUL]], [[MUL]]

; CHECK-NEXT: [[TMP15:%.*]] = fmul fast double [[TMP6]], [[MUL]]
; CHECK-NEXT: [[TMP19:%.*]] = fmul fast double [[TMP10]], [[MUL]]
; CHECK-NEXT: [[TMP23:%.*]] = fmul fast double [[TMP14]], [[MUL]]
; CHECK-NEXT: [[TMP15:%.*]] = fmul fast double [[TMP5]], [[MUL]]
; CHECK-NEXT: [[TMP19:%.*]] = fmul fast double [[TMP9]], [[MUL]]
; CHECK-NEXT: [[TMP23:%.*]] = fmul fast double [[TMP13]], [[MUL]]

; CHECK-NEXT: [[TMP16:%.*]] = fmul fast double [[TMP6]], [[MUL]]
; CHECK-NEXT: [[TMP20:%.*]] = fmul fast double [[TMP10]], [[MUL]]
; CHECK-NEXT: [[TMP24:%.*]] = fmul fast double [[TMP14]], [[MUL]]
; CHECK-NEXT: [[TMP16:%.*]] = fmul fast double [[TMP5]], [[MUL]]
; CHECK-NEXT: [[TMP20:%.*]] = fmul fast double [[TMP9]], [[MUL]]
; CHECK-NEXT: [[TMP24:%.*]] = fmul fast double [[TMP13]], [[MUL]]

; CHECK-NEXT: [[TMP17:%.*]] = fadd fast double [[TMP15]], [[TMP16]]
; CHECK-NEXT: [[TMP18:%.*]] = insertvalue [3 x double] undef, double [[TMP17]], 0
; CHECK-NEXT: [[TMP21:%.*]] = fadd fast double [[TMP19]], [[TMP20]]
; CHECK-NEXT: [[TMP22:%.*]] = insertvalue [3 x double] [[TMP18]], double [[TMP21]], 1
; CHECK-NEXT: [[TMP25:%.*]] = fadd fast double [[TMP23]], [[TMP24]]
; CHECK-NEXT: [[TMP26:%.*]] = insertvalue [3 x double] [[TMP22]], double [[TMP25]], 2
; CHECK-NEXT: store double [[TMP17]], double* %"global'ipa", align 8
; CHECK-NEXT: store double [[TMP21]], double* %"global'ipa1", align 8
; CHECK-NEXT: store double [[TMP25]], double* %"global'ipa2", align 8
; CHECK-NEXT: store double [[MUL2]], double* @global, align 8
; CHECK-NEXT: ret [3 x double] [[TMP26]]
;
2 changes: 0 additions & 2 deletions enzyme/test/Enzyme/ForwardModeVector/invertselect.ll
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,9 @@ attributes #0 = { noinline }
; CHECK-NEXT: [[TMP2:%.*]] = extractvalue [3 x float*] %"a'", 0
; CHECK-NEXT: [[TMP3:%.*]] = extractvalue [3 x float*] %"b'", 0
; CHECK-NEXT: %"a.b'ipse" = select i1 [[CMP]], float* [[TMP2]], float* [[TMP3]]
; CHECK-NEXT: [[TMP4:%.*]] = insertvalue [3 x float*] undef, float* %"a.b'ipse", 0
; CHECK-NEXT: [[TMP5:%.*]] = extractvalue [3 x float*] %"a'", 1
; CHECK-NEXT: [[TMP6:%.*]] = extractvalue [3 x float*] %"b'", 1
; CHECK-NEXT: %"a.b'ipse1" = select i1 [[CMP]], float* [[TMP5]], float* [[TMP6]]
; CHECK-NEXT: [[TMP7:%.*]] = insertvalue [3 x float*] [[TMP4]], float* %"a.b'ipse1", 1
; CHECK-NEXT: [[TMP8:%.*]] = extractvalue [3 x float*] %"a'", 2
; CHECK-NEXT: [[TMP9:%.*]] = extractvalue [3 x float*] %"b'", 2
; CHECK-NEXT: %"a.b'ipse2" = select i1 [[CMP]], float* [[TMP8]], float* [[TMP9]]
Expand Down
22 changes: 11 additions & 11 deletions enzyme/test/Enzyme/ForwardModeVector/log1p.ll
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@ declare double @log1p(double)

; CHECK: define internal [3 x double] @fwddiffe3tester(double %x, [3 x double] %"x'")
; CHECK-NEXT: entry:
; CHECK-NEXT: %0 = fadd fast double %x, 1.000000e+00
; CHECK-NEXT: %1 = extractvalue [3 x double] %"x'", 0
; CHECK-NEXT: %2 = fdiv fast double %1, %0
; CHECK-NEXT: %3 = insertvalue [3 x double] undef, double %2, 0
; CHECK-NEXT: %4 = extractvalue [3 x double] %"x'", 1
; CHECK-NEXT: %5 = fdiv fast double %4, %0
; CHECK-NEXT: %6 = insertvalue [3 x double] %3, double %5, 1
; CHECK-NEXT: %7 = extractvalue [3 x double] %"x'", 2
; CHECK-NEXT: %8 = fdiv fast double %7, %0
; CHECK-NEXT: %9 = insertvalue [3 x double] %6, double %8, 2
; CHECK-NEXT: ret [3 x double] %9
; CHECK-NEXT: %[[i0:.+]] = fadd fast double %x, 1.000000e+00
; CHECK-NEXT: %[[i1:.+]] = extractvalue [3 x double] %"x'", 0
; CHECK-NEXT: %[[i2:.+]] = fdiv fast double %[[i1]], %[[i0]]
; CHECK-NEXT: %[[i4:.+]] = extractvalue [3 x double] %"x'", 1
; CHECK-NEXT: %[[i5:.+]] = fdiv fast double %[[i4]], %[[i0]]
; CHECK-NEXT: %[[i7:.+]] = extractvalue [3 x double] %"x'", 2
; CHECK-NEXT: %[[i8:.+]] = fdiv fast double %[[i7]], %[[i0]]
; CHECK-NEXT: %[[i3:.+]] = insertvalue [3 x double] undef, double %[[i2]], 0
; CHECK-NEXT: %[[i6:.+]] = insertvalue [3 x double] %[[i3]], double %[[i5]], 1
; CHECK-NEXT: %[[i9:.+]] = insertvalue [3 x double] %[[i6]], double %[[i8]], 2
; CHECK-NEXT: ret [3 x double] %[[i9]]
; CHECK-NEXT: }
6 changes: 0 additions & 6 deletions enzyme/test/Enzyme/ForwardModeVector/ptr-eq.ll
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,10 @@ entry:
; CHECK-NEXT: entry:
; CHECK-NEXT: [[TMP0:%.*]] = extractvalue [3 x double*] %"x'", 0
; CHECK-NEXT: %"val'ipl" = load double, double* [[TMP0]]
; CHECK-NEXT: [[TMP1:%.*]] = insertvalue [3 x double] undef, double %"val'ipl", 0
; CHECK-NEXT: [[TMP2:%.*]] = extractvalue [3 x double*] %"x'", 1
; CHECK-NEXT: %"val'ipl1" = load double, double* [[TMP2]]
; CHECK-NEXT: [[TMP3:%.*]] = insertvalue [3 x double] [[TMP1]], double %"val'ipl1", 1
; CHECK-NEXT: [[TMP4:%.*]] = extractvalue [3 x double*] %"x'", 2
; CHECK-NEXT: %"val'ipl2" = load double, double* [[TMP4]]
; CHECK-NEXT: [[TMP5:%.*]] = insertvalue [3 x double] [[TMP3]], double %"val'ipl2", 2
; CHECK-NEXT: [[VAL:%.*]] = load double, double* [[X]]
; CHECK-NEXT: [[TMP6:%.*]] = extractvalue [3 x double*] %"y'", 0
; CHECK-NEXT: store double %"val'ipl", double* [[TMP6]]
Expand All @@ -41,13 +38,10 @@ entry:
; CHECK-NEXT: store double [[VAL]], double* [[Y]]
; CHECK-NEXT: [[TMP12:%.*]] = extractvalue [3 x double*] %"x'", 0
; CHECK-NEXT: %"ptr'ipc" = bitcast double* [[TMP12]] to i8*
; CHECK-NEXT: [[TMP13:%.*]] = insertvalue [3 x i8*] undef, i8* %"ptr'ipc", 0
; CHECK-NEXT: [[TMP14:%.*]] = extractvalue [3 x double*] %"x'", 1
; CHECK-NEXT: %"ptr'ipc3" = bitcast double* [[TMP14]] to i8*
; CHECK-NEXT: [[TMP15:%.*]] = insertvalue [3 x i8*] [[TMP13]], i8* %"ptr'ipc3", 1
; CHECK-NEXT: [[TMP16:%.*]] = extractvalue [3 x double*] %"x'", 2
; CHECK-NEXT: %"ptr'ipc4" = bitcast double* [[TMP16]] to i8*
; CHECK-NEXT: [[TMP17:%.*]] = insertvalue [3 x i8*] [[TMP15]], i8* %"ptr'ipc4", 2
; CHECK-NEXT: [[PTR:%.*]] = bitcast double* [[X]] to i8*
; CHECK-NEXT: call void @free(i8* [[PTR]])
; CHECK-NEXT: [[TMPZ4:%.*]] = icmp ne i8* [[PTR]], %"ptr'ipc"
Expand Down
Loading

0 comments on commit 4ddf79f

Please sign in to comment.