Skip to content

Commit

Permalink
[Comb][Fold] Fix idemp n^2 and bugfix + more optimization. (#7514)
Browse files Browse the repository at this point in the history
* Fix n^2 behavior with canonicalizeIdempotentInputs.
  Add arbitrary "depth" check to bound the search.
* Don't allow flattening or idempotent to search into operations
  defined in other blocks.
* Support removing duplicates even when operands come from other
  blocks. (or(x, y, x) -> or(x, x) regardless of their origin).
* Don't walk into operations with different two-state-ness in the
  idempotent operand canonicalizer.
* When creating new operation in idempotent canonicalizer,
  create it with matching two-state-ness.

Add tests for functional changes above.
  • Loading branch information
dtzSiFive authored Aug 13, 2024
1 parent ff40aab commit ebb2429
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 26 deletions.
70 changes: 44 additions & 26 deletions lib/Dialect/Comb/CombFolds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,15 +122,16 @@ static inline ComplementMatcher<SubType> m_Complement(const SubType &subExpr) {
}

/// Return true if the op will be flattened afterwards. Op will be flattend if
/// it has a single user which has a same op type.
/// it has a single user which has a same op type. User must be in same block.
static bool shouldBeFlattened(Operation *op) {
assert((isa<AndOp, OrOp, XorOp, AddOp, MulOp>(op) &&
"must be commutative operations"));
if (op->hasOneUse()) {
auto *user = *op->getUsers().begin();
return user->getName() == op->getName() &&
op->getAttrOfType<UnitAttr>("twoState") ==
user->getAttrOfType<UnitAttr>("twoState");
user->getAttrOfType<UnitAttr>("twoState") &&
op->getBlock() == user->getBlock();
}
return false;
}
Expand Down Expand Up @@ -169,8 +170,11 @@ static bool tryFlatteningOperands(Operation *op, PatternRewriter &rewriter) {

Value value = *element.current++;
auto *flattenOp = value.getDefiningOp();
// If not defined by a compatible operation of the same kind and
// from the same block, keep this as-is.
if (!flattenOp || flattenOp->getName() != op->getName() ||
flattenOp == op || binFlag != op->hasAttrOfType<UnitAttr>("twoState")) {
flattenOp == op || binFlag != op->hasAttrOfType<UnitAttr>("twoState") ||
flattenOp->getBlock() != op->getBlock()) {
newOperands.push_back(value);
continue;
}
Expand Down Expand Up @@ -933,47 +937,57 @@ static Value getCommonOperand(Op op) {
/// Example: `and(x, y, x, z)` -> `and(x, y, z)`
template <typename Op>
static bool canonicalizeIdempotentInputs(Op op, PatternRewriter &rewriter) {
// Depth limit to search, in operations. Chosen arbitrarily, keep small.
constexpr unsigned limit = 3;
auto inputs = op.getInputs();

llvm::SmallSetVector<Value, 8> uniqueInputs(inputs.begin(), inputs.end());
llvm::SmallDenseSet<Value, 8> checked;
llvm::SmallDenseSet<Op, 8> checked;
checked.insert(op);

llvm::SmallVector<Value, 8> worklist;
for (auto input : inputs) {
if (input != op)
worklist.push_back(input);
}
struct OpWithDepth {
Op op;
unsigned depth;
};
llvm::SmallVector<OpWithDepth, 8> worklist;

auto enqueue = [&worklist, &checked, &op](Value input, unsigned depth) {
// Add to worklist if within depth limit, is defined in the same block by
// the same kind of operation, has same two-state-ness, and not enqueued
// previously.
if (depth < limit && input.getParentBlock() == op->getBlock()) {
auto inputOp = input.template getDefiningOp<Op>();
if (inputOp && inputOp.getTwoState() == op.getTwoState() &&
checked.insert(inputOp).second)
worklist.push_back({inputOp, depth + 1});
}
};

while (!worklist.empty()) {
auto element = worklist.pop_back_val();
for (auto input : uniqueInputs)
enqueue(input, 0);

if (auto idempotentOp = element.getDefiningOp<Op>()) {
for (auto input : idempotentOp.getInputs()) {
uniqueInputs.remove(input);
while (!worklist.empty()) {
auto item = worklist.pop_back_val();

if (checked.insert(input).second)
worklist.push_back(input);
}
for (auto input : item.op.getInputs()) {
uniqueInputs.remove(input);
enqueue(input, item.depth);
}
}

if (uniqueInputs.size() < inputs.size()) {
replaceOpWithNewOpAndCopyName<Op>(rewriter, op, op.getType(),
uniqueInputs.getArrayRef());
uniqueInputs.getArrayRef(),
op.getTwoState());
return true;
}

return false;
}

LogicalResult AndOp::canonicalize(AndOp op, PatternRewriter &rewriter) {
if (hasOperandsOutsideOfBlock(&*op))
return failure();

auto inputs = op.getInputs();
auto size = inputs.size();
assert(size > 1 && "expected 2 or more operands, `fold` should handle this");

// and(x, and(...)) -> and(x, ...) -- flatten
if (tryFlatteningOperands(op, rewriter))
Expand All @@ -985,6 +999,10 @@ LogicalResult AndOp::canonicalize(AndOp op, PatternRewriter &rewriter) {
if (size > 1 && canonicalizeIdempotentInputs(op, rewriter))
return success();

if (hasOperandsOutsideOfBlock(&*op))
return failure();
assert(size > 1 && "expected 2 or more operands, `fold` should handle this");

// Patterns for and with a constant on RHS.
APInt value;
if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
Expand Down Expand Up @@ -1255,12 +1273,8 @@ static bool canonicalizeOrOfConcatsWithCstOperands(OrOp op, size_t concatIdx1,
}

LogicalResult OrOp::canonicalize(OrOp op, PatternRewriter &rewriter) {
if (hasOperandsOutsideOfBlock(&*op))
return failure();

auto inputs = op.getInputs();
auto size = inputs.size();
assert(size > 1 && "expected 2 or more operands");

// or(x, or(...)) -> or(x, ...) -- flatten
if (tryFlatteningOperands(op, rewriter))
Expand All @@ -1272,6 +1286,10 @@ LogicalResult OrOp::canonicalize(OrOp op, PatternRewriter &rewriter) {
if (size > 1 && canonicalizeIdempotentInputs(op, rewriter))
return success();

if (hasOperandsOutsideOfBlock(&*op))
return failure();
assert(size > 1 && "expected 2 or more operands");

// Patterns for and with a constant on RHS.
APInt value;
if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
Expand Down
53 changes: 53 additions & 0 deletions test/Dialect/Comb/canonicalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1553,6 +1553,59 @@ hw.module @OrMuxSameTrueValueAndZero(in %tag_0: i1, in %tag_1: i1, in %tag_2: i1
"terminator"(%add2) : (i32) -> ()
}) : () -> ()

// CHECK-LABEL: "test.acrossBlockCanonicalizationBarrierFlattenAndIdem"
// CHECK: ^bb1:
// CHECK-NEXT: %[[OUT:.+]] = comb.or %0, %1, %2 : i32
// CHECK-NEXT: "terminator"(%[[OUT]]) : (i32) -> ()
"test.acrossBlockCanonicalizationBarrierFlattenAndIdem"() ({
^bb0(%arg0: i32, %arg1: i32, %arg2: i32):
%0 = comb.or %arg0, %arg1 : i32
%1 = comb.or %arg1, %arg2 : i32
%2 = comb.or %arg0, %arg2 : i32
"terminator"() : () -> ()
^bb1: // no predecessors
// Flatten and unique, but not across blocks.
%3 = comb.or %0, %1 : i32
%4 = comb.or %1, %2 : i32
%5 = comb.or %3, %4, %0, %1, %1, %2 : i32

"terminator"(%5) : (i32) -> ()
}) : () -> ()

// CHECK-LABEL: "test.acrossBlockCanonicalizationBarrierIdem"
// CHECK: ^bb1:
// CHECK-NEXT: %[[OUT1:.+]] = comb.or %0, %1 : i32
// CHECK-NEXT: %[[OUT2:.+]] = comb.or %[[OUT1]], %arg0 : i32
// CHECK-NEXT: "terminator"(%[[OUT1]], %[[OUT2]]) : (i32, i32) -> ()
"test.acrossBlockCanonicalizationBarrierIdem"() ({
^bb0(%arg0: i32, %arg1: i32, %arg2: i32):
%0 = comb.or %arg0, %arg1 : i32
%1 = comb.or %arg1, %arg2 : i32
"terminator"() : () -> ()
^bb1: // no predecessors
%2 = comb.or %0, %1, %1 : i32
%3 = comb.or %2, %0, %1, %arg0 : i32

"terminator"(%2, %3) : (i32, i32) -> ()
}) : () -> ()

// Check multi-operation idempotent operand deduplication.
// CHECK-LABEL: @IdemTwoState
// CHECK-NEXT: %[[ZERO:.+]] = comb.or bin %cond, %val1
// CHECK-NEXT: %[[ONE:.+]] = comb.or bin %val1, %val2
// Don't allow dropping these (%0/%1) due to two-state differences.
// CHECK-NEXT: %[[TWO:.+]] = comb.or %[[ZERO]], %[[ONE]]
// New operation should preserve two-state-ness.
// CHECK-NEXT: %[[THREE:.+]] = comb.or bin %[[TWO]], %[[ZERO]], %[[ONE]]
// CHECK-NEXT: hw.output %[[ZERO]], %[[ONE]], %[[TWO]], %[[THREE]]
hw.module @IdemTwoState(in %cond: i32, in %val1: i32, in %val2: i32, out o1: i32, out o2: i32, out o3: i32, out o4: i32) {
%0 = comb.or bin %cond, %val1 : i32
%1 = comb.or bin %val1, %val2: i32
%2 = comb.or %0, %1 : i32
%3 = comb.or bin %cond, %val1, %val2, %2, %0, %1 : i32
hw.output %0, %1, %2, %3: i32, i32, i32, i32
}

// CHECK-LABEL: hw.module @combineOppositeBinCmpIntoConstant
// CHECK: %[[TRUE:.+]] = hw.constant true
// CHECK: %[[FALSE:.+]] = hw.constant false
Expand Down

0 comments on commit ebb2429

Please sign in to comment.