Skip to content

Commit

Permalink
Fix int/float memset (#1421)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Sep 12, 2023
1 parent 8025707 commit 65d6481
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 10 deletions.
14 changes: 4 additions & 10 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -2939,11 +2939,8 @@ class AdjointGenerator
if (start != 0) {
Value *idxs[] = {
ConstantInt::get(Type::getInt32Ty(op0->getContext()), start)};
op0 = BuilderZ.CreateInBoundsGEP(
PointerType::get(
Type::getInt8Ty(op0->getContext()),
cast<PointerType>(op0->getType())->getAddressSpace()),
op0, idxs);
op0 = BuilderZ.CreateInBoundsGEP(Type::getInt8Ty(op0->getContext()),
op0, idxs);
}
SmallVector<Value *, 4> args = {op0, op1, length};
if (op3)
Expand Down Expand Up @@ -2979,11 +2976,8 @@ class AdjointGenerator
if (start != 0) {
Value *idxs[] = {
ConstantInt::get(Type::getInt32Ty(op0->getContext()), start)};
op0 = BuilderZ.CreateInBoundsGEP(
PointerType::get(
Type::getInt8Ty(op0->getContext()),
cast<PointerType>(op0->getType())->getAddressSpace()),
op0, idxs);
op0 = BuilderZ.CreateInBoundsGEP(Type::getInt8Ty(op0->getContext()),
op0, idxs);
}
SmallVector<Value *, 4> args = {op0, op1l, length};
if (op3l)
Expand Down
39 changes: 39 additions & 0 deletions enzyme/test/Enzyme/ReverseMode/memset-intfloat.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -simplifycfg -S | FileCheck %s; fi
; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,%simplifycfg)" -S | FileCheck %s

declare void @__enzyme_autodiff(...)

declare void @llvm.memset.p0i8.i64(i8* nocapture writeonly, i8, i64, i1)

declare void @g()

define void @f(i8* %x) {
call void @llvm.memset.p0i8.i64(i8* %x, i8 0, i64 16, i1 false)
%xp = bitcast i8* %x to double*
%flt = load double, double* %xp, align 8, !tbaa !7
%g = getelementptr inbounds double, double* %xp, i32 1
%int = load double, double* %g, align 8, !tbaa !4
call void @g() "enzyme_inactive"
ret void
}

define void @df(double* %x, double* %xp) {
tail call void (...) @__enzyme_autodiff(i8* bitcast (void (i8*)* @f to i8*), metadata !"enzyme_dup", double* %x, double* %xp)
ret void
}

!4 = !{!"long", !5, i64 0}
!7 = !{!"double", !5, i64 0}
!5 = !{!"omnipotent char", !6, i64 0}
!6 = !{!"Simple C++ TBAA"}

; CHECK: define internal void @diffef(i8* %x, i8* %"x'")
; CHECK-NEXT: invert:
; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* %x, i8 0, i64 16, i1 false)
; CHECK-NEXT: %0 = getelementptr inbounds i8, i8* %"x'", i32 8
; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* %0, i8 0, i64 8, i1 false)
; The G here is used to partition the forward pass from the reverse pass
; CHECK-NEXT: call void @g()
; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* %"x'", i8 0, i64 8, i1 false)
; CHECK-NEXT: ret void
; CHECK-NEXT: }

0 comments on commit 65d6481

Please sign in to comment.