From ebb2429224aa67858d4b9ad51dedd53e8577452c Mon Sep 17 00:00:00 2001 From: Will Dietz Date: Tue, 13 Aug 2024 09:15:17 -0500 Subject: [PATCH] [Comb][Fold] Fix idemp n^2 and bugfix + more optimization. (#7514) * Fix n^2 behavior with canonicalizeIdempotentInputs. Add arbitrary "depth" check to bound the search. * Don't allow flattening or idempotent to search into operations defined in other blocks. * Support removing duplicates even when operands come from other blocks. (or(x, y, x) -> or(x, x) regardless of their origin). * Don't walk into operations with different two-state-ness in the idempotent operand canonicalizer. * When creating new operation in idempotent canonicalizer, create it with matching two-state-ness. Add tests for functional changes above. --- lib/Dialect/Comb/CombFolds.cpp | 70 ++++++++++++++++--------- test/Dialect/Comb/canonicalization.mlir | 53 +++++++++++++++++++ 2 files changed, 97 insertions(+), 26 deletions(-) diff --git a/lib/Dialect/Comb/CombFolds.cpp b/lib/Dialect/Comb/CombFolds.cpp index 913be97a76ff..96cb562b408c 100644 --- a/lib/Dialect/Comb/CombFolds.cpp +++ b/lib/Dialect/Comb/CombFolds.cpp @@ -122,7 +122,7 @@ static inline ComplementMatcher m_Complement(const SubType &subExpr) { } /// Return true if the op will be flattened afterwards. Op will be flattend if -/// it has a single user which has a same op type. +/// it has a single user which has a same op type. User must be in same block. static bool shouldBeFlattened(Operation *op) { assert((isa(op) && "must be commutative operations")); @@ -130,7 +130,8 @@ static bool shouldBeFlattened(Operation *op) { auto *user = *op->getUsers().begin(); return user->getName() == op->getName() && op->getAttrOfType("twoState") == - user->getAttrOfType("twoState"); + user->getAttrOfType("twoState") && + op->getBlock() == user->getBlock(); } return false; } @@ -169,8 +170,11 @@ static bool tryFlatteningOperands(Operation *op, PatternRewriter &rewriter) { Value value = *element.current++; auto *flattenOp = value.getDefiningOp(); + // If not defined by a compatible operation of the same kind and + // from the same block, keep this as-is. if (!flattenOp || flattenOp->getName() != op->getName() || - flattenOp == op || binFlag != op->hasAttrOfType("twoState")) { + flattenOp == op || binFlag != op->hasAttrOfType("twoState") || + flattenOp->getBlock() != op->getBlock()) { newOperands.push_back(value); continue; } @@ -933,34 +937,48 @@ static Value getCommonOperand(Op op) { /// Example: `and(x, y, x, z)` -> `and(x, y, z)` template static bool canonicalizeIdempotentInputs(Op op, PatternRewriter &rewriter) { + // Depth limit to search, in operations. Chosen arbitrarily, keep small. + constexpr unsigned limit = 3; auto inputs = op.getInputs(); llvm::SmallSetVector uniqueInputs(inputs.begin(), inputs.end()); - llvm::SmallDenseSet checked; + llvm::SmallDenseSet checked; checked.insert(op); - llvm::SmallVector worklist; - for (auto input : inputs) { - if (input != op) - worklist.push_back(input); - } + struct OpWithDepth { + Op op; + unsigned depth; + }; + llvm::SmallVector worklist; + + auto enqueue = [&worklist, &checked, &op](Value input, unsigned depth) { + // Add to worklist if within depth limit, is defined in the same block by + // the same kind of operation, has same two-state-ness, and not enqueued + // previously. + if (depth < limit && input.getParentBlock() == op->getBlock()) { + auto inputOp = input.template getDefiningOp(); + if (inputOp && inputOp.getTwoState() == op.getTwoState() && + checked.insert(inputOp).second) + worklist.push_back({inputOp, depth + 1}); + } + }; - while (!worklist.empty()) { - auto element = worklist.pop_back_val(); + for (auto input : uniqueInputs) + enqueue(input, 0); - if (auto idempotentOp = element.getDefiningOp()) { - for (auto input : idempotentOp.getInputs()) { - uniqueInputs.remove(input); + while (!worklist.empty()) { + auto item = worklist.pop_back_val(); - if (checked.insert(input).second) - worklist.push_back(input); - } + for (auto input : item.op.getInputs()) { + uniqueInputs.remove(input); + enqueue(input, item.depth); } } if (uniqueInputs.size() < inputs.size()) { replaceOpWithNewOpAndCopyName(rewriter, op, op.getType(), - uniqueInputs.getArrayRef()); + uniqueInputs.getArrayRef(), + op.getTwoState()); return true; } @@ -968,12 +986,8 @@ static bool canonicalizeIdempotentInputs(Op op, PatternRewriter &rewriter) { } LogicalResult AndOp::canonicalize(AndOp op, PatternRewriter &rewriter) { - if (hasOperandsOutsideOfBlock(&*op)) - return failure(); - auto inputs = op.getInputs(); auto size = inputs.size(); - assert(size > 1 && "expected 2 or more operands, `fold` should handle this"); // and(x, and(...)) -> and(x, ...) -- flatten if (tryFlatteningOperands(op, rewriter)) @@ -985,6 +999,10 @@ LogicalResult AndOp::canonicalize(AndOp op, PatternRewriter &rewriter) { if (size > 1 && canonicalizeIdempotentInputs(op, rewriter)) return success(); + if (hasOperandsOutsideOfBlock(&*op)) + return failure(); + assert(size > 1 && "expected 2 or more operands, `fold` should handle this"); + // Patterns for and with a constant on RHS. APInt value; if (matchPattern(inputs.back(), m_ConstantInt(&value))) { @@ -1255,12 +1273,8 @@ static bool canonicalizeOrOfConcatsWithCstOperands(OrOp op, size_t concatIdx1, } LogicalResult OrOp::canonicalize(OrOp op, PatternRewriter &rewriter) { - if (hasOperandsOutsideOfBlock(&*op)) - return failure(); - auto inputs = op.getInputs(); auto size = inputs.size(); - assert(size > 1 && "expected 2 or more operands"); // or(x, or(...)) -> or(x, ...) -- flatten if (tryFlatteningOperands(op, rewriter)) @@ -1272,6 +1286,10 @@ LogicalResult OrOp::canonicalize(OrOp op, PatternRewriter &rewriter) { if (size > 1 && canonicalizeIdempotentInputs(op, rewriter)) return success(); + if (hasOperandsOutsideOfBlock(&*op)) + return failure(); + assert(size > 1 && "expected 2 or more operands"); + // Patterns for and with a constant on RHS. APInt value; if (matchPattern(inputs.back(), m_ConstantInt(&value))) { diff --git a/test/Dialect/Comb/canonicalization.mlir b/test/Dialect/Comb/canonicalization.mlir index ea4f6af92f9b..faa7e4033452 100644 --- a/test/Dialect/Comb/canonicalization.mlir +++ b/test/Dialect/Comb/canonicalization.mlir @@ -1553,6 +1553,59 @@ hw.module @OrMuxSameTrueValueAndZero(in %tag_0: i1, in %tag_1: i1, in %tag_2: i1 "terminator"(%add2) : (i32) -> () }) : () -> () +// CHECK-LABEL: "test.acrossBlockCanonicalizationBarrierFlattenAndIdem" +// CHECK: ^bb1: +// CHECK-NEXT: %[[OUT:.+]] = comb.or %0, %1, %2 : i32 +// CHECK-NEXT: "terminator"(%[[OUT]]) : (i32) -> () +"test.acrossBlockCanonicalizationBarrierFlattenAndIdem"() ({ +^bb0(%arg0: i32, %arg1: i32, %arg2: i32): + %0 = comb.or %arg0, %arg1 : i32 + %1 = comb.or %arg1, %arg2 : i32 + %2 = comb.or %arg0, %arg2 : i32 + "terminator"() : () -> () +^bb1: // no predecessors + // Flatten and unique, but not across blocks. + %3 = comb.or %0, %1 : i32 + %4 = comb.or %1, %2 : i32 + %5 = comb.or %3, %4, %0, %1, %1, %2 : i32 + + "terminator"(%5) : (i32) -> () +}) : () -> () + +// CHECK-LABEL: "test.acrossBlockCanonicalizationBarrierIdem" +// CHECK: ^bb1: +// CHECK-NEXT: %[[OUT1:.+]] = comb.or %0, %1 : i32 +// CHECK-NEXT: %[[OUT2:.+]] = comb.or %[[OUT1]], %arg0 : i32 +// CHECK-NEXT: "terminator"(%[[OUT1]], %[[OUT2]]) : (i32, i32) -> () +"test.acrossBlockCanonicalizationBarrierIdem"() ({ +^bb0(%arg0: i32, %arg1: i32, %arg2: i32): + %0 = comb.or %arg0, %arg1 : i32 + %1 = comb.or %arg1, %arg2 : i32 + "terminator"() : () -> () +^bb1: // no predecessors + %2 = comb.or %0, %1, %1 : i32 + %3 = comb.or %2, %0, %1, %arg0 : i32 + + "terminator"(%2, %3) : (i32, i32) -> () +}) : () -> () + +// Check multi-operation idempotent operand deduplication. +// CHECK-LABEL: @IdemTwoState +// CHECK-NEXT: %[[ZERO:.+]] = comb.or bin %cond, %val1 +// CHECK-NEXT: %[[ONE:.+]] = comb.or bin %val1, %val2 +// Don't allow dropping these (%0/%1) due to two-state differences. +// CHECK-NEXT: %[[TWO:.+]] = comb.or %[[ZERO]], %[[ONE]] +// New operation should preserve two-state-ness. +// CHECK-NEXT: %[[THREE:.+]] = comb.or bin %[[TWO]], %[[ZERO]], %[[ONE]] +// CHECK-NEXT: hw.output %[[ZERO]], %[[ONE]], %[[TWO]], %[[THREE]] +hw.module @IdemTwoState(in %cond: i32, in %val1: i32, in %val2: i32, out o1: i32, out o2: i32, out o3: i32, out o4: i32) { + %0 = comb.or bin %cond, %val1 : i32 + %1 = comb.or bin %val1, %val2: i32 + %2 = comb.or %0, %1 : i32 + %3 = comb.or bin %cond, %val1, %val2, %2, %0, %1 : i32 + hw.output %0, %1, %2, %3: i32, i32, i32, i32 +} + // CHECK-LABEL: hw.module @combineOppositeBinCmpIntoConstant // CHECK: %[[TRUE:.+]] = hw.constant true // CHECK: %[[FALSE:.+]] = hw.constant false