From 65d6481bc17fc1b42c3185f6d6ad3a87555711ab Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 12 Sep 2023 13:35:45 -0500 Subject: [PATCH] Fix int/float memset (#1421) --- enzyme/Enzyme/AdjointGenerator.h | 14 ++----- .../Enzyme/ReverseMode/memset-intfloat.ll | 39 +++++++++++++++++++ 2 files changed, 43 insertions(+), 10 deletions(-) create mode 100644 enzyme/test/Enzyme/ReverseMode/memset-intfloat.ll diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index 234eee67d4a2..5eb1aab40fa6 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -2939,11 +2939,8 @@ class AdjointGenerator if (start != 0) { Value *idxs[] = { ConstantInt::get(Type::getInt32Ty(op0->getContext()), start)}; - op0 = BuilderZ.CreateInBoundsGEP( - PointerType::get( - Type::getInt8Ty(op0->getContext()), - cast(op0->getType())->getAddressSpace()), - op0, idxs); + op0 = BuilderZ.CreateInBoundsGEP(Type::getInt8Ty(op0->getContext()), + op0, idxs); } SmallVector args = {op0, op1, length}; if (op3) @@ -2979,11 +2976,8 @@ class AdjointGenerator if (start != 0) { Value *idxs[] = { ConstantInt::get(Type::getInt32Ty(op0->getContext()), start)}; - op0 = BuilderZ.CreateInBoundsGEP( - PointerType::get( - Type::getInt8Ty(op0->getContext()), - cast(op0->getType())->getAddressSpace()), - op0, idxs); + op0 = BuilderZ.CreateInBoundsGEP(Type::getInt8Ty(op0->getContext()), + op0, idxs); } SmallVector args = {op0, op1l, length}; if (op3l) diff --git a/enzyme/test/Enzyme/ReverseMode/memset-intfloat.ll b/enzyme/test/Enzyme/ReverseMode/memset-intfloat.ll new file mode 100644 index 000000000000..4c9c76a3f712 --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/memset-intfloat.ll @@ -0,0 +1,39 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -simplifycfg -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,%simplifycfg)" -S | FileCheck %s + +declare void @__enzyme_autodiff(...) + +declare void @llvm.memset.p0i8.i64(i8* nocapture writeonly, i8, i64, i1) + +declare void @g() + +define void @f(i8* %x) { + call void @llvm.memset.p0i8.i64(i8* %x, i8 0, i64 16, i1 false) + %xp = bitcast i8* %x to double* + %flt = load double, double* %xp, align 8, !tbaa !7 + %g = getelementptr inbounds double, double* %xp, i32 1 + %int = load double, double* %g, align 8, !tbaa !4 + call void @g() "enzyme_inactive" + ret void +} + +define void @df(double* %x, double* %xp) { + tail call void (...) @__enzyme_autodiff(i8* bitcast (void (i8*)* @f to i8*), metadata !"enzyme_dup", double* %x, double* %xp) + ret void +} + +!4 = !{!"long", !5, i64 0} +!7 = !{!"double", !5, i64 0} +!5 = !{!"omnipotent char", !6, i64 0} +!6 = !{!"Simple C++ TBAA"} + +; CHECK: define internal void @diffef(i8* %x, i8* %"x'") +; CHECK-NEXT: invert: +; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* %x, i8 0, i64 16, i1 false) +; CHECK-NEXT: %0 = getelementptr inbounds i8, i8* %"x'", i32 8 +; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* %0, i8 0, i64 8, i1 false) +; The G here is used to partition the forward pass from the reverse pass +; CHECK-NEXT: call void @g() +; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* %"x'", i8 0, i64 8, i1 false) +; CHECK-NEXT: ret void +; CHECK-NEXT: }