Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Reassociate] Preserve NSW flags after expr tree rewriting #93105

Merged
merged 9 commits into from
May 28, 2024
13 changes: 12 additions & 1 deletion llvm/include/llvm/Transforms/Scalar/Reassociate.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,17 @@ struct Factor {
Factor(Value *Base, unsigned Power) : Base(Base), Power(Power) {}
};

struct OverflowTracking {
bool HasNUW;
bool HasNSW;
bool AllKnownNonNegative;
// Note: AllKnownNegative can be true in a case where one of the operands
// is negative, but one the operators is not NSW. AllKnownNegative should
// not be used independently of HasNSW
OverflowTracking(void)
akshayrdeodhar marked this conversation as resolved.
Show resolved Hide resolved
: HasNUW(true), HasNSW(true), AllKnownNonNegative(true) {}
};

class XorOpnd;

} // end namespace reassociate
Expand Down Expand Up @@ -103,7 +114,7 @@ class ReassociatePass : public PassInfoMixin<ReassociatePass> {
void ReassociateExpression(BinaryOperator *I);
void RewriteExprTree(BinaryOperator *I,
SmallVectorImpl<reassociate::ValueEntry> &Ops,
bool HasNUW);
reassociate::OverflowTracking Flags);
Value *OptimizeExpression(BinaryOperator *I,
SmallVectorImpl<reassociate::ValueEntry> &Ops);
Value *OptimizeAdd(Instruction *I,
Expand Down
37 changes: 24 additions & 13 deletions llvm/lib/Transforms/Scalar/Reassociate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ using RepeatedValue = std::pair<Value*, APInt>;
static bool LinearizeExprTree(Instruction *I,
SmallVectorImpl<RepeatedValue> &Ops,
ReassociatePass::OrderedSet &ToRedo,
bool &HasNUW) {
reassociate::OverflowTracking &Flags) {
assert((isa<UnaryOperator>(I) || isa<BinaryOperator>(I)) &&
"Expected a UnaryOperator or BinaryOperator!");
LLVM_DEBUG(dbgs() << "LINEARIZE: " << *I << '\n');
Expand Down Expand Up @@ -512,6 +512,7 @@ static bool LinearizeExprTree(Instruction *I,
using LeafMap = DenseMap<Value *, APInt>;
LeafMap Leaves; // Leaf -> Total weight so far.
SmallVector<Value *, 8> LeafOrder; // Ensure deterministic leaf output order.
const DataLayout DL = I->getModule()->getDataLayout();

#ifndef NDEBUG
SmallPtrSet<Value *, 8> Visited; // For checking the iteration scheme.
Expand All @@ -520,8 +521,10 @@ static bool LinearizeExprTree(Instruction *I,
std::pair<Instruction*, APInt> P = Worklist.pop_back_val();
I = P.first; // We examine the operands of this binary operator.

if (isa<OverflowingBinaryOperator>(I))
HasNUW &= I->hasNoUnsignedWrap();
if (isa<OverflowingBinaryOperator>(I)) {
Flags.HasNUW &= I->hasNoUnsignedWrap();
Flags.HasNSW &= I->hasNoSignedWrap();
}

for (unsigned OpIdx = 0; OpIdx < I->getNumOperands(); ++OpIdx) { // Visit operands.
Value *Op = I->getOperand(OpIdx);
Expand Down Expand Up @@ -648,6 +651,10 @@ static bool LinearizeExprTree(Instruction *I,
// Ensure the leaf is only output once.
It->second = 0;
Ops.push_back(std::make_pair(V, Weight));
if (Opcode == Instruction::Add && Flags.AllKnownNonNegative &&
Flags.HasNSW) {
dtcxzyw marked this conversation as resolved.
Show resolved Hide resolved
Flags.AllKnownNonNegative &= isKnownNonNegative(V, SimplifyQuery(DL));
}
}

// For nilpotent operations or addition there may be no operands, for example
Expand All @@ -666,7 +673,7 @@ static bool LinearizeExprTree(Instruction *I,
/// linearized and optimized, emit them in-order.
void ReassociatePass::RewriteExprTree(BinaryOperator *I,
SmallVectorImpl<ValueEntry> &Ops,
bool HasNUW) {
OverflowTracking Flags) {
assert(Ops.size() > 1 && "Single values should be used directly!");

// Since our optimizations should never increase the number of operations, the
Expand Down Expand Up @@ -834,8 +841,12 @@ void ReassociatePass::RewriteExprTree(BinaryOperator *I,
// Note that it doesn't hold for mul if one of the operands is zero.
// TODO: We can preserve NUW flag if we prove that all mul operands
// are non-zero.
if (HasNUW && ExpressionChangedStart->getOpcode() == Instruction::Add)
ExpressionChangedStart->setHasNoUnsignedWrap();
if (ExpressionChangedStart->getOpcode() == Instruction::Add) {
if (Flags.HasNUW)
ExpressionChangedStart->setHasNoUnsignedWrap();
if (Flags.HasNSW && (Flags.AllKnownNonNegative || Flags.HasNUW))
ExpressionChangedStart->setHasNoSignedWrap();
}
}
}

Expand Down Expand Up @@ -1192,8 +1203,8 @@ Value *ReassociatePass::RemoveFactorFromExpression(Value *V, Value *Factor) {
return nullptr;

SmallVector<RepeatedValue, 8> Tree;
bool HasNUW = true;
MadeChange |= LinearizeExprTree(BO, Tree, RedoInsts, HasNUW);
OverflowTracking Flags;
MadeChange |= LinearizeExprTree(BO, Tree, RedoInsts, Flags);
SmallVector<ValueEntry, 8> Factors;
Factors.reserve(Tree.size());
for (unsigned i = 0, e = Tree.size(); i != e; ++i) {
Expand Down Expand Up @@ -1235,7 +1246,7 @@ Value *ReassociatePass::RemoveFactorFromExpression(Value *V, Value *Factor) {

if (!FoundFactor) {
// Make sure to restore the operands to the expression tree.
RewriteExprTree(BO, Factors, HasNUW);
RewriteExprTree(BO, Factors, Flags);
return nullptr;
}

Expand All @@ -1247,7 +1258,7 @@ Value *ReassociatePass::RemoveFactorFromExpression(Value *V, Value *Factor) {
RedoInsts.insert(BO);
V = Factors[0].Op;
} else {
RewriteExprTree(BO, Factors, HasNUW);
RewriteExprTree(BO, Factors, Flags);
V = BO;
}

Expand Down Expand Up @@ -2373,8 +2384,8 @@ void ReassociatePass::ReassociateExpression(BinaryOperator *I) {
// First, walk the expression tree, linearizing the tree, collecting the
// operand information.
SmallVector<RepeatedValue, 8> Tree;
bool HasNUW = true;
MadeChange |= LinearizeExprTree(I, Tree, RedoInsts, HasNUW);
OverflowTracking Flags;
MadeChange |= LinearizeExprTree(I, Tree, RedoInsts, Flags);
SmallVector<ValueEntry, 8> Ops;
Ops.reserve(Tree.size());
for (const RepeatedValue &E : Tree)
Expand Down Expand Up @@ -2567,7 +2578,7 @@ void ReassociatePass::ReassociateExpression(BinaryOperator *I) {
dbgs() << '\n');
// Now that we ordered and optimized the expressions, splat them back into
// the expression tree, removing any unneeded nodes.
RewriteExprTree(I, Ops, HasNUW);
RewriteExprTree(I, Ops, Flags);
}

void
Expand Down
40 changes: 20 additions & 20 deletions llvm/test/Transforms/Reassociate/local-cse.ll
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,16 @@ define void @chain_spanning_several_blocks(i64 %inv1, i64 %inv2, i64 %inv3, i64
; LOCAL_CSE-LABEL: define void @chain_spanning_several_blocks
; LOCAL_CSE-SAME: (i64 [[INV1:%.*]], i64 [[INV2:%.*]], i64 [[INV3:%.*]], i64 [[INV4:%.*]], i64 [[INV5:%.*]]) {
; LOCAL_CSE-NEXT: bb1:
; LOCAL_CSE-NEXT: [[CHAIN_A0:%.*]] = add nuw i64 [[INV2]], [[INV1]]
; LOCAL_CSE-NEXT: [[CHAIN_A0:%.*]] = add nuw nsw i64 [[INV2]], [[INV1]]
; LOCAL_CSE-NEXT: br label [[BB2:%.*]]
; LOCAL_CSE: bb2:
; LOCAL_CSE-NEXT: [[VAL_BB2:%.*]] = call i64 @get_val()
; LOCAL_CSE-NEXT: [[CHAIN_A1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV4]]
; LOCAL_CSE-NEXT: [[CHAIN_A2:%.*]] = add nuw i64 [[CHAIN_A1]], [[VAL_BB2]]
; LOCAL_CSE-NEXT: [[CHAIN_B1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV5]]
; LOCAL_CSE-NEXT: [[CHAIN_B2:%.*]] = add nuw i64 [[CHAIN_B1]], [[VAL_BB2]]
; LOCAL_CSE-NEXT: [[CHAIN_C0:%.*]] = add nuw i64 [[INV3]], [[INV1]]
; LOCAL_CSE-NEXT: [[CHAIN_C1:%.*]] = add nuw i64 [[CHAIN_C0]], [[VAL_BB2]]
; LOCAL_CSE-NEXT: [[CHAIN_A1:%.*]] = add nuw nsw i64 [[CHAIN_A0]], [[INV4]]
; LOCAL_CSE-NEXT: [[CHAIN_A2:%.*]] = add nuw nsw i64 [[CHAIN_A1]], [[VAL_BB2]]
; LOCAL_CSE-NEXT: [[CHAIN_B1:%.*]] = add nuw nsw i64 [[CHAIN_A0]], [[INV5]]
; LOCAL_CSE-NEXT: [[CHAIN_B2:%.*]] = add nuw nsw i64 [[CHAIN_B1]], [[VAL_BB2]]
; LOCAL_CSE-NEXT: [[CHAIN_C0:%.*]] = add nuw nsw i64 [[INV3]], [[INV1]]
; LOCAL_CSE-NEXT: [[CHAIN_C1:%.*]] = add nuw nsw i64 [[CHAIN_C0]], [[VAL_BB2]]
; LOCAL_CSE-NEXT: call void @keep_alive(i64 [[CHAIN_A2]])
; LOCAL_CSE-NEXT: call void @keep_alive(i64 [[CHAIN_B2]])
; LOCAL_CSE-NEXT: call void @keep_alive(i64 [[CHAIN_C1]])
Expand All @@ -47,11 +47,11 @@ define void @chain_spanning_several_blocks(i64 %inv1, i64 %inv2, i64 %inv3, i64
; CSE-NEXT: br label [[BB2:%.*]]
; CSE: bb2:
; CSE-NEXT: [[VAL_BB2:%.*]] = call i64 @get_val()
; CSE-NEXT: [[CHAIN_A0:%.*]] = add nuw i64 [[VAL_BB2]], [[INV1]]
; CSE-NEXT: [[CHAIN_A1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV2]]
; CSE-NEXT: [[CHAIN_A0:%.*]] = add nuw nsw i64 [[VAL_BB2]], [[INV1]]
; CSE-NEXT: [[CHAIN_A1:%.*]] = add nuw nsw i64 [[CHAIN_A0]], [[INV2]]
; CSE-NEXT: [[CHAIN_A2:%.*]] = add nuw nsw i64 [[CHAIN_A1]], [[INV4]]
; CSE-NEXT: [[CHAIN_B2:%.*]] = add nuw nsw i64 [[CHAIN_A1]], [[INV5]]
; CSE-NEXT: [[CHAIN_C1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV3]]
; CSE-NEXT: [[CHAIN_C1:%.*]] = add nuw nsw i64 [[CHAIN_A0]], [[INV3]]
; CSE-NEXT: call void @keep_alive(i64 [[CHAIN_A2]])
; CSE-NEXT: call void @keep_alive(i64 [[CHAIN_B2]])
; CSE-NEXT: call void @keep_alive(i64 [[CHAIN_C1]])
Expand Down Expand Up @@ -90,19 +90,19 @@ define void @chain_spanning_several_blocks_no_entry_anchor() {
; LOCAL_CSE-NEXT: br label [[BB1:%.*]]
; LOCAL_CSE: bb1:
; LOCAL_CSE-NEXT: [[INV1_BB1:%.*]] = call i64 @get_val()
; LOCAL_CSE-NEXT: [[CHAIN_A0:%.*]] = add nuw i64 [[INV1_BB1]], [[INV2_BB0]]
; LOCAL_CSE-NEXT: [[CHAIN_A0:%.*]] = add nuw nsw i64 [[INV1_BB1]], [[INV2_BB0]]
; LOCAL_CSE-NEXT: br label [[BB2:%.*]]
; LOCAL_CSE: bb2:
; LOCAL_CSE-NEXT: [[INV3_BB2:%.*]] = call i64 @get_val()
; LOCAL_CSE-NEXT: [[INV4_BB2:%.*]] = call i64 @get_val()
; LOCAL_CSE-NEXT: [[INV5_BB2:%.*]] = call i64 @get_val()
; LOCAL_CSE-NEXT: [[VAL_BB2:%.*]] = call i64 @get_val()
; LOCAL_CSE-NEXT: [[CHAIN_A1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV4_BB2]]
; LOCAL_CSE-NEXT: [[CHAIN_A2:%.*]] = add nuw i64 [[CHAIN_A1]], [[VAL_BB2]]
; LOCAL_CSE-NEXT: [[CHAIN_B1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV5_BB2]]
; LOCAL_CSE-NEXT: [[CHAIN_B2:%.*]] = add nuw i64 [[CHAIN_B1]], [[VAL_BB2]]
; LOCAL_CSE-NEXT: [[CHAIN_C0:%.*]] = add nuw i64 [[VAL_BB2]], [[INV1_BB1]]
; LOCAL_CSE-NEXT: [[CHAIN_C1:%.*]] = add nuw i64 [[CHAIN_C0]], [[INV3_BB2]]
; LOCAL_CSE-NEXT: [[CHAIN_A1:%.*]] = add nuw nsw i64 [[CHAIN_A0]], [[INV4_BB2]]
; LOCAL_CSE-NEXT: [[CHAIN_A2:%.*]] = add nuw nsw i64 [[CHAIN_A1]], [[VAL_BB2]]
; LOCAL_CSE-NEXT: [[CHAIN_B1:%.*]] = add nuw nsw i64 [[CHAIN_A0]], [[INV5_BB2]]
; LOCAL_CSE-NEXT: [[CHAIN_B2:%.*]] = add nuw nsw i64 [[CHAIN_B1]], [[VAL_BB2]]
; LOCAL_CSE-NEXT: [[CHAIN_C0:%.*]] = add nuw nsw i64 [[VAL_BB2]], [[INV1_BB1]]
; LOCAL_CSE-NEXT: [[CHAIN_C1:%.*]] = add nuw nsw i64 [[CHAIN_C0]], [[INV3_BB2]]
; LOCAL_CSE-NEXT: call void @keep_alive(i64 [[CHAIN_A2]])
; LOCAL_CSE-NEXT: call void @keep_alive(i64 [[CHAIN_B2]])
; LOCAL_CSE-NEXT: call void @keep_alive(i64 [[CHAIN_C1]])
Expand All @@ -120,11 +120,11 @@ define void @chain_spanning_several_blocks_no_entry_anchor() {
; CSE-NEXT: [[INV4_BB2:%.*]] = call i64 @get_val()
; CSE-NEXT: [[INV5_BB2:%.*]] = call i64 @get_val()
; CSE-NEXT: [[VAL_BB2:%.*]] = call i64 @get_val()
; CSE-NEXT: [[CHAIN_A0:%.*]] = add nuw i64 [[VAL_BB2]], [[INV1_BB1]]
; CSE-NEXT: [[CHAIN_A1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV2_BB0]]
; CSE-NEXT: [[CHAIN_A0:%.*]] = add nuw nsw i64 [[VAL_BB2]], [[INV1_BB1]]
; CSE-NEXT: [[CHAIN_A1:%.*]] = add nuw nsw i64 [[CHAIN_A0]], [[INV2_BB0]]
; CSE-NEXT: [[CHAIN_A2:%.*]] = add nuw nsw i64 [[CHAIN_A1]], [[INV4_BB2]]
; CSE-NEXT: [[CHAIN_B2:%.*]] = add nuw nsw i64 [[CHAIN_A1]], [[INV5_BB2]]
; CSE-NEXT: [[CHAIN_C1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV3_BB2]]
; CSE-NEXT: [[CHAIN_C1:%.*]] = add nuw nsw i64 [[CHAIN_A0]], [[INV3_BB2]]
; CSE-NEXT: call void @keep_alive(i64 [[CHAIN_A2]])
; CSE-NEXT: call void @keep_alive(i64 [[CHAIN_B2]])
; CSE-NEXT: call void @keep_alive(i64 [[CHAIN_C1]])
Expand Down
79 changes: 79 additions & 0 deletions llvm/test/Transforms/Reassociate/reassoc-add-nsw.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
; RUN: opt < %s -passes=reassociate -S | FileCheck %s
define i32 @nsw_preserve_nonnegative(ptr %ptr0, ptr %ptr1, ptr %ptr2) {
; CHECK-LABEL: define i32 @nsw_preserve_nonnegative(
; CHECK-SAME: ptr [[PTR0:%.*]], ptr [[PTR1:%.*]], ptr [[PTR2:%.*]]) {
; CHECK-NEXT: [[V0:%.*]] = load i32, ptr [[PTR0]], align 4, !range [[RNG0:![0-9]+]]
; CHECK-NEXT: [[V1:%.*]] = load i32, ptr [[PTR1]], align 4, !range [[RNG0]]
; CHECK-NEXT: [[V2:%.*]] = load i32, ptr [[PTR2]], align 4, !range [[RNG0]]
; CHECK-NEXT: [[ADD0:%.*]] = add nsw i32 [[V1]], [[V0]]
; CHECK-NEXT: [[ADD1:%.*]] = add nsw i32 [[ADD0]], [[V2]]
; CHECK-NEXT: ret i32 [[ADD1]]
;
%v0 = load i32, ptr %ptr0, !range !1
%v1 = load i32, ptr %ptr1, !range !1
%v2 = load i32, ptr %ptr2, !range !1
%add0 = add nsw i32 %v1, %v2
%add1 = add nsw i32 %add0, %v0
ret i32 %add1
}

define i32 @nsw_preserve_nuw_nsw(ptr %ptr0, ptr %ptr1, ptr %ptr2) {
; CHECK-LABEL: define i32 @nsw_preserve_nuw_nsw(
; CHECK-SAME: ptr [[PTR0:%.*]], ptr [[PTR1:%.*]], ptr [[PTR2:%.*]]) {
; CHECK-NEXT: [[V0:%.*]] = load i32, ptr [[PTR0]], align 4
; CHECK-NEXT: [[V1:%.*]] = load i32, ptr [[PTR1]], align 4
; CHECK-NEXT: [[V2:%.*]] = load i32, ptr [[PTR2]], align 4
; CHECK-NEXT: [[ADD0:%.*]] = add nuw nsw i32 [[V1]], [[V0]]
; CHECK-NEXT: [[ADD1:%.*]] = add nuw nsw i32 [[ADD0]], [[V2]]
; CHECK-NEXT: ret i32 [[ADD1]]
;
%v0 = load i32, ptr %ptr0
%v1 = load i32, ptr %ptr1
%v2 = load i32, ptr %ptr2
%add0 = add nuw nsw i32 %v1, %v2
%add1 = add nuw nsw i32 %add0, %v0
ret i32 %add1
}

define i32 @nsw_dont_preserve_negative(ptr %ptr0, ptr %ptr1, ptr %ptr2) {
; CHECK-LABEL: define i32 @nsw_dont_preserve_negative(
; CHECK-SAME: ptr [[PTR0:%.*]], ptr [[PTR1:%.*]], ptr [[PTR2:%.*]]) {
; CHECK-NEXT: [[V0:%.*]] = load i32, ptr [[PTR0]], align 4
; CHECK-NEXT: [[V1:%.*]] = load i32, ptr [[PTR1]], align 4, !range [[RNG0]]
; CHECK-NEXT: [[V2:%.*]] = load i32, ptr [[PTR2]], align 4, !range [[RNG0]]
; CHECK-NEXT: [[ADD0:%.*]] = add i32 [[V1]], [[V0]]
; CHECK-NEXT: [[ADD1:%.*]] = add i32 [[ADD0]], [[V2]]
; CHECK-NEXT: ret i32 [[ADD1]]
;
%v0 = load i32, ptr %ptr0
%v1 = load i32, ptr %ptr1, !range !1
%v2 = load i32, ptr %ptr2, !range !1
%add0 = add nsw i32 %v1, %v2
%add1 = add nsw i32 %add0, %v0
ret i32 %add1
}
nikic marked this conversation as resolved.
Show resolved Hide resolved

define i32 @nsw_nopreserve_notallnsw(ptr %ptr0, ptr %ptr1, ptr %ptr2) {
; CHECK-LABEL: define i32 @nsw_nopreserve_notallnsw(
; CHECK-SAME: ptr [[PTR0:%.*]], ptr [[PTR1:%.*]], ptr [[PTR2:%.*]]) {
; CHECK-NEXT: [[V0:%.*]] = load i32, ptr [[PTR0]], align 4, !range [[RNG0:![0-9]+]]
; CHECK-NEXT: [[V1:%.*]] = load i32, ptr [[PTR1]], align 4, !range [[RNG0]]
; CHECK-NEXT: [[V2:%.*]] = load i32, ptr [[PTR2]], align 4, !range [[RNG0]]
; CHECK-NEXT: [[ADD0:%.*]] = add i32 [[V1]], [[V0]]
; CHECK-NEXT: [[ADD1:%.*]] = add i32 [[ADD0]], [[V2]]
; CHECK-NEXT: ret i32 [[ADD1]]
;
%v0 = load i32, ptr %ptr0, !range !1
%v1 = load i32, ptr %ptr1, !range !1
%v2 = load i32, ptr %ptr2, !range !1
%add0 = add nsw i32 %v1, %v2
%add1 = add i32 %add0, %v0
ret i32 %add1
}

; Positive 32 bit integers
!1 = !{i32 0, i32 2147483648}
;.
; CHECK: [[RNG0]] = !{i32 0, i32 -2147483648}
;.
Loading