From 73e22ff3d77db72bb9b6e22342417a5f4fe6afb4 Mon Sep 17 00:00:00 2001 From: Akshay Deodhar Date: Tue, 28 May 2024 11:05:38 -0700 Subject: [PATCH] [Reassociate] Preserve NSW flags after expr tree rewriting (#93105) We can guarantee NSW on all operands in a reassociated add expression tree when: - All adds in an add operator tree are NSW, AND either - All add operands are guaranteed to be nonnegative, OR - All adds are also NUW - Alive2: - Nonnegative Operands - 3 operands: https://alive2.llvm.org/ce/z/G4XW6Q - 4 operands: https://alive2.llvm.org/ce/z/FWcZ6D - NUW NSW adds: https://alive2.llvm.org/ce/z/vRUxeC --------- Co-authored-by: Nikita Popov --- .../llvm/Transforms/Scalar/Reassociate.h | 12 ++- llvm/lib/Transforms/Scalar/Reassociate.cpp | 35 +++++--- llvm/test/Transforms/Reassociate/local-cse.ll | 40 +++++----- .../Transforms/Reassociate/reassoc-add-nsw.ll | 79 +++++++++++++++++++ 4 files changed, 132 insertions(+), 34 deletions(-) create mode 100644 llvm/test/Transforms/Reassociate/reassoc-add-nsw.ll diff --git a/llvm/include/llvm/Transforms/Scalar/Reassociate.h b/llvm/include/llvm/Transforms/Scalar/Reassociate.h index f3a2e0f4380eb0..84d72df6fc4d81 100644 --- a/llvm/include/llvm/Transforms/Scalar/Reassociate.h +++ b/llvm/include/llvm/Transforms/Scalar/Reassociate.h @@ -63,6 +63,16 @@ struct Factor { Factor(Value *Base, unsigned Power) : Base(Base), Power(Power) {} }; +struct OverflowTracking { + bool HasNUW; + bool HasNSW; + bool AllKnownNonNegative; + // Note: AllKnownNonNegative can be true in a case where one of the operands + // is negative, but one the operators is not NSW. AllKnownNonNegative should + // not be used independently of HasNSW + OverflowTracking() : HasNUW(true), HasNSW(true), AllKnownNonNegative(true) {} +}; + class XorOpnd; } // end namespace reassociate @@ -103,7 +113,7 @@ class ReassociatePass : public PassInfoMixin { void ReassociateExpression(BinaryOperator *I); void RewriteExprTree(BinaryOperator *I, SmallVectorImpl &Ops, - bool HasNUW); + reassociate::OverflowTracking Flags); Value *OptimizeExpression(BinaryOperator *I, SmallVectorImpl &Ops); Value *OptimizeAdd(Instruction *I, diff --git a/llvm/lib/Transforms/Scalar/Reassociate.cpp b/llvm/lib/Transforms/Scalar/Reassociate.cpp index d91320863e241d..c903e47a93cafd 100644 --- a/llvm/lib/Transforms/Scalar/Reassociate.cpp +++ b/llvm/lib/Transforms/Scalar/Reassociate.cpp @@ -471,7 +471,7 @@ using RepeatedValue = std::pair; static bool LinearizeExprTree(Instruction *I, SmallVectorImpl &Ops, ReassociatePass::OrderedSet &ToRedo, - bool &HasNUW) { + reassociate::OverflowTracking &Flags) { assert((isa(I) || isa(I)) && "Expected a UnaryOperator or BinaryOperator!"); LLVM_DEBUG(dbgs() << "LINEARIZE: " << *I << '\n'); @@ -512,6 +512,7 @@ static bool LinearizeExprTree(Instruction *I, using LeafMap = DenseMap; LeafMap Leaves; // Leaf -> Total weight so far. SmallVector LeafOrder; // Ensure deterministic leaf output order. + const DataLayout DL = I->getModule()->getDataLayout(); #ifndef NDEBUG SmallPtrSet Visited; // For checking the iteration scheme. @@ -520,8 +521,10 @@ static bool LinearizeExprTree(Instruction *I, std::pair P = Worklist.pop_back_val(); I = P.first; // We examine the operands of this binary operator. - if (isa(I)) - HasNUW &= I->hasNoUnsignedWrap(); + if (isa(I)) { + Flags.HasNUW &= I->hasNoUnsignedWrap(); + Flags.HasNSW &= I->hasNoSignedWrap(); + } for (unsigned OpIdx = 0; OpIdx < I->getNumOperands(); ++OpIdx) { // Visit operands. Value *Op = I->getOperand(OpIdx); @@ -648,6 +651,8 @@ static bool LinearizeExprTree(Instruction *I, // Ensure the leaf is only output once. It->second = 0; Ops.push_back(std::make_pair(V, Weight)); + if (Opcode == Instruction::Add && Flags.AllKnownNonNegative && Flags.HasNSW) + Flags.AllKnownNonNegative &= isKnownNonNegative(V, SimplifyQuery(DL)); } // For nilpotent operations or addition there may be no operands, for example @@ -666,7 +671,7 @@ static bool LinearizeExprTree(Instruction *I, /// linearized and optimized, emit them in-order. void ReassociatePass::RewriteExprTree(BinaryOperator *I, SmallVectorImpl &Ops, - bool HasNUW) { + OverflowTracking Flags) { assert(Ops.size() > 1 && "Single values should be used directly!"); // Since our optimizations should never increase the number of operations, the @@ -834,8 +839,12 @@ void ReassociatePass::RewriteExprTree(BinaryOperator *I, // Note that it doesn't hold for mul if one of the operands is zero. // TODO: We can preserve NUW flag if we prove that all mul operands // are non-zero. - if (HasNUW && ExpressionChangedStart->getOpcode() == Instruction::Add) - ExpressionChangedStart->setHasNoUnsignedWrap(); + if (ExpressionChangedStart->getOpcode() == Instruction::Add) { + if (Flags.HasNUW) + ExpressionChangedStart->setHasNoUnsignedWrap(); + if (Flags.HasNSW && (Flags.AllKnownNonNegative || Flags.HasNUW)) + ExpressionChangedStart->setHasNoSignedWrap(); + } } } @@ -1192,8 +1201,8 @@ Value *ReassociatePass::RemoveFactorFromExpression(Value *V, Value *Factor) { return nullptr; SmallVector Tree; - bool HasNUW = true; - MadeChange |= LinearizeExprTree(BO, Tree, RedoInsts, HasNUW); + OverflowTracking Flags; + MadeChange |= LinearizeExprTree(BO, Tree, RedoInsts, Flags); SmallVector Factors; Factors.reserve(Tree.size()); for (unsigned i = 0, e = Tree.size(); i != e; ++i) { @@ -1235,7 +1244,7 @@ Value *ReassociatePass::RemoveFactorFromExpression(Value *V, Value *Factor) { if (!FoundFactor) { // Make sure to restore the operands to the expression tree. - RewriteExprTree(BO, Factors, HasNUW); + RewriteExprTree(BO, Factors, Flags); return nullptr; } @@ -1247,7 +1256,7 @@ Value *ReassociatePass::RemoveFactorFromExpression(Value *V, Value *Factor) { RedoInsts.insert(BO); V = Factors[0].Op; } else { - RewriteExprTree(BO, Factors, HasNUW); + RewriteExprTree(BO, Factors, Flags); V = BO; } @@ -2373,8 +2382,8 @@ void ReassociatePass::ReassociateExpression(BinaryOperator *I) { // First, walk the expression tree, linearizing the tree, collecting the // operand information. SmallVector Tree; - bool HasNUW = true; - MadeChange |= LinearizeExprTree(I, Tree, RedoInsts, HasNUW); + OverflowTracking Flags; + MadeChange |= LinearizeExprTree(I, Tree, RedoInsts, Flags); SmallVector Ops; Ops.reserve(Tree.size()); for (const RepeatedValue &E : Tree) @@ -2567,7 +2576,7 @@ void ReassociatePass::ReassociateExpression(BinaryOperator *I) { dbgs() << '\n'); // Now that we ordered and optimized the expressions, splat them back into // the expression tree, removing any unneeded nodes. - RewriteExprTree(I, Ops, HasNUW); + RewriteExprTree(I, Ops, Flags); } void diff --git a/llvm/test/Transforms/Reassociate/local-cse.ll b/llvm/test/Transforms/Reassociate/local-cse.ll index 4d0467e263f553..d0d609f022b46b 100644 --- a/llvm/test/Transforms/Reassociate/local-cse.ll +++ b/llvm/test/Transforms/Reassociate/local-cse.ll @@ -26,16 +26,16 @@ define void @chain_spanning_several_blocks(i64 %inv1, i64 %inv2, i64 %inv3, i64 ; LOCAL_CSE-LABEL: define void @chain_spanning_several_blocks ; LOCAL_CSE-SAME: (i64 [[INV1:%.*]], i64 [[INV2:%.*]], i64 [[INV3:%.*]], i64 [[INV4:%.*]], i64 [[INV5:%.*]]) { ; LOCAL_CSE-NEXT: bb1: -; LOCAL_CSE-NEXT: [[CHAIN_A0:%.*]] = add nuw i64 [[INV2]], [[INV1]] +; LOCAL_CSE-NEXT: [[CHAIN_A0:%.*]] = add nuw nsw i64 [[INV2]], [[INV1]] ; LOCAL_CSE-NEXT: br label [[BB2:%.*]] ; LOCAL_CSE: bb2: ; LOCAL_CSE-NEXT: [[VAL_BB2:%.*]] = call i64 @get_val() -; LOCAL_CSE-NEXT: [[CHAIN_A1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV4]] -; LOCAL_CSE-NEXT: [[CHAIN_A2:%.*]] = add nuw i64 [[CHAIN_A1]], [[VAL_BB2]] -; LOCAL_CSE-NEXT: [[CHAIN_B1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV5]] -; LOCAL_CSE-NEXT: [[CHAIN_B2:%.*]] = add nuw i64 [[CHAIN_B1]], [[VAL_BB2]] -; LOCAL_CSE-NEXT: [[CHAIN_C0:%.*]] = add nuw i64 [[INV3]], [[INV1]] -; LOCAL_CSE-NEXT: [[CHAIN_C1:%.*]] = add nuw i64 [[CHAIN_C0]], [[VAL_BB2]] +; LOCAL_CSE-NEXT: [[CHAIN_A1:%.*]] = add nuw nsw i64 [[CHAIN_A0]], [[INV4]] +; LOCAL_CSE-NEXT: [[CHAIN_A2:%.*]] = add nuw nsw i64 [[CHAIN_A1]], [[VAL_BB2]] +; LOCAL_CSE-NEXT: [[CHAIN_B1:%.*]] = add nuw nsw i64 [[CHAIN_A0]], [[INV5]] +; LOCAL_CSE-NEXT: [[CHAIN_B2:%.*]] = add nuw nsw i64 [[CHAIN_B1]], [[VAL_BB2]] +; LOCAL_CSE-NEXT: [[CHAIN_C0:%.*]] = add nuw nsw i64 [[INV3]], [[INV1]] +; LOCAL_CSE-NEXT: [[CHAIN_C1:%.*]] = add nuw nsw i64 [[CHAIN_C0]], [[VAL_BB2]] ; LOCAL_CSE-NEXT: call void @keep_alive(i64 [[CHAIN_A2]]) ; LOCAL_CSE-NEXT: call void @keep_alive(i64 [[CHAIN_B2]]) ; LOCAL_CSE-NEXT: call void @keep_alive(i64 [[CHAIN_C1]]) @@ -47,11 +47,11 @@ define void @chain_spanning_several_blocks(i64 %inv1, i64 %inv2, i64 %inv3, i64 ; CSE-NEXT: br label [[BB2:%.*]] ; CSE: bb2: ; CSE-NEXT: [[VAL_BB2:%.*]] = call i64 @get_val() -; CSE-NEXT: [[CHAIN_A0:%.*]] = add nuw i64 [[VAL_BB2]], [[INV1]] -; CSE-NEXT: [[CHAIN_A1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV2]] +; CSE-NEXT: [[CHAIN_A0:%.*]] = add nuw nsw i64 [[VAL_BB2]], [[INV1]] +; CSE-NEXT: [[CHAIN_A1:%.*]] = add nuw nsw i64 [[CHAIN_A0]], [[INV2]] ; CSE-NEXT: [[CHAIN_A2:%.*]] = add nuw nsw i64 [[CHAIN_A1]], [[INV4]] ; CSE-NEXT: [[CHAIN_B2:%.*]] = add nuw nsw i64 [[CHAIN_A1]], [[INV5]] -; CSE-NEXT: [[CHAIN_C1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV3]] +; CSE-NEXT: [[CHAIN_C1:%.*]] = add nuw nsw i64 [[CHAIN_A0]], [[INV3]] ; CSE-NEXT: call void @keep_alive(i64 [[CHAIN_A2]]) ; CSE-NEXT: call void @keep_alive(i64 [[CHAIN_B2]]) ; CSE-NEXT: call void @keep_alive(i64 [[CHAIN_C1]]) @@ -90,19 +90,19 @@ define void @chain_spanning_several_blocks_no_entry_anchor() { ; LOCAL_CSE-NEXT: br label [[BB1:%.*]] ; LOCAL_CSE: bb1: ; LOCAL_CSE-NEXT: [[INV1_BB1:%.*]] = call i64 @get_val() -; LOCAL_CSE-NEXT: [[CHAIN_A0:%.*]] = add nuw i64 [[INV1_BB1]], [[INV2_BB0]] +; LOCAL_CSE-NEXT: [[CHAIN_A0:%.*]] = add nuw nsw i64 [[INV1_BB1]], [[INV2_BB0]] ; LOCAL_CSE-NEXT: br label [[BB2:%.*]] ; LOCAL_CSE: bb2: ; LOCAL_CSE-NEXT: [[INV3_BB2:%.*]] = call i64 @get_val() ; LOCAL_CSE-NEXT: [[INV4_BB2:%.*]] = call i64 @get_val() ; LOCAL_CSE-NEXT: [[INV5_BB2:%.*]] = call i64 @get_val() ; LOCAL_CSE-NEXT: [[VAL_BB2:%.*]] = call i64 @get_val() -; LOCAL_CSE-NEXT: [[CHAIN_A1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV4_BB2]] -; LOCAL_CSE-NEXT: [[CHAIN_A2:%.*]] = add nuw i64 [[CHAIN_A1]], [[VAL_BB2]] -; LOCAL_CSE-NEXT: [[CHAIN_B1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV5_BB2]] -; LOCAL_CSE-NEXT: [[CHAIN_B2:%.*]] = add nuw i64 [[CHAIN_B1]], [[VAL_BB2]] -; LOCAL_CSE-NEXT: [[CHAIN_C0:%.*]] = add nuw i64 [[VAL_BB2]], [[INV1_BB1]] -; LOCAL_CSE-NEXT: [[CHAIN_C1:%.*]] = add nuw i64 [[CHAIN_C0]], [[INV3_BB2]] +; LOCAL_CSE-NEXT: [[CHAIN_A1:%.*]] = add nuw nsw i64 [[CHAIN_A0]], [[INV4_BB2]] +; LOCAL_CSE-NEXT: [[CHAIN_A2:%.*]] = add nuw nsw i64 [[CHAIN_A1]], [[VAL_BB2]] +; LOCAL_CSE-NEXT: [[CHAIN_B1:%.*]] = add nuw nsw i64 [[CHAIN_A0]], [[INV5_BB2]] +; LOCAL_CSE-NEXT: [[CHAIN_B2:%.*]] = add nuw nsw i64 [[CHAIN_B1]], [[VAL_BB2]] +; LOCAL_CSE-NEXT: [[CHAIN_C0:%.*]] = add nuw nsw i64 [[VAL_BB2]], [[INV1_BB1]] +; LOCAL_CSE-NEXT: [[CHAIN_C1:%.*]] = add nuw nsw i64 [[CHAIN_C0]], [[INV3_BB2]] ; LOCAL_CSE-NEXT: call void @keep_alive(i64 [[CHAIN_A2]]) ; LOCAL_CSE-NEXT: call void @keep_alive(i64 [[CHAIN_B2]]) ; LOCAL_CSE-NEXT: call void @keep_alive(i64 [[CHAIN_C1]]) @@ -120,11 +120,11 @@ define void @chain_spanning_several_blocks_no_entry_anchor() { ; CSE-NEXT: [[INV4_BB2:%.*]] = call i64 @get_val() ; CSE-NEXT: [[INV5_BB2:%.*]] = call i64 @get_val() ; CSE-NEXT: [[VAL_BB2:%.*]] = call i64 @get_val() -; CSE-NEXT: [[CHAIN_A0:%.*]] = add nuw i64 [[VAL_BB2]], [[INV1_BB1]] -; CSE-NEXT: [[CHAIN_A1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV2_BB0]] +; CSE-NEXT: [[CHAIN_A0:%.*]] = add nuw nsw i64 [[VAL_BB2]], [[INV1_BB1]] +; CSE-NEXT: [[CHAIN_A1:%.*]] = add nuw nsw i64 [[CHAIN_A0]], [[INV2_BB0]] ; CSE-NEXT: [[CHAIN_A2:%.*]] = add nuw nsw i64 [[CHAIN_A1]], [[INV4_BB2]] ; CSE-NEXT: [[CHAIN_B2:%.*]] = add nuw nsw i64 [[CHAIN_A1]], [[INV5_BB2]] -; CSE-NEXT: [[CHAIN_C1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV3_BB2]] +; CSE-NEXT: [[CHAIN_C1:%.*]] = add nuw nsw i64 [[CHAIN_A0]], [[INV3_BB2]] ; CSE-NEXT: call void @keep_alive(i64 [[CHAIN_A2]]) ; CSE-NEXT: call void @keep_alive(i64 [[CHAIN_B2]]) ; CSE-NEXT: call void @keep_alive(i64 [[CHAIN_C1]]) diff --git a/llvm/test/Transforms/Reassociate/reassoc-add-nsw.ll b/llvm/test/Transforms/Reassociate/reassoc-add-nsw.ll new file mode 100644 index 00000000000000..fcebc4980e6d7d --- /dev/null +++ b/llvm/test/Transforms/Reassociate/reassoc-add-nsw.ll @@ -0,0 +1,79 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4 +; RUN: opt < %s -passes=reassociate -S | FileCheck %s +define i32 @nsw_preserve_nonnegative(ptr %ptr0, ptr %ptr1, ptr %ptr2) { +; CHECK-LABEL: define i32 @nsw_preserve_nonnegative( +; CHECK-SAME: ptr [[PTR0:%.*]], ptr [[PTR1:%.*]], ptr [[PTR2:%.*]]) { +; CHECK-NEXT: [[V0:%.*]] = load i32, ptr [[PTR0]], align 4, !range [[RNG0:![0-9]+]] +; CHECK-NEXT: [[V1:%.*]] = load i32, ptr [[PTR1]], align 4, !range [[RNG0]] +; CHECK-NEXT: [[V2:%.*]] = load i32, ptr [[PTR2]], align 4, !range [[RNG0]] +; CHECK-NEXT: [[ADD0:%.*]] = add nsw i32 [[V1]], [[V0]] +; CHECK-NEXT: [[ADD1:%.*]] = add nsw i32 [[ADD0]], [[V2]] +; CHECK-NEXT: ret i32 [[ADD1]] +; + %v0 = load i32, ptr %ptr0, !range !1 + %v1 = load i32, ptr %ptr1, !range !1 + %v2 = load i32, ptr %ptr2, !range !1 + %add0 = add nsw i32 %v1, %v2 + %add1 = add nsw i32 %add0, %v0 + ret i32 %add1 +} + +define i32 @nsw_preserve_nuw_nsw(ptr %ptr0, ptr %ptr1, ptr %ptr2) { +; CHECK-LABEL: define i32 @nsw_preserve_nuw_nsw( +; CHECK-SAME: ptr [[PTR0:%.*]], ptr [[PTR1:%.*]], ptr [[PTR2:%.*]]) { +; CHECK-NEXT: [[V0:%.*]] = load i32, ptr [[PTR0]], align 4 +; CHECK-NEXT: [[V1:%.*]] = load i32, ptr [[PTR1]], align 4 +; CHECK-NEXT: [[V2:%.*]] = load i32, ptr [[PTR2]], align 4 +; CHECK-NEXT: [[ADD0:%.*]] = add nuw nsw i32 [[V1]], [[V0]] +; CHECK-NEXT: [[ADD1:%.*]] = add nuw nsw i32 [[ADD0]], [[V2]] +; CHECK-NEXT: ret i32 [[ADD1]] +; + %v0 = load i32, ptr %ptr0 + %v1 = load i32, ptr %ptr1 + %v2 = load i32, ptr %ptr2 + %add0 = add nuw nsw i32 %v1, %v2 + %add1 = add nuw nsw i32 %add0, %v0 + ret i32 %add1 +} + +define i32 @nsw_dont_preserve_negative(ptr %ptr0, ptr %ptr1, ptr %ptr2) { +; CHECK-LABEL: define i32 @nsw_dont_preserve_negative( +; CHECK-SAME: ptr [[PTR0:%.*]], ptr [[PTR1:%.*]], ptr [[PTR2:%.*]]) { +; CHECK-NEXT: [[V0:%.*]] = load i32, ptr [[PTR0]], align 4 +; CHECK-NEXT: [[V1:%.*]] = load i32, ptr [[PTR1]], align 4, !range [[RNG0]] +; CHECK-NEXT: [[V2:%.*]] = load i32, ptr [[PTR2]], align 4, !range [[RNG0]] +; CHECK-NEXT: [[ADD0:%.*]] = add i32 [[V1]], [[V0]] +; CHECK-NEXT: [[ADD1:%.*]] = add i32 [[ADD0]], [[V2]] +; CHECK-NEXT: ret i32 [[ADD1]] +; + %v0 = load i32, ptr %ptr0 + %v1 = load i32, ptr %ptr1, !range !1 + %v2 = load i32, ptr %ptr2, !range !1 + %add0 = add nsw i32 %v1, %v2 + %add1 = add nsw i32 %add0, %v0 + ret i32 %add1 +} + +define i32 @nsw_nopreserve_notallnsw(ptr %ptr0, ptr %ptr1, ptr %ptr2) { +; CHECK-LABEL: define i32 @nsw_nopreserve_notallnsw( +; CHECK-SAME: ptr [[PTR0:%.*]], ptr [[PTR1:%.*]], ptr [[PTR2:%.*]]) { +; CHECK-NEXT: [[V0:%.*]] = load i32, ptr [[PTR0]], align 4, !range [[RNG0:![0-9]+]] +; CHECK-NEXT: [[V1:%.*]] = load i32, ptr [[PTR1]], align 4, !range [[RNG0]] +; CHECK-NEXT: [[V2:%.*]] = load i32, ptr [[PTR2]], align 4, !range [[RNG0]] +; CHECK-NEXT: [[ADD0:%.*]] = add i32 [[V1]], [[V0]] +; CHECK-NEXT: [[ADD1:%.*]] = add i32 [[ADD0]], [[V2]] +; CHECK-NEXT: ret i32 [[ADD1]] +; + %v0 = load i32, ptr %ptr0, !range !1 + %v1 = load i32, ptr %ptr1, !range !1 + %v2 = load i32, ptr %ptr2, !range !1 + %add0 = add nsw i32 %v1, %v2 + %add1 = add i32 %add0, %v0 + ret i32 %add1 +} + +; Positive 32 bit integers +!1 = !{i32 0, i32 2147483648} +;. +; CHECK: [[RNG0]] = !{i32 0, i32 -2147483648} +;.