Skip to content

Commit

Permalink
[Reassociate] Preserve NSW flags after expr tree rewriting (#93105)
Browse files Browse the repository at this point in the history
We can guarantee NSW on all operands in a reassociated add expression
tree when:

- All adds in an add operator tree are NSW, AND either 
  - All add operands are guaranteed to be nonnegative, OR 
  - All adds are also NUW

- Alive2: 
- Nonnegative Operands
	- 3 operands: https://alive2.llvm.org/ce/z/G4XW6Q
	- 4 operands: https://alive2.llvm.org/ce/z/FWcZ6D
 - NUW NSW adds: https://alive2.llvm.org/ce/z/vRUxeC

---------

Co-authored-by: Nikita Popov <github@npopov.com>
  • Loading branch information
akshayrdeodhar and nikic authored May 28, 2024
1 parent f089996 commit 73e22ff
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 34 deletions.
12 changes: 11 additions & 1 deletion llvm/include/llvm/Transforms/Scalar/Reassociate.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,16 @@ struct Factor {
Factor(Value *Base, unsigned Power) : Base(Base), Power(Power) {}
};

struct OverflowTracking {
bool HasNUW;
bool HasNSW;
bool AllKnownNonNegative;
// Note: AllKnownNonNegative can be true in a case where one of the operands
// is negative, but one the operators is not NSW. AllKnownNonNegative should
// not be used independently of HasNSW
OverflowTracking() : HasNUW(true), HasNSW(true), AllKnownNonNegative(true) {}
};

class XorOpnd;

} // end namespace reassociate
Expand Down Expand Up @@ -103,7 +113,7 @@ class ReassociatePass : public PassInfoMixin<ReassociatePass> {
void ReassociateExpression(BinaryOperator *I);
void RewriteExprTree(BinaryOperator *I,
SmallVectorImpl<reassociate::ValueEntry> &Ops,
bool HasNUW);
reassociate::OverflowTracking Flags);
Value *OptimizeExpression(BinaryOperator *I,
SmallVectorImpl<reassociate::ValueEntry> &Ops);
Value *OptimizeAdd(Instruction *I,
Expand Down
35 changes: 22 additions & 13 deletions llvm/lib/Transforms/Scalar/Reassociate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ using RepeatedValue = std::pair<Value*, APInt>;
static bool LinearizeExprTree(Instruction *I,
SmallVectorImpl<RepeatedValue> &Ops,
ReassociatePass::OrderedSet &ToRedo,
bool &HasNUW) {
reassociate::OverflowTracking &Flags) {
assert((isa<UnaryOperator>(I) || isa<BinaryOperator>(I)) &&
"Expected a UnaryOperator or BinaryOperator!");
LLVM_DEBUG(dbgs() << "LINEARIZE: " << *I << '\n');
Expand Down Expand Up @@ -512,6 +512,7 @@ static bool LinearizeExprTree(Instruction *I,
using LeafMap = DenseMap<Value *, APInt>;
LeafMap Leaves; // Leaf -> Total weight so far.
SmallVector<Value *, 8> LeafOrder; // Ensure deterministic leaf output order.
const DataLayout DL = I->getModule()->getDataLayout();

#ifndef NDEBUG
SmallPtrSet<Value *, 8> Visited; // For checking the iteration scheme.
Expand All @@ -520,8 +521,10 @@ static bool LinearizeExprTree(Instruction *I,
std::pair<Instruction*, APInt> P = Worklist.pop_back_val();
I = P.first; // We examine the operands of this binary operator.

if (isa<OverflowingBinaryOperator>(I))
HasNUW &= I->hasNoUnsignedWrap();
if (isa<OverflowingBinaryOperator>(I)) {
Flags.HasNUW &= I->hasNoUnsignedWrap();
Flags.HasNSW &= I->hasNoSignedWrap();
}

for (unsigned OpIdx = 0; OpIdx < I->getNumOperands(); ++OpIdx) { // Visit operands.
Value *Op = I->getOperand(OpIdx);
Expand Down Expand Up @@ -648,6 +651,8 @@ static bool LinearizeExprTree(Instruction *I,
// Ensure the leaf is only output once.
It->second = 0;
Ops.push_back(std::make_pair(V, Weight));
if (Opcode == Instruction::Add && Flags.AllKnownNonNegative && Flags.HasNSW)
Flags.AllKnownNonNegative &= isKnownNonNegative(V, SimplifyQuery(DL));
}

// For nilpotent operations or addition there may be no operands, for example
Expand All @@ -666,7 +671,7 @@ static bool LinearizeExprTree(Instruction *I,
/// linearized and optimized, emit them in-order.
void ReassociatePass::RewriteExprTree(BinaryOperator *I,
SmallVectorImpl<ValueEntry> &Ops,
bool HasNUW) {
OverflowTracking Flags) {
assert(Ops.size() > 1 && "Single values should be used directly!");

// Since our optimizations should never increase the number of operations, the
Expand Down Expand Up @@ -834,8 +839,12 @@ void ReassociatePass::RewriteExprTree(BinaryOperator *I,
// Note that it doesn't hold for mul if one of the operands is zero.
// TODO: We can preserve NUW flag if we prove that all mul operands
// are non-zero.
if (HasNUW && ExpressionChangedStart->getOpcode() == Instruction::Add)
ExpressionChangedStart->setHasNoUnsignedWrap();
if (ExpressionChangedStart->getOpcode() == Instruction::Add) {
if (Flags.HasNUW)
ExpressionChangedStart->setHasNoUnsignedWrap();
if (Flags.HasNSW && (Flags.AllKnownNonNegative || Flags.HasNUW))
ExpressionChangedStart->setHasNoSignedWrap();
}
}
}

Expand Down Expand Up @@ -1192,8 +1201,8 @@ Value *ReassociatePass::RemoveFactorFromExpression(Value *V, Value *Factor) {
return nullptr;

SmallVector<RepeatedValue, 8> Tree;
bool HasNUW = true;
MadeChange |= LinearizeExprTree(BO, Tree, RedoInsts, HasNUW);
OverflowTracking Flags;
MadeChange |= LinearizeExprTree(BO, Tree, RedoInsts, Flags);
SmallVector<ValueEntry, 8> Factors;
Factors.reserve(Tree.size());
for (unsigned i = 0, e = Tree.size(); i != e; ++i) {
Expand Down Expand Up @@ -1235,7 +1244,7 @@ Value *ReassociatePass::RemoveFactorFromExpression(Value *V, Value *Factor) {

if (!FoundFactor) {
// Make sure to restore the operands to the expression tree.
RewriteExprTree(BO, Factors, HasNUW);
RewriteExprTree(BO, Factors, Flags);
return nullptr;
}

Expand All @@ -1247,7 +1256,7 @@ Value *ReassociatePass::RemoveFactorFromExpression(Value *V, Value *Factor) {
RedoInsts.insert(BO);
V = Factors[0].Op;
} else {
RewriteExprTree(BO, Factors, HasNUW);
RewriteExprTree(BO, Factors, Flags);
V = BO;
}

Expand Down Expand Up @@ -2373,8 +2382,8 @@ void ReassociatePass::ReassociateExpression(BinaryOperator *I) {
// First, walk the expression tree, linearizing the tree, collecting the
// operand information.
SmallVector<RepeatedValue, 8> Tree;
bool HasNUW = true;
MadeChange |= LinearizeExprTree(I, Tree, RedoInsts, HasNUW);
OverflowTracking Flags;
MadeChange |= LinearizeExprTree(I, Tree, RedoInsts, Flags);
SmallVector<ValueEntry, 8> Ops;
Ops.reserve(Tree.size());
for (const RepeatedValue &E : Tree)
Expand Down Expand Up @@ -2567,7 +2576,7 @@ void ReassociatePass::ReassociateExpression(BinaryOperator *I) {
dbgs() << '\n');
// Now that we ordered and optimized the expressions, splat them back into
// the expression tree, removing any unneeded nodes.
RewriteExprTree(I, Ops, HasNUW);
RewriteExprTree(I, Ops, Flags);
}

void
Expand Down
40 changes: 20 additions & 20 deletions llvm/test/Transforms/Reassociate/local-cse.ll
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,16 @@ define void @chain_spanning_several_blocks(i64 %inv1, i64 %inv2, i64 %inv3, i64
; LOCAL_CSE-LABEL: define void @chain_spanning_several_blocks
; LOCAL_CSE-SAME: (i64 [[INV1:%.*]], i64 [[INV2:%.*]], i64 [[INV3:%.*]], i64 [[INV4:%.*]], i64 [[INV5:%.*]]) {
; LOCAL_CSE-NEXT: bb1:
; LOCAL_CSE-NEXT: [[CHAIN_A0:%.*]] = add nuw i64 [[INV2]], [[INV1]]
; LOCAL_CSE-NEXT: [[CHAIN_A0:%.*]] = add nuw nsw i64 [[INV2]], [[INV1]]
; LOCAL_CSE-NEXT: br label [[BB2:%.*]]
; LOCAL_CSE: bb2:
; LOCAL_CSE-NEXT: [[VAL_BB2:%.*]] = call i64 @get_val()
; LOCAL_CSE-NEXT: [[CHAIN_A1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV4]]
; LOCAL_CSE-NEXT: [[CHAIN_A2:%.*]] = add nuw i64 [[CHAIN_A1]], [[VAL_BB2]]
; LOCAL_CSE-NEXT: [[CHAIN_B1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV5]]
; LOCAL_CSE-NEXT: [[CHAIN_B2:%.*]] = add nuw i64 [[CHAIN_B1]], [[VAL_BB2]]
; LOCAL_CSE-NEXT: [[CHAIN_C0:%.*]] = add nuw i64 [[INV3]], [[INV1]]
; LOCAL_CSE-NEXT: [[CHAIN_C1:%.*]] = add nuw i64 [[CHAIN_C0]], [[VAL_BB2]]
; LOCAL_CSE-NEXT: [[CHAIN_A1:%.*]] = add nuw nsw i64 [[CHAIN_A0]], [[INV4]]
; LOCAL_CSE-NEXT: [[CHAIN_A2:%.*]] = add nuw nsw i64 [[CHAIN_A1]], [[VAL_BB2]]
; LOCAL_CSE-NEXT: [[CHAIN_B1:%.*]] = add nuw nsw i64 [[CHAIN_A0]], [[INV5]]
; LOCAL_CSE-NEXT: [[CHAIN_B2:%.*]] = add nuw nsw i64 [[CHAIN_B1]], [[VAL_BB2]]
; LOCAL_CSE-NEXT: [[CHAIN_C0:%.*]] = add nuw nsw i64 [[INV3]], [[INV1]]
; LOCAL_CSE-NEXT: [[CHAIN_C1:%.*]] = add nuw nsw i64 [[CHAIN_C0]], [[VAL_BB2]]
; LOCAL_CSE-NEXT: call void @keep_alive(i64 [[CHAIN_A2]])
; LOCAL_CSE-NEXT: call void @keep_alive(i64 [[CHAIN_B2]])
; LOCAL_CSE-NEXT: call void @keep_alive(i64 [[CHAIN_C1]])
Expand All @@ -47,11 +47,11 @@ define void @chain_spanning_several_blocks(i64 %inv1, i64 %inv2, i64 %inv3, i64
; CSE-NEXT: br label [[BB2:%.*]]
; CSE: bb2:
; CSE-NEXT: [[VAL_BB2:%.*]] = call i64 @get_val()
; CSE-NEXT: [[CHAIN_A0:%.*]] = add nuw i64 [[VAL_BB2]], [[INV1]]
; CSE-NEXT: [[CHAIN_A1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV2]]
; CSE-NEXT: [[CHAIN_A0:%.*]] = add nuw nsw i64 [[VAL_BB2]], [[INV1]]
; CSE-NEXT: [[CHAIN_A1:%.*]] = add nuw nsw i64 [[CHAIN_A0]], [[INV2]]
; CSE-NEXT: [[CHAIN_A2:%.*]] = add nuw nsw i64 [[CHAIN_A1]], [[INV4]]
; CSE-NEXT: [[CHAIN_B2:%.*]] = add nuw nsw i64 [[CHAIN_A1]], [[INV5]]
; CSE-NEXT: [[CHAIN_C1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV3]]
; CSE-NEXT: [[CHAIN_C1:%.*]] = add nuw nsw i64 [[CHAIN_A0]], [[INV3]]
; CSE-NEXT: call void @keep_alive(i64 [[CHAIN_A2]])
; CSE-NEXT: call void @keep_alive(i64 [[CHAIN_B2]])
; CSE-NEXT: call void @keep_alive(i64 [[CHAIN_C1]])
Expand Down Expand Up @@ -90,19 +90,19 @@ define void @chain_spanning_several_blocks_no_entry_anchor() {
; LOCAL_CSE-NEXT: br label [[BB1:%.*]]
; LOCAL_CSE: bb1:
; LOCAL_CSE-NEXT: [[INV1_BB1:%.*]] = call i64 @get_val()
; LOCAL_CSE-NEXT: [[CHAIN_A0:%.*]] = add nuw i64 [[INV1_BB1]], [[INV2_BB0]]
; LOCAL_CSE-NEXT: [[CHAIN_A0:%.*]] = add nuw nsw i64 [[INV1_BB1]], [[INV2_BB0]]
; LOCAL_CSE-NEXT: br label [[BB2:%.*]]
; LOCAL_CSE: bb2:
; LOCAL_CSE-NEXT: [[INV3_BB2:%.*]] = call i64 @get_val()
; LOCAL_CSE-NEXT: [[INV4_BB2:%.*]] = call i64 @get_val()
; LOCAL_CSE-NEXT: [[INV5_BB2:%.*]] = call i64 @get_val()
; LOCAL_CSE-NEXT: [[VAL_BB2:%.*]] = call i64 @get_val()
; LOCAL_CSE-NEXT: [[CHAIN_A1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV4_BB2]]
; LOCAL_CSE-NEXT: [[CHAIN_A2:%.*]] = add nuw i64 [[CHAIN_A1]], [[VAL_BB2]]
; LOCAL_CSE-NEXT: [[CHAIN_B1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV5_BB2]]
; LOCAL_CSE-NEXT: [[CHAIN_B2:%.*]] = add nuw i64 [[CHAIN_B1]], [[VAL_BB2]]
; LOCAL_CSE-NEXT: [[CHAIN_C0:%.*]] = add nuw i64 [[VAL_BB2]], [[INV1_BB1]]
; LOCAL_CSE-NEXT: [[CHAIN_C1:%.*]] = add nuw i64 [[CHAIN_C0]], [[INV3_BB2]]
; LOCAL_CSE-NEXT: [[CHAIN_A1:%.*]] = add nuw nsw i64 [[CHAIN_A0]], [[INV4_BB2]]
; LOCAL_CSE-NEXT: [[CHAIN_A2:%.*]] = add nuw nsw i64 [[CHAIN_A1]], [[VAL_BB2]]
; LOCAL_CSE-NEXT: [[CHAIN_B1:%.*]] = add nuw nsw i64 [[CHAIN_A0]], [[INV5_BB2]]
; LOCAL_CSE-NEXT: [[CHAIN_B2:%.*]] = add nuw nsw i64 [[CHAIN_B1]], [[VAL_BB2]]
; LOCAL_CSE-NEXT: [[CHAIN_C0:%.*]] = add nuw nsw i64 [[VAL_BB2]], [[INV1_BB1]]
; LOCAL_CSE-NEXT: [[CHAIN_C1:%.*]] = add nuw nsw i64 [[CHAIN_C0]], [[INV3_BB2]]
; LOCAL_CSE-NEXT: call void @keep_alive(i64 [[CHAIN_A2]])
; LOCAL_CSE-NEXT: call void @keep_alive(i64 [[CHAIN_B2]])
; LOCAL_CSE-NEXT: call void @keep_alive(i64 [[CHAIN_C1]])
Expand All @@ -120,11 +120,11 @@ define void @chain_spanning_several_blocks_no_entry_anchor() {
; CSE-NEXT: [[INV4_BB2:%.*]] = call i64 @get_val()
; CSE-NEXT: [[INV5_BB2:%.*]] = call i64 @get_val()
; CSE-NEXT: [[VAL_BB2:%.*]] = call i64 @get_val()
; CSE-NEXT: [[CHAIN_A0:%.*]] = add nuw i64 [[VAL_BB2]], [[INV1_BB1]]
; CSE-NEXT: [[CHAIN_A1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV2_BB0]]
; CSE-NEXT: [[CHAIN_A0:%.*]] = add nuw nsw i64 [[VAL_BB2]], [[INV1_BB1]]
; CSE-NEXT: [[CHAIN_A1:%.*]] = add nuw nsw i64 [[CHAIN_A0]], [[INV2_BB0]]
; CSE-NEXT: [[CHAIN_A2:%.*]] = add nuw nsw i64 [[CHAIN_A1]], [[INV4_BB2]]
; CSE-NEXT: [[CHAIN_B2:%.*]] = add nuw nsw i64 [[CHAIN_A1]], [[INV5_BB2]]
; CSE-NEXT: [[CHAIN_C1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV3_BB2]]
; CSE-NEXT: [[CHAIN_C1:%.*]] = add nuw nsw i64 [[CHAIN_A0]], [[INV3_BB2]]
; CSE-NEXT: call void @keep_alive(i64 [[CHAIN_A2]])
; CSE-NEXT: call void @keep_alive(i64 [[CHAIN_B2]])
; CSE-NEXT: call void @keep_alive(i64 [[CHAIN_C1]])
Expand Down
79 changes: 79 additions & 0 deletions llvm/test/Transforms/Reassociate/reassoc-add-nsw.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
; RUN: opt < %s -passes=reassociate -S | FileCheck %s
define i32 @nsw_preserve_nonnegative(ptr %ptr0, ptr %ptr1, ptr %ptr2) {
; CHECK-LABEL: define i32 @nsw_preserve_nonnegative(
; CHECK-SAME: ptr [[PTR0:%.*]], ptr [[PTR1:%.*]], ptr [[PTR2:%.*]]) {
; CHECK-NEXT: [[V0:%.*]] = load i32, ptr [[PTR0]], align 4, !range [[RNG0:![0-9]+]]
; CHECK-NEXT: [[V1:%.*]] = load i32, ptr [[PTR1]], align 4, !range [[RNG0]]
; CHECK-NEXT: [[V2:%.*]] = load i32, ptr [[PTR2]], align 4, !range [[RNG0]]
; CHECK-NEXT: [[ADD0:%.*]] = add nsw i32 [[V1]], [[V0]]
; CHECK-NEXT: [[ADD1:%.*]] = add nsw i32 [[ADD0]], [[V2]]
; CHECK-NEXT: ret i32 [[ADD1]]
;
%v0 = load i32, ptr %ptr0, !range !1
%v1 = load i32, ptr %ptr1, !range !1
%v2 = load i32, ptr %ptr2, !range !1
%add0 = add nsw i32 %v1, %v2
%add1 = add nsw i32 %add0, %v0
ret i32 %add1
}

define i32 @nsw_preserve_nuw_nsw(ptr %ptr0, ptr %ptr1, ptr %ptr2) {
; CHECK-LABEL: define i32 @nsw_preserve_nuw_nsw(
; CHECK-SAME: ptr [[PTR0:%.*]], ptr [[PTR1:%.*]], ptr [[PTR2:%.*]]) {
; CHECK-NEXT: [[V0:%.*]] = load i32, ptr [[PTR0]], align 4
; CHECK-NEXT: [[V1:%.*]] = load i32, ptr [[PTR1]], align 4
; CHECK-NEXT: [[V2:%.*]] = load i32, ptr [[PTR2]], align 4
; CHECK-NEXT: [[ADD0:%.*]] = add nuw nsw i32 [[V1]], [[V0]]
; CHECK-NEXT: [[ADD1:%.*]] = add nuw nsw i32 [[ADD0]], [[V2]]
; CHECK-NEXT: ret i32 [[ADD1]]
;
%v0 = load i32, ptr %ptr0
%v1 = load i32, ptr %ptr1
%v2 = load i32, ptr %ptr2
%add0 = add nuw nsw i32 %v1, %v2
%add1 = add nuw nsw i32 %add0, %v0
ret i32 %add1
}

define i32 @nsw_dont_preserve_negative(ptr %ptr0, ptr %ptr1, ptr %ptr2) {
; CHECK-LABEL: define i32 @nsw_dont_preserve_negative(
; CHECK-SAME: ptr [[PTR0:%.*]], ptr [[PTR1:%.*]], ptr [[PTR2:%.*]]) {
; CHECK-NEXT: [[V0:%.*]] = load i32, ptr [[PTR0]], align 4
; CHECK-NEXT: [[V1:%.*]] = load i32, ptr [[PTR1]], align 4, !range [[RNG0]]
; CHECK-NEXT: [[V2:%.*]] = load i32, ptr [[PTR2]], align 4, !range [[RNG0]]
; CHECK-NEXT: [[ADD0:%.*]] = add i32 [[V1]], [[V0]]
; CHECK-NEXT: [[ADD1:%.*]] = add i32 [[ADD0]], [[V2]]
; CHECK-NEXT: ret i32 [[ADD1]]
;
%v0 = load i32, ptr %ptr0
%v1 = load i32, ptr %ptr1, !range !1
%v2 = load i32, ptr %ptr2, !range !1
%add0 = add nsw i32 %v1, %v2
%add1 = add nsw i32 %add0, %v0
ret i32 %add1
}

define i32 @nsw_nopreserve_notallnsw(ptr %ptr0, ptr %ptr1, ptr %ptr2) {
; CHECK-LABEL: define i32 @nsw_nopreserve_notallnsw(
; CHECK-SAME: ptr [[PTR0:%.*]], ptr [[PTR1:%.*]], ptr [[PTR2:%.*]]) {
; CHECK-NEXT: [[V0:%.*]] = load i32, ptr [[PTR0]], align 4, !range [[RNG0:![0-9]+]]
; CHECK-NEXT: [[V1:%.*]] = load i32, ptr [[PTR1]], align 4, !range [[RNG0]]
; CHECK-NEXT: [[V2:%.*]] = load i32, ptr [[PTR2]], align 4, !range [[RNG0]]
; CHECK-NEXT: [[ADD0:%.*]] = add i32 [[V1]], [[V0]]
; CHECK-NEXT: [[ADD1:%.*]] = add i32 [[ADD0]], [[V2]]
; CHECK-NEXT: ret i32 [[ADD1]]
;
%v0 = load i32, ptr %ptr0, !range !1
%v1 = load i32, ptr %ptr1, !range !1
%v2 = load i32, ptr %ptr2, !range !1
%add0 = add nsw i32 %v1, %v2
%add1 = add i32 %add0, %v0
ret i32 %add1
}

; Positive 32 bit integers
!1 = !{i32 0, i32 2147483648}
;.
; CHECK: [[RNG0]] = !{i32 0, i32 -2147483648}
;.

0 comments on commit 73e22ff

Please sign in to comment.