From e0068a6482caceece4240b99070090abfee74971 Mon Sep 17 00:00:00 2001 From: Fabian Schuiki Date: Sun, 13 Oct 2024 18:42:51 -0700 Subject: [PATCH] [Arc] Add dominance-aware pass to sink ops and merge scf.if ops Add the `MergeIfs` pass to the Arc dialect. This pass covers a handful of control flow optimizations that are valuable to pick up after `hw.module`s have been linearized and lowered into `arc.model` ops: - It moves operations closer to their earliest user, if possible sinking them into blocks if all uses are nested in the same block. - It merges adjacent `scf.if` operations with the same condition. - It moves operations in between two `scf.if` operations ahead of the first if op to allow them to be merged. The `MergeIfs` pass can operate on SSACFG regions. It assigns an integer order to each operation and considers that order to determine up to which point operations can be moved without moving beyond their first use and without crossing interfering side-effecting ops. The pass is aware of side-effects, and in particular uses the non-aliasing between `arc.state` and `arc.memory` to track read/write side-effects at a per-state level, which allows for fairly aggressive optimization. Other side-effecting ops act as a hard barrier and will not be moved or moved over. This pass supersedes the very effective `GroupResetsAndEnables` pass we have been using until now. The latter relies on the `LegalizeStateUpdate` pass to run at a later point however, which will be removed in a future PR, thus making this new pass necessary. This is a preparatory step for a later PR that overhauls the LowerState pass. That rewrite will make the `arc.clock_tree` and `arc.passthrough` ops obsolete, which is why they are not present in the tests. --- include/circt/Dialect/Arc/ArcPasses.h | 1 - include/circt/Dialect/Arc/ArcPasses.td | 26 +- lib/Dialect/Arc/Transforms/CMakeLists.txt | 2 +- .../Arc/Transforms/GroupResetsAndEnables.cpp | 224 ---------- lib/Dialect/Arc/Transforms/MergeIfs.cpp | 365 ++++++++++++++++ .../Dialect/Arc/group-resets-and-enables.mlir | 396 ------------------ test/Dialect/Arc/merge-ifs.mlir | 283 +++++++++++++ tools/arcilator/arcilator.cpp | 2 +- 8 files changed, 669 insertions(+), 630 deletions(-) delete mode 100644 lib/Dialect/Arc/Transforms/GroupResetsAndEnables.cpp create mode 100644 lib/Dialect/Arc/Transforms/MergeIfs.cpp delete mode 100644 test/Dialect/Arc/group-resets-and-enables.mlir create mode 100644 test/Dialect/Arc/merge-ifs.mlir diff --git a/include/circt/Dialect/Arc/ArcPasses.h b/include/circt/Dialect/Arc/ArcPasses.h index bd185fa147ef..b31398d5898a 100644 --- a/include/circt/Dialect/Arc/ArcPasses.h +++ b/include/circt/Dialect/Arc/ArcPasses.h @@ -31,7 +31,6 @@ std::unique_ptr createAllocateStatePass(); std::unique_ptr createArcCanonicalizerPass(); std::unique_ptr createDedupPass(); std::unique_ptr createFindInitialVectorsPass(); -std::unique_ptr createGroupResetsAndEnablesPass(); std::unique_ptr createInferMemoriesPass(const InferMemoriesOptions &options = {}); std::unique_ptr createInlineArcsPass(); diff --git a/include/circt/Dialect/Arc/ArcPasses.td b/include/circt/Dialect/Arc/ArcPasses.td index 6a89f69c7b02..8b3005b32e39 100644 --- a/include/circt/Dialect/Arc/ArcPasses.td +++ b/include/circt/Dialect/Arc/ArcPasses.td @@ -92,13 +92,6 @@ def FindInitialVectors : Pass<"arc-find-initial-vectors", "mlir::ModuleOp"> { ]; } -def GroupResetsAndEnables : Pass<"arc-group-resets-and-enables", - "mlir::ModuleOp"> { - let summary = "Group reset and enable conditions of lowered states"; - let constructor = "circt::arc::createGroupResetsAndEnablesPass()"; - let dependentDialects = ["arc::ArcDialect", "mlir::scf::SCFDialect"]; -} - def InferMemories : Pass<"arc-infer-memories", "mlir::ModuleOp"> { let summary = "Convert `FIRRTL_Memory` instances to dedicated memory ops"; let constructor = "circt::arc::createInferMemoriesPass()"; @@ -282,6 +275,25 @@ def MakeTables : Pass<"arc-make-tables", "mlir::ModuleOp"> { let dependentDialects = ["arc::ArcDialect"]; } +def MergeIfsPass : Pass<"arc-merge-ifs"> { + let summary = "Merge control flow structures"; + let description = [{ + This pass optimizes control flow in a few ways. It moves operations closer + to their earliest user, if possible sinking them into blocks if all uses are + nested in the same block. It merges adjacent `scf.if` operations with the + same condition. And it moves operations in between two `scf.if` operations + ahead of the first if op to allow them to be merged. The pass runs on any + SSACFG regions nested under the operation it is applied to. + }]; + let statistics = [ + Statistic<"numOpsSunk", "sunk", "Ops sunk into blocks">, + Statistic<"numOpsMovedToUser", "moved-to-user", "Ops moved to first user">, + Statistic<"numIfsMerged", "ifs-merged", "Adjacent scf.if ops merged">, + Statistic<"numOpsMovedFromBetweenIfs", "moved-from-between-ifs", + "Ops moved from between ifs to enable merging">, + ]; +} + def MuxToControlFlow : Pass<"arc-mux-to-control-flow", "mlir::ModuleOp"> { let summary = "Convert muxes with large independent fan-ins to if-statements"; let constructor = "circt::arc::createMuxToControlFlowPass()"; diff --git a/lib/Dialect/Arc/Transforms/CMakeLists.txt b/lib/Dialect/Arc/Transforms/CMakeLists.txt index d3690cd408be..b9362e2f1ff9 100644 --- a/lib/Dialect/Arc/Transforms/CMakeLists.txt +++ b/lib/Dialect/Arc/Transforms/CMakeLists.txt @@ -4,7 +4,6 @@ add_circt_dialect_library(CIRCTArcTransforms ArcCanonicalizer.cpp Dedup.cpp FindInitialVectors.cpp - GroupResetsAndEnables.cpp InferMemories.cpp InferStateProperties.cpp InlineArcs.cpp @@ -17,6 +16,7 @@ add_circt_dialect_library(CIRCTArcTransforms LowerState.cpp LowerVectorizations.cpp MakeTables.cpp + MergeIfs.cpp MuxToControlFlow.cpp PrintCostModel.cpp SimplifyVariadicOps.cpp diff --git a/lib/Dialect/Arc/Transforms/GroupResetsAndEnables.cpp b/lib/Dialect/Arc/Transforms/GroupResetsAndEnables.cpp deleted file mode 100644 index e66c834a4b56..000000000000 --- a/lib/Dialect/Arc/Transforms/GroupResetsAndEnables.cpp +++ /dev/null @@ -1,224 +0,0 @@ -//===- GroupResetsAndEnables.cpp ------------------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "circt/Dialect/Arc/ArcOps.h" -#include "circt/Dialect/Arc/ArcPasses.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/IR/Dominance.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassManager.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/Support/Debug.h" - -#define DEBUG_TYPE "arc-group-resets-and-enables" - -namespace circt { -namespace arc { -#define GEN_PASS_DEF_GROUPRESETSANDENABLES -#include "circt/Dialect/Arc/ArcPasses.h.inc" -} // namespace arc -} // namespace circt - -using namespace circt; -using namespace arc; -using namespace mlir; - -//===----------------------------------------------------------------------===// -// Rewrite Patterns -//===----------------------------------------------------------------------===// - -namespace { - -struct ResetGroupingPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(ClockTreeOp clockTreeOp, - PatternRewriter &rewriter) const override { - // Group similar resets into single IfOps - // Create a list of reset values and map from them to the states they reset - llvm::MapVector> resetMap; - - for (auto ifOp : clockTreeOp.getBody().getOps()) - if (ifOp.getResults().empty()) - resetMap[ifOp.getCondition()].push_back(ifOp); - - // TODO: Check that conflicting memory effects aren't being reordered - - // Combine IfOps - bool changed = false; - for (auto &[cond, oldOps] : resetMap) { - if (oldOps.size() <= 1) - continue; - scf::IfOp lastIfOp = oldOps.pop_back_val(); - for (auto thisOp : oldOps) { - // Inline the before and after region inside the original If - rewriter.eraseOp(thisOp.thenBlock()->getTerminator()); - rewriter.inlineBlockBefore(thisOp.thenBlock(), - &lastIfOp.thenBlock()->front()); - // Check we're not inlining an empty block - if (auto *elseBlock = thisOp.elseBlock()) { - rewriter.eraseOp(elseBlock->getTerminator()); - if (auto *lastElseBlock = lastIfOp.elseBlock()) { - rewriter.inlineBlockBefore(elseBlock, - &lastIfOp.elseBlock()->front()); - } else { - lastElseBlock = rewriter.createBlock(&lastIfOp.getElseRegion()); - rewriter.setInsertionPointToEnd(lastElseBlock); - auto yieldOp = rewriter.create( - lastElseBlock->getParentOp()->getLoc()); - rewriter.inlineBlockBefore(thisOp.elseBlock(), yieldOp); - } - } - rewriter.eraseOp(thisOp); - changed = true; - } - } - return success(changed); - } -}; - -struct EnableGroupingPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(ClockTreeOp clockTreeOp, - PatternRewriter &rewriter) const override { - // Amass regions that we want to group enables in - SmallVector groupingRegions; - groupingRegions.push_back(&clockTreeOp.getBody()); - for (auto ifOp : clockTreeOp.getBody().getOps()) { - groupingRegions.push_back(&ifOp.getThenRegion()); - groupingRegions.push_back(&ifOp.getElseRegion()); - } - - bool changed = false; - for (auto *region : groupingRegions) { - llvm::MapVector> enableMap; - for (auto writeOp : region->getOps()) { - if (writeOp.getCondition()) - enableMap[writeOp.getCondition()].push_back(writeOp); - } - for (auto &[enable, writeOps] : enableMap) { - // Only group if multiple writes share an enable - if (writeOps.size() <= 1) - continue; - if (region->getParentOp()->hasTrait()) - rewriter.setInsertionPointToEnd(®ion->back()); - else - rewriter.setInsertionPoint(region->back().getTerminator()); - scf::IfOp ifOp = - rewriter.create(writeOps[0].getLoc(), enable, false); - for (auto writeOp : writeOps) { - rewriter.modifyOpInPlace(writeOp, [&]() { - writeOp->moveBefore(ifOp.thenBlock()->getTerminator()); - writeOp.getConditionMutable().erase(0); - }); - } - changed = true; - } - } - return success(changed); - } -}; - -/// Where possible without domination issues, group assignments inside IfOps and -/// return true if any operations were moved. -bool groupInRegion(Block *block, Operation *clockTreeOp, - PatternRewriter *rewriter) { - bool changed = false; - if (!block) - return false; - - SmallVector worklist; - // Don't walk as we don't want nested ops in order to restrict to IfOps - for (auto &op : block->getOperations()) { - worklist.push_back(&op); - } - while (!worklist.empty()) { - Operation *op = worklist.pop_back_val(); - mlir::DominanceInfo dom(op); - for (auto operand : op->getOperands()) { - Operation *definition = operand.getDefiningOp(); - if (definition == nullptr) - continue; - // Skip if the operand is already defined in this block or is - // defined out of the clock tree - if (definition->getBlock() == op->getBlock() || - !clockTreeOp->isAncestor(definition)) - continue; - if (llvm::any_of(definition->getUsers(), - [&](auto *user) { return !dom.dominates(op, user); })) - continue; - // For some currently unknown reason, just calling moveBefore - // directly has the same output but is much slower - rewriter->modifyOpInPlace(definition, - [&]() { definition->moveBefore(op); }); - changed = true; - worklist.push_back(definition); - } - } - return changed; -} - -struct GroupAssignmentsInIfPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(scf::IfOp ifOp, - PatternRewriter &rewriter) const override { - // Pull values only used in certain reset/enable cases into the appropriate - // IfOps - // Skip anything not in a ClockTreeOp - auto clockTreeOp = ifOp->getParentOfType(); - if (!clockTreeOp) - return failure(); - // Group assignments in each region and keep track of whether either - // grouping made changes - bool changed = groupInRegion(ifOp.thenBlock(), clockTreeOp, &rewriter) || - groupInRegion(ifOp.elseBlock(), clockTreeOp, &rewriter); - return success(changed); - } -}; - -} // namespace - -//===----------------------------------------------------------------------===// -// Pass Infrastructure -//===----------------------------------------------------------------------===// - -namespace { -struct GroupResetsAndEnablesPass - : public arc::impl::GroupResetsAndEnablesBase { - - void runOnOperation() override; - LogicalResult runOnModel(ModelOp modelOp); -}; -} // namespace - -void GroupResetsAndEnablesPass::runOnOperation() { - for (auto op : getOperation().getOps()) - if (failed(runOnModel(op))) - return signalPassFailure(); -} - -LogicalResult GroupResetsAndEnablesPass::runOnModel(ModelOp modelOp) { - LLVM_DEBUG(llvm::dbgs() << "Grouping resets and enables in `" - << modelOp.getName() << "`\n"); - - MLIRContext &context = getContext(); - RewritePatternSet patterns(&context); - patterns.add(&context); - - if (failed(applyPatternsAndFoldGreedily(modelOp, std::move(patterns)))) - return emitError(modelOp.getLoc(), - "GroupResetsAndEnables: greedy rewriter did not converge"); - - return success(); -} - -std::unique_ptr arc::createGroupResetsAndEnablesPass() { - return std::make_unique(); -} diff --git a/lib/Dialect/Arc/Transforms/MergeIfs.cpp b/lib/Dialect/Arc/Transforms/MergeIfs.cpp new file mode 100644 index 000000000000..900ad91d9951 --- /dev/null +++ b/lib/Dialect/Arc/Transforms/MergeIfs.cpp @@ -0,0 +1,365 @@ +//===- MergeIfs.cpp -------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "circt/Dialect/Arc/ArcOps.h" +#include "circt/Dialect/Arc/ArcPasses.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "arc-merge-ifs" + +namespace circt { +namespace arc { +#define GEN_PASS_DEF_MERGEIFSPASS +#include "circt/Dialect/Arc/ArcPasses.h.inc" +} // namespace arc +} // namespace circt + +using namespace mlir; +using namespace circt; +using namespace arc; + +namespace { +struct MergeIfsPass : public arc::impl::MergeIfsPassBase { + void runOnOperation() override; + using MergeIfsPassBase::MergeIfsPassBase; + using MergeIfsPassBase::numIfsMerged; + using MergeIfsPassBase::numOpsMovedFromBetweenIfs; + using MergeIfsPassBase::numOpsMovedToUser; + using MergeIfsPassBase::numOpsSunk; +}; + +/// A helper to perform the op sinking within a specific block. +struct Sinker { + MergeIfsPass &pass; + Block &rootBlock; + bool anyChanges; + + Sinker(MergeIfsPass &pass, Block &rootBlock) + : pass(pass), rootBlock(rootBlock) {} + LogicalResult run(); + void sinkOps(); + void mergeIfs(); +}; +} // namespace + +void MergeIfsPass::runOnOperation() { + // Go through the regions recursively, from outer regions to nested regions, + // and try to move/sink/merge ops in each. + auto result = getOperation()->walk([&](Region *region) { + if (region->hasOneBlock() && mlir::mayHaveSSADominance(*region)) + if (failed(Sinker(*this, region->front()).run())) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + if (result.wasInterrupted()) + signalPassFailure(); +} + +/// Iteratively sink ops into block and move them closer to their uses, and +/// merge adjacent `scf.if` operations. +LogicalResult Sinker::run() { + LLVM_DEBUG(llvm::dbgs() << "Running on block in " + << rootBlock.getParentOp()->getName() << "\n"); + do { + anyChanges = false; + sinkOps(); + mergeIfs(); + } while (anyChanges); + return success(); +} + +/// Return the state/memory value being written by an op. +static Value getPointerWrittenByOp(Operation *op) { + if (auto write = dyn_cast(op)) + return write.getState(); + if (auto write = dyn_cast(op)) + return write.getMemory(); + return {}; +} + +/// Return the state/memory value being read by an op. +static Value getPointerReadByOp(Operation *op) { + if (auto read = dyn_cast(op)) + return read.getState(); + if (auto read = dyn_cast(op)) + return read.getMemory(); + return {}; +} + +namespace { +/// An integer indicating the position of an operation in its parent block. The +/// first field is the initial order/position assigned. The second field is used +/// to order ops that were moved to the same location, which makes them have +/// same first field. +using OpOrder = std::pair; + +/// A helper that tracks an op and its order, and allows for convenient +/// substitution with another op that has a higher/lower order. +struct OpAndOrder { + Operation *op = nullptr; + OpOrder order = {0, 0}; + + explicit operator bool() const { return op; } + + /// Assign `other` if its order is lower than this op, or this op is null. + void minimize(const OpAndOrder &other) { + if (!op || (other.op && other.order < order)) + *this = other; + } + + /// Assign `other` if its order is higher than this op, or this op is null. + void maximize(const OpAndOrder &other) { + if (!op || (other.op && other.order > order)) + *this = other; + } +}; +} // namespace + +/// Sink operations as close to their users as possible. +void Sinker::sinkOps() { + // A numeric position assigned to ops as we encounter them. Ops at the end of + // the block get the lowest order number, ops at the beginning the highest. + DenseMap opOrder; + // A lookup table that indicates where ops should be inserted. This is used to + // maintain the original op order if multiple ops pile up before the same + // other op that blocks their move. + DenseMap insertionPoints; + // The write ops to each state/memory pointer we've seen so far. ("Next" + // because we run from the end to the beginning of the block.) + DenseMap nextWrite; + // The most recent op that has an unknown (non-read/write) side-effect. + Operation *nextSideEffect = nullptr; + + for (auto &op : llvm::make_early_inc_range(llvm::reverse(rootBlock))) { + // Assign an order to this op. + auto order = OpOrder{opOrder.size() + 1, 0}; + opOrder[&op] = order; + + // Analyze the side effects in the op. + op.walk([&](Operation *subOp) { + if (auto ptr = getPointerWrittenByOp(subOp)) + nextWrite[ptr] = &op; + else if (!isa(subOp) && + !subOp->hasTrait() && + !mlir::isMemoryEffectFree(subOp)) + nextSideEffect = &op; + }); + + // Determine how much the op can be moved. + OpAndOrder moveLimit; + if (auto ptr = getPointerReadByOp(&op)) { + // Don't move across writes to the same state/memory. + if (auto *write = nextWrite.lookup(ptr)) + moveLimit.maximize({write, opOrder.lookup(write)}); + // Don't move across general side-effecting ops. + if (nextSideEffect) + moveLimit.maximize({nextSideEffect, opOrder.lookup(nextSideEffect)}); + } else if (isa(&op) || nextSideEffect == &op) { + // Don't move writes or side-effecting ops. + continue; + } + + // Find the block that contains all uses. + Block *allUsesInBlock = nullptr; + for (auto *user : op.getUsers()) { + // If this user is directly in the root block there's no chance of sinking + // the current op anywhere. + if (user->getBlock() == &rootBlock) { + allUsesInBlock = nullptr; + break; + } + + // Find the operation in the root block that contains this user. + while (user->getParentOp()->getBlock() != &rootBlock) + user = user->getParentOp(); + assert(user); + + // Check that all users sit in the same op in the root block. + if (!allUsesInBlock) { + allUsesInBlock = user->getBlock(); + } else if (allUsesInBlock != user->getBlock()) { + allUsesInBlock = nullptr; + break; + } + } + + // If no single block exists that contains all uses, find the earliest op in + // the root block that uses the current op. + OpAndOrder earliest; + if (allUsesInBlock) { + earliest.op = allUsesInBlock->getParentOp(); + earliest.order = opOrder.lookup(earliest.op); + } else { + for (auto *user : op.getUsers()) { + while (user->getBlock() != &rootBlock) + user = user->getParentOp(); + assert(user); + earliest.maximize({user, opOrder.lookup(user)}); + } + } + + // Ensure we don't move past the move limit imposed by side effects. + earliest.maximize(moveLimit); + if (!earliest) + continue; + + // Either move the op inside the single block that contains all uses, or + // move it to just before its earliest user. + if (allUsesInBlock && allUsesInBlock->getParentOp() == earliest.op) { + op.moveBefore(allUsesInBlock, allUsesInBlock->begin()); + ++pass.numOpsSunk; + anyChanges = true; + LLVM_DEBUG(llvm::dbgs() << "- Sunk " << op << "\n"); + } else { + // Insert above other ops that we have already moved to this earliest op. + // This ensures the original op order is maintained and we are not + // spuriously flipping ops around. This also works without the + // `insertionPoint` lookup, but can cause significant linear scanning to + // find the op before which we want to insert. + auto &insertionPoint = insertionPoints[earliest.op]; + if (insertionPoint) { + auto order = opOrder.lookup(insertionPoint); + assert(order.first == earliest.order.first); + assert(order.second >= earliest.order.second); + earliest.op = insertionPoint; + earliest.order = order; + } + while (auto *prevOp = earliest.op->getPrevNode()) { + auto order = opOrder.lookup(prevOp); + if (order.first != earliest.order.first) + break; + assert(order.second > earliest.order.second); + earliest.op = prevOp; + earliest.order = order; + } + insertionPoint = earliest.op; + + // Only move if the op isn't already in the right spot. + if (op.getNextNode() != earliest.op) { + LLVM_DEBUG(llvm::dbgs() << "- Moved " << op << "\n"); + op.moveBefore(earliest.op); + ++pass.numOpsMovedToUser; + anyChanges = true; + } + + // Update the current op's order to reflect where it has been inserted. + // This ensures that later moves to the same pile of moved ops do not + // reorder the operations. + order = earliest.order; + assert(order.second < unsigned(-1)); + ++order.second; + opOrder[&op] = order; + } + } +} + +void Sinker::mergeIfs() { + DenseSet writes, reads; + + scf::IfOp lastOp; + for (auto ifOp : rootBlock.getOps()) { + auto prevIfOp = std::exchange(lastOp, ifOp); + if (!prevIfOp) + continue; + + // Only handle simple cases for now. (Same condition, no results, only + // single blocks, and both ifs either have or don't have an else region.) + if (ifOp.getCondition() != prevIfOp.getCondition()) + continue; + if (ifOp.getNumResults() != 0 || prevIfOp.getNumResults() != 0) + continue; + if (!ifOp.getThenRegion().hasOneBlock() || + !prevIfOp.getThenRegion().hasOneBlock()) + continue; + if (ifOp.getElseRegion().empty() != prevIfOp.getElseRegion().empty()) + continue; + if (!ifOp.getElseRegion().empty() && + (!ifOp.getElseRegion().hasOneBlock() || + !prevIfOp.getElseRegion().hasOneBlock())) + continue; + + // Try to move ops in between the if ops above the previous if in order to + // make them immediately adjacent. + if (ifOp->getPrevNode() != prevIfOp) { + // Determine the side effects inside the previous if op. + bool hasSideEffects = false; + writes.clear(); + reads.clear(); + prevIfOp.walk([&](Operation *op) { + if (auto ptr = getPointerWrittenByOp(op)) + writes.insert(ptr); + else if (auto ptr = getPointerReadByOp(op)) + reads.insert(ptr); + else if (!hasSideEffects && + !op->hasTrait() && + !mlir::isMemoryEffectFree(op)) + hasSideEffects = true; + }); + + // Check if it is legal to throw all ops over the previous if op, given + // the side effects. We don't move the ops yet to ensure we can move *all* + // of them at once afterwards. Otherwise this optimization would race with + // the sink-to-users optimization. + bool allMovable = true; + for (auto &op : llvm::make_range(Block::iterator(prevIfOp->getNextNode()), + Block::iterator(ifOp))) { + auto result = op.walk([&](Operation *subOp) { + if (auto ptr = getPointerWrittenByOp(subOp)) { + // We can't move writes over writes or reads of the same state. + if (writes.contains(ptr) || reads.contains(ptr)) + return WalkResult::interrupt(); + } else if (auto ptr = getPointerReadByOp(subOp)) { + // We can't move reads over writes to the same state. + if (writes.contains(ptr)) + return WalkResult::interrupt(); + } else if (!subOp->hasTrait() && + !mlir::isMemoryEffectFree(subOp)) { + // We can't move side-effecting ops over other side-effecting ops. + if (hasSideEffects) + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + if (result.wasInterrupted()) { + allMovable = false; + break; + } + } + if (!allMovable) + continue; + + // At this point we know that all ops can be moved. Do so. + while (auto *op = prevIfOp->getNextNode()) { + if (op == ifOp) + break; + LLVM_DEBUG(llvm::dbgs() << "- Moved before if " << *op << "\n"); + op->moveBefore(prevIfOp); + ++pass.numOpsMovedFromBetweenIfs; + } + } + + // Merge the then-blocks. + prevIfOp.thenYield().erase(); + ifOp.thenBlock()->getOperations().splice( + ifOp.thenBlock()->begin(), prevIfOp.thenBlock()->getOperations()); + + // Merge the else-blocks if present. + if (ifOp.elseBlock()) { + prevIfOp.elseYield().erase(); + ifOp.elseBlock()->getOperations().splice( + ifOp.elseBlock()->begin(), prevIfOp.elseBlock()->getOperations()); + } + + // Clean up. + prevIfOp.erase(); + anyChanges = true; + ++pass.numIfsMerged; + LLVM_DEBUG(llvm::dbgs() << "- Merged adjacent if ops\n"); + } +} diff --git a/test/Dialect/Arc/group-resets-and-enables.mlir b/test/Dialect/Arc/group-resets-and-enables.mlir deleted file mode 100644 index a3d41e7c1eff..000000000000 --- a/test/Dialect/Arc/group-resets-and-enables.mlir +++ /dev/null @@ -1,396 +0,0 @@ -// RUN: circt-opt %s --arc-group-resets-and-enables | FileCheck %s - -// CHECK-LABEL: arc.model @BasicResetGrouping -arc.model @BasicResetGrouping io !hw.modty { -^bb0(%arg0: !arc.storage): - %c0_i4 = hw.constant 0 : i4 - %in_clock = arc.root_input "clock", %arg0 : (!arc.storage) -> !arc.state - %in_i0 = arc.root_input "i0", %arg0 : (!arc.storage) -> !arc.state - %in_i1 = arc.root_input "i1", %arg0 : (!arc.storage) -> !arc.state - %in_reset0 = arc.root_input "reset0", %arg0 : (!arc.storage) -> !arc.state - %in_reset1 = arc.root_input "reset1", %arg0 : (!arc.storage) -> !arc.state - // CHECK: [[FOO_ALLOC:%.+]] = arc.alloc_state %arg0 {name = "foo"} : (!arc.storage) -> !arc.state - // CHECK: [[BAR_ALLOC:%.+]] = arc.alloc_state %arg0 {name = "bar"} : (!arc.storage) -> !arc.state - %1 = arc.alloc_state %arg0 {name = "foo"} : (!arc.storage) -> !arc.state - %2 = arc.alloc_state %arg0 {name = "bar"} : (!arc.storage) -> !arc.state - %0 = arc.state_read %in_clock : - // Group resets: - arc.clock_tree %0 { - // CHECK: [[IN_RESET0:%.+]] = arc.state_read %in_reset0 - %3 = arc.state_read %in_reset0 : - // CHECK-NEXT: scf.if [[IN_RESET0]] { - scf.if %3 { - // CHECK-NEXT: arc.state_write [[FOO_ALLOC]] = %c0_i4 - // CHECK-NEXT: arc.state_write [[BAR_ALLOC]] = %c0_i4 - arc.state_write %1 = %c0_i4 : - // CHECK-NEXT: } else { - } else { - // CHECK-NEXT: [[IN_I0:%.+]] = arc.state_read %in_i0 - // CHECK-NEXT: arc.state_write [[FOO_ALLOC]] = [[IN_I0]] - // CHECK-NEXT: [[IN_I1:%.+]] = arc.state_read %in_i1 - // CHECK-NEXT: arc.state_write [[BAR_ALLOC]] = [[IN_I1]] - %4 = arc.state_read %in_i0 : - arc.state_write %1 = %4 : - // CHECK-NEXT: } - } - scf.if %3 { - arc.state_write %2 = %c0_i4 : - } else { - %5 = arc.state_read %in_i1 : - arc.state_write %2 = %5 : - } - // CHECK-NEXT: } - } - // Don't group resets that don't match: - arc.clock_tree %0 { - // CHECK: [[IN_RESET0_1:%.+]] = arc.state_read %in_reset0 - %6 = arc.state_read %in_reset0 : - // CHECK-NEXT: [[IN_RESET1_1:%.+]] = arc.state_read %in_reset1 - %7 = arc.state_read %in_reset1 : - // CHECK-NEXT: scf.if [[IN_RESET0_1]] { - scf.if %6 { - // CHECK-NEXT: arc.state_write [[FOO_ALLOC]] = %c0_i4 - arc.state_write %1 = %c0_i4 : - // CHECK-NEXT: } else { - } else { - // CHECK-NEXT: [[IN_I0_1:%.+]] = arc.state_read %in_i0 - // CHECK-NEXT: arc.state_write [[FOO_ALLOC]] = [[IN_I0_1]] - %8 = arc.state_read %in_i0 : - arc.state_write %1 = %8 : - // CHECK-NEXT: } - } - // CHECK-NEXT: scf.if [[IN_RESET1_1]] { - scf.if %7 { - // CHECK-NEXT: arc.state_write [[BAR_ALLOC]] = %c0_i4 - arc.state_write %2 = %c0_i4 : - // CHECK-NEXT: } else { - } else { - // CHECK-NEXT: [[IN_I1_1:%.+]] = arc.state_read %in_i1 - // CHECK-NEXT: arc.state_write [[BAR_ALLOC]] = [[IN_I1_1]] - %9 = arc.state_read %in_i1 : - arc.state_write %2 = %9 : - } - // CHECK-NEXT: } - // CHECK-NEXT: } - } - // Don't group IfOps with return values: - arc.clock_tree %0 { - // CHECK: [[IN_RESET0:%.+]] = arc.state_read %in_reset0 - %10 = arc.state_read %in_reset0 : - // CHECK-NEXT: scf.if [[IN_RESET0]] { - scf.if %10 { - // CHECK-NEXT: arc.state_write [[FOO_ALLOC]] = %c0_i4 - arc.state_write %1 = %c0_i4 : - // CHECK-NEXT: } else { - } else { - // CHECK-NEXT: [[IN_I0:%.+]] = arc.state_read %in_i0 - // CHECK-NEXT: arc.state_write [[FOO_ALLOC]] = [[IN_I0]] - %11 = arc.state_read %in_i0 : - arc.state_write %1 = %11 : - // CHECK-NEXT: } - } - // CHECK-NEXT: [[IF_RESULT:%.+]] scf.if [[IN_RESET0]] -> (i4) { - %res = scf.if %10 -> (i4) { - // CHECK-NEXT: arc.state_write [[BAR_ALLOC]] = %c0_i4 - // CHECK-NEXT: scf.yield %c0_i4 : i4 - arc.state_write %2 = %c0_i4 : - scf.yield %c0_i4 : i4 - // CHECK-NEXT: } else { - } else { - // CHECK-NEXT: [[IN_I1:%.+]] = arc.state_read %in_i1 - // CHECK-NEXT: arc.state_write [[BAR_ALLOC]] = [[IN_I1]] - // CHECK-NEXT: scf.yield %c0_i4 : i4 - %12 = arc.state_read %in_i1 : - arc.state_write %2 = %12 : - scf.yield %c0_i4 : i4 - } - // CHECK-NEXT: } - // CHECK-NEXT: } - } - // Group resets with no else in an early block (that has its contents moved): - arc.clock_tree %0 { - // CHECK: [[IN_RESET0:%.+]] = arc.state_read %in_reset0 - %13 = arc.state_read %in_reset0 : - // CHECK-NEXT: scf.if [[IN_RESET0]] { - scf.if %13 { - // CHECK-NEXT: arc.state_write [[FOO_ALLOC:%.+]] = %c0_i4 - // CHECK-NEXT: arc.state_write [[BAR_ALLOC:%.+]] = %c0_i4 - arc.state_write %1 = %c0_i4 : - // CHECK-NEXT: } else { - // CHECK-NEXT: [[IN_I1:%.+]] = arc.state_read %in_i1 - // CHECK-NEXT: arc.state_write [[BAR_ALLOC]] = [[IN_I1]] - // CHECK-NEXT: } - } - scf.if %13 { - arc.state_write %2 = %c0_i4 : - } else { - %14 = arc.state_read %in_i1 : - arc.state_write %2 = %14 : - } - // CHECK-NEXT: } - } - // Group resets with no else in the last if (where contents are moved to): - arc.clock_tree %0 { - // CHECK: [[IN_RESET0:%.+]] = arc.state_read %in_reset0 - %15 = arc.state_read %in_reset0 : - // CHECK-NEXT: scf.if [[IN_RESET0]] { - scf.if %15 { - // CHECK-NEXT: arc.state_write [[FOO_ALLOC]] = %c0_i4 - // CHECK-NEXT: arc.state_write [[BAR_ALLOC]] = %c0_i4 - arc.state_write %1 = %c0_i4 : - // CHECK-NEXT: } else { - } else { - // CHECK-NEXT: [[IN_I0:%.+]] = arc.state_read %in_i0 - // CHECK-NEXT: arc.state_write [[FOO_ALLOC]] = [[IN_I0]] - %16 = arc.state_read %in_i0 : - arc.state_write %1 = %16 : - // CHECK-NEXT: } - } - scf.if %15 { - arc.state_write %2 = %c0_i4 : - } - // CHECK-NEXT: } - } -} - -// CHECK-LABEL: arc.model @BasicEnableGrouping -arc.model @BasicEnableGrouping io !hw.modty { -^bb0(%arg0: !arc.storage): - %c0_i4 = hw.constant 0 : i4 - %in_clock = arc.root_input "clock", %arg0 : (!arc.storage) -> !arc.state - %in_i0 = arc.root_input "i0", %arg0 : (!arc.storage) -> !arc.state - %in_i1 = arc.root_input "i1", %arg0 : (!arc.storage) -> !arc.state - %in_en0 = arc.root_input "en0", %arg0 : (!arc.storage) -> !arc.state - %in_en1 = arc.root_input "en1", %arg0 : (!arc.storage) -> !arc.state - // CHECK: [[FOO_ALLOC:%.+]] = arc.alloc_state %arg0 {name = "foo"} : (!arc.storage) -> !arc.state - // CHECK: [[BAR_ALLOC:%.+]] = arc.alloc_state %arg0 {name = "bar"} : (!arc.storage) -> !arc.state - %1 = arc.alloc_state %arg0 {name = "foo"} : (!arc.storage) -> !arc.state - %2 = arc.alloc_state %arg0 {name = "bar"} : (!arc.storage) -> !arc.state - %0 = arc.state_read %in_clock : - // Group enables: - arc.clock_tree %0 { - // CHECK: [[IN_EN0:%.+]] = arc.state_read %in_en0 - %3 = arc.state_read %in_en0 : - // CHECK-NEXT: arc.state_write [[FOO_ALLOC]] = %c0_i4 - // CHECK-NEXT: arc.state_write [[BAR_ALLOC]] = %c0_i4 - arc.state_write %1 = %c0_i4 : - arc.state_write %2 = %c0_i4 : - // CHECK-NEXT: scf.if [[IN_EN0]] { - // state_reads are pulled in: - // CHECK-NEXT: [[IN_I0:%.+]] = arc.state_read %in_i0 - // CHECK-NEXT: arc.state_write [[FOO_ALLOC]] = [[IN_I0]] - // CHECK-NEXT: [[IN_I1:%.+]] = arc.state_read %in_i1 - // CHECK-NEXT: arc.state_write [[BAR_ALLOC]] = [[IN_I1]] - %4 = arc.state_read %in_i0 : - arc.state_write %1 = %4 if %3 : - %5 = arc.state_read %in_i1 : - arc.state_write %2 = %5 if %3 : - // CHECK-NEXT: } - // CHECK-NEXT: } - } - // Don't group non-matching enables: - arc.clock_tree %0 { - // CHECK: [[IN_EN0_1:%.+]] = arc.state_read %in_en0 - %6 = arc.state_read %in_en0 : - // CHECK-NEXT: [[IN_EN1_1:%.+]] = arc.state_read %in_en1 - %7 = arc.state_read %in_en1 : - // CHECK-NEXT: arc.state_write [[FOO_ALLOC]] = %c0_i4 - // CHECK-NEXT: arc.state_write [[BAR_ALLOC]] = %c0_i4 - arc.state_write %1 = %c0_i4 : - arc.state_write %2 = %c0_i4 : - // CHECK-NEXT: [[IN_I0_1:%.+]] = arc.state_read %in_i0 - // CHECK-NEXT: arc.state_write [[FOO_ALLOC]] = [[IN_I0_1]] if [[IN_EN0_1]] - // CHECK-NEXT: [[IN_I1_1:%.+]] = arc.state_read %in_i1 - // CHECK-NEXT: arc.state_write [[BAR_ALLOC]] = [[IN_I1_1]] if [[IN_EN1_1]] - %8 = arc.state_read %in_i0 : - arc.state_write %1 = %8 if %6 : - %9 = arc.state_read %in_i1 : - arc.state_write %2 = %9 if %7 : - // CHECK-NEXT: } - } -} - -// CHECK-LABEL: arc.model @GroupAssignmentsInIfTesting -arc.model @GroupAssignmentsInIfTesting io !hw.modty { -^bb0(%arg0: !arc.storage): - %in_clock = arc.root_input "clock", %arg0 : (!arc.storage) -> !arc.state - %in_i1 = arc.root_input "i1", %arg0 : (!arc.storage) -> !arc.state - %in_i2 = arc.root_input "i2", %arg0 : (!arc.storage) -> !arc.state - %in_cond0 = arc.root_input "cond0", %arg0 : (!arc.storage) -> !arc.state - %in_cond1 = arc.root_input "cond1", %arg0 : (!arc.storage) -> !arc.state - // CHECK: [[FOO_ALLOC:%.+]] = arc.alloc_state %arg0 {name = "foo"} : (!arc.storage) -> !arc.state - // CHECK: [[BAR_ALLOC:%.+]] = arc.alloc_state %arg0 {name = "bar"} : (!arc.storage) -> !arc.state - %1 = arc.alloc_state %arg0 {name = "foo"} : (!arc.storage) -> !arc.state - %2 = arc.alloc_state %arg0 {name = "bar"} : (!arc.storage) -> !arc.state - %0 = arc.state_read %in_clock : - // Do pull value in (1st and 2nd layer) - arc.clock_tree %0 { - // CHECK: [[IN_COND0:%.+]] = arc.state_read %in_cond0 - %3 = arc.state_read %in_cond0 : - %4 = arc.state_read %in_i1 : - %5 = arc.state_read %in_i2 : - // CHECK-NEXT: scf.if [[IN_COND0]] { - scf.if %3 { - // CHECK-NEXT: [[IN_I1:%.+]] = arc.state_read %in_i1 - // CHECK-NEXT: arc.state_write [[FOO_ALLOC]] = [[IN_I1]] - arc.state_write %1 = %4 : - // CHECK: [[IN_COND1:%.+]] = arc.state_read %in_cond1 - %6 = arc.state_read %in_cond1 : - // CHECK-NEXT: scf.if [[IN_COND1]] { - scf.if %6 { - // CHECK-NEXT: [[IN_I2:%.+]] = arc.state_read %in_i2 - // CHECK-NEXT: arc.state_write [[BAR_ALLOC]] = [[IN_I2]] - arc.state_write %2 = %5 : - // CHECK-NEXT: } - } - // CHECK-NEXT: } - } - } - // CHECK-NEXT: } - // Don't pull value in - arc.clock_tree %0 { - // CHECK: [[IN_COND0:%.+]] = arc.state_read %in_cond0 - %5 = arc.state_read %in_cond0 : - // CHECK-NEXT: [[IN_I1:%.+]] = arc.state_read %in_i1 - %6 = arc.state_read %in_i1 : - // CHECK-NEXT: scf.if [[IN_COND0]] { - scf.if %5 { - // CHECK-NEXT: arc.state_write [[FOO_ALLOC]] = [[IN_I1]] - arc.state_write %1 = %6 : - // CHECK-NEXT: } - } - // CHECK-NEXT: arc.state_write [[BAR_ALLOC]] = [[IN_I1]] - arc.state_write %2 = %6 : - // CHECK-NEXT: } - } - // Pull multi-use value into first if only - arc.clock_tree %0 { - // CHECK: [[IN_COND0:%.+]] = arc.state_read %in_cond0 - %5 = arc.state_read %in_cond0 : - %6 = arc.state_read %in_cond1 : - %7 = arc.state_read %in_i1 : - // CHECK-NEXT: scf.if [[IN_COND0]] { - scf.if %5 { - // CHECK-NEXT: [[IN_I1:%.+]] = arc.state_read %in_i1 - // CHECK-NEXT: arc.state_write [[FOO_ALLOC]] = [[IN_I1]] - arc.state_write %1 = %7 : - // CHECK-NEXT: [[IN_COND1:%.+]] = arc.state_read %in_cond1 - // CHECK-NEXT: scf.if [[IN_COND1]] { - scf.if %6 { - // CHECK-NEXT: arc.state_write [[BAR_ALLOC]] = [[IN_I1]] - arc.state_write %2 = %7 : - // CHECK-NEXT: } - } - // CHECK-NEXT: } - } - // CHECK-NEXT: } - } -} - -// CHECK-LABEL: arc.model @ResetAndEnableGrouping -arc.model @ResetAndEnableGrouping io !hw.modty { -^bb0(%arg0: !arc.storage): - %c0_i4 = hw.constant 0 : i4 - %in_clock = arc.root_input "clock", %arg0 : (!arc.storage) -> !arc.state - %in_i0 = arc.root_input "i0", %arg0 : (!arc.storage) -> !arc.state - %in_i1 = arc.root_input "i1", %arg0 : (!arc.storage) -> !arc.state - %in_reset = arc.root_input "reset", %arg0 : (!arc.storage) -> !arc.state - %in_en0 = arc.root_input "en0", %arg0 : (!arc.storage) -> !arc.state - %in_en1 = arc.root_input "en1", %arg0 : (!arc.storage) -> !arc.state - // CHECK: [[FOO_ALLOC:%.+]] = arc.alloc_state %arg0 {name = "foo"} : (!arc.storage) -> !arc.state - // CHECK: [[BAR_ALLOC:%.+]] = arc.alloc_state %arg0 {name = "bar"} : (!arc.storage) -> !arc.state - %1 = arc.alloc_state %arg0 {name = "foo"} : (!arc.storage) -> !arc.state - %2 = arc.alloc_state %arg0 {name = "bar"} : (!arc.storage) -> !arc.state - %0 = arc.state_read %in_clock : - // Group enables inside resets (and pull in reads): - arc.clock_tree %0 { - // CHECK: [[IN_RESET:%.+]] = arc.state_read %in_reset - %3 = arc.state_read %in_reset : - // CHECK-NEXT: scf.if [[IN_RESET]] { - scf.if %3 { - // CHECK-NEXT: arc.state_write [[FOO_ALLOC]] = %c0_i4 - // CHECK-NEXT: arc.state_write [[BAR_ALLOC]] = %c0_i4 - arc.state_write %1 = %c0_i4 : - arc.state_write %2 = %c0_i4 : - // CHECK-NEXT: } else { - } else { - // CHECK-NEXT: [[IN_EN:%.+]] = arc.state_read %in_en1 - %4 = arc.state_read %in_en1 : - // CHECK-NEXT: scf.if [[IN_EN]] { - // CHECK-NEXT: [[IN_I0:%.+]] = arc.state_read %in_i0 - // CHECK-NEXT: arc.state_write [[FOO_ALLOC]] = [[IN_I0]] - // CHECK-NEXT: [[IN_I1:%.+]] = arc.state_read %in_i1 - // CHECK-NEXT: arc.state_write [[BAR_ALLOC]] = [[IN_I1]] - %5 = arc.state_read %in_i0 : - arc.state_write %1 = %5 if %4 : - %6 = arc.state_read %in_i1 : - arc.state_write %2 = %6 if %4 : - // CHECK-NEXT: } - // CHECK-NEXT: } - } - // CHECK-NEXT: } - } - // Group both resets and enables (and pull in reads): - arc.clock_tree %0 { - // CHECK: [[IN_RESET:%.+]] = arc.state_read %in_reset - %7 = arc.state_read %in_reset : - %8 = arc.state_read %in_en0 : - // CHECK-NEXT: scf.if [[IN_RESET]] { - scf.if %7 { - // CHECK-NEXT: arc.state_write [[FOO_ALLOC]] = %c0_i4 - // CHECK-NEXT: arc.state_write [[BAR_ALLOC]] = %c0_i4 - arc.state_write %1 = %c0_i4 : - // CHECK-NEXT: } else { - } else { - // CHECK-NEXT: [[IN_EN0:%.+]] = arc.state_read %in_en0 - // State reads are pulled in - // CHECK-NEXT: scf.if [[IN_EN0]] { - // CHECK-NEXT: [[IN_I0:%.+]] = arc.state_read %in_i0 - // CHECK-NEXT: arc.state_write [[FOO_ALLOC]] = [[IN_I0]] - // CHECK-NEXT: [[IN_I1:%.+]] = arc.state_read %in_i1 - // CHECK-NEXT: arc.state_write [[BAR_ALLOC]] = [[IN_I1]] - // CHECK-NEXT: } - %9 = arc.state_read %in_i0 : - arc.state_write %1 = %9 if %8 : - // CHECK-NEXT: } - } - scf.if %7 { - arc.state_write %2 = %c0_i4 : - } else { - %10 = arc.state_read %in_i1 : - arc.state_write %2 = %10 if %8 : - } - // CHECK-NEXT: } - } - // Group resets that are separated by an enable read (and pull in reads): - arc.clock_tree %0 { - // CHECK: [[IN_RESET:%.+]] = arc.state_read %in_reset - %11 = arc.state_read %in_reset : - // CHECK-NEXT: scf.if [[IN_RESET]] { - scf.if %11 { - // CHECK-NEXT: arc.state_write [[FOO_ALLOC]] = %c0_i4 - // CHECK-NEXT: arc.state_write [[BAR_ALLOC]] = %c0_i4 - arc.state_write %1 = %c0_i4 : - // CHECK-NEXT: } else { - } else { - // CHECK-NEXT: [[IN_EN0:%.+]] = arc.state_read %in_en0 - // CHECK-NEXT: [[IN_I0:%.+]] = arc.state_read %in_i0 - %12 = arc.state_read %in_en0 : - // CHECK-NEXT: arc.state_write [[FOO_ALLOC]] = [[IN_I0]] if [[IN_EN0]] - // CHECK-NEXT: [[IN_I1:%.+]] = arc.state_read %in_i1 - // CHECK-NEXT: [[IN_EN1:%.+]] = arc.state_read %in_en1 - // CHECK-NEXT: arc.state_write [[BAR_ALLOC]] = [[IN_I1]] if [[IN_EN1]] - %13 = arc.state_read %in_i0 : - arc.state_write %1 = %13 if %12 : - // CHECK-NEXT: } - } - %14 = arc.state_read %in_en1 : - scf.if %11 { - arc.state_write %2 = %c0_i4 : - } else { - %15 = arc.state_read %in_i1 : - arc.state_write %2 = %15 if %14 : - } - // CHECK-NEXT: } - } -} diff --git a/test/Dialect/Arc/merge-ifs.mlir b/test/Dialect/Arc/merge-ifs.mlir new file mode 100644 index 000000000000..cc245f7554ea --- /dev/null +++ b/test/Dialect/Arc/merge-ifs.mlir @@ -0,0 +1,283 @@ +// RUN: circt-opt --arc-merge-ifs %s | FileCheck %s + +func.func private @Blocker() + +// CHECK-LABEL: func.func @DontMoveUnusedOps +func.func @DontMoveUnusedOps(%arg0: !arc.state) { + // CHECK-NEXT: arc.state_read %arg0 : + arc.state_read %arg0 : + // CHECK-NEXT: hw.constant false + hw.constant false + // CHECK-NEXT: hw.constant true + hw.constant true + return +} + +// CHECK-LABEL: func.func @SinkReads +func.func @SinkReads(%arg0: !arc.state, %arg1: !arc.memory<2 x i42, i1>, %arg2: i1) { + %0 = arc.state_read %arg0 : + %1 = arc.memory_read %arg1[%arg2] : <2 x i42, i1> + // CHECK-NEXT: hw.constant false + hw.constant false + // CHECK-NEXT: scf.if + scf.if %arg2 { + // CHECK-NEXT: hw.constant true + hw.constant true + // CHECK-NEXT: arc.state_read + // CHECK-NEXT: arc.memory_read + // CHECK-NEXT: comb.xor + comb.xor %0, %1 : i42 + } + return +} + +// CHECK-LABEL: func.func @MoveReads +func.func @MoveReads(%arg0: !arc.state, %arg1: !arc.memory<2 x i42, i1>, %arg2: i1) { + %0 = arc.state_read %arg0 : + %1 = arc.memory_read %arg1[%arg2] : <2 x i42, i1> + // CHECK-NEXT: hw.constant false + hw.constant false + // CHECK-NEXT: arc.state_read + // CHECK-NEXT: arc.memory_read + // CHECK-NEXT: scf.if + scf.if %arg2 { + comb.xor %0, %1 : i42 + } + comb.xor %0, %1 : i42 + return +} + +// CHECK-LABEL: func.func @SinkAndMoveReads +func.func @SinkAndMoveReads(%arg0: !arc.state, %arg1: !arc.memory<2 x i42, i1>, %arg2: i1, %arg3: i1) { + %0 = arc.state_read %arg0 {a} : + %1 = arc.state_read %arg0 {b} : + %2 = arc.state_read %arg0 {c} : + %3 = arc.memory_read %arg1[%arg2] {x} : <2 x i42, i1> + %4 = arc.memory_read %arg1[%arg2] {y} : <2 x i42, i1> + %5 = arc.memory_read %arg1[%arg2] {z} : <2 x i42, i1> + // CHECK-NEXT: scf.if + scf.if %arg2 { + // CHECK-NEXT: hw.constant false + hw.constant false + // CHECK-NEXT: arc.state_read %arg0 {a} + // CHECK-NEXT: arc.memory_read %arg1[%arg2] {x} + // CHECK-NEXT: comb.xor + comb.xor %0, %3 : i42 + } + // CHECK-NEXT: } + // CHECK-NEXT: arc.state_read %arg0 {b} + // CHECK-NEXT: arc.memory_read %arg1[%arg2] {y} + // CHECK-NEXT: scf.if + scf.if %arg3 { + // CHECK-NEXT: hw.constant false + hw.constant false + // CHECK-NEXT: comb.xor + comb.xor %1, %4 : i42 + } + // CHECK-NEXT: } + // CHECK-NEXT: arc.state_read %arg0 {c} + // CHECK-NEXT: arc.memory_read %arg1[%arg2] {z} + // CHECK-NEXT: comb.xor + comb.xor %1, %2, %4, %5 : i42 + return +} + +// CHECK-LABEL: func.func @WriteBlocksReadMove +func.func @WriteBlocksReadMove( + %arg0: !arc.state, + %arg1: !arc.state, + %arg2: !arc.memory<2 x i42, i1>, + %arg3: !arc.memory<2 x i42, i1>, + %arg4: i1, + %arg5: i42 +) { + %0 = arc.state_read %arg0 {blocked} : + %1 = arc.state_read %arg1 {free} : + %2 = arc.memory_read %arg2[%arg4] {blocked} : <2 x i42, i1> + %3 = arc.memory_read %arg3[%arg4] {free} : <2 x i42, i1> + // CHECK-NEXT: hw.constant false + hw.constant false + // CHECK-NEXT: arc.state_read %arg0 {blocked} + // CHECK-NEXT: arc.memory_read %arg2[%arg4] {blocked} + // CHECK-NEXT: scf.if + scf.if %arg4 { + // CHECK-NEXT: arc.state_write + // CHECK-NEXT: arc.memory_write + arc.state_write %arg0 = %arg5 : + arc.memory_write %arg2[%arg4], %arg5 : <2 x i42, i1> + } + // CHECK-NEXT: } + // CHECK-NEXT: arc.state_read %arg1 {free} + // CHECK-NEXT: arc.memory_read %arg3[%arg4] {free} + // CHECK-NEXT: comb.xor + comb.xor %0, %1, %2, %3 : i42 + return +} + +// CHECK-LABEL: func.func @MovedOpsRetainOrder +func.func @MovedOpsRetainOrder(%arg0: i42, %arg1: i1) { + // CHECK-NEXT: hw.constant false {ka} + // CHECK-NEXT: hw.constant false {kb} + // CHECK-NEXT: hw.constant false {kc} + // CHECK-NEXT: comb.xor {{%.+}} {a0} + // CHECK-NEXT: comb.xor {{%.+}} {a1} + // CHECK-NEXT: comb.xor {{%.+}} {b0} + // CHECK-NEXT: comb.xor {{%.+}} {b1} + // CHECK-NEXT: comb.xor {{%.+}} {c0} + // CHECK-NEXT: comb.xor {{%.+}} {c1} + %a0 = comb.xor %arg0 {a0} : i42 + %a1 = comb.xor %a0 {a1} : i42 + hw.constant false {ka} + %b0 = comb.xor %arg0 {b0} : i42 + %b1 = comb.xor %b0 {b1} : i42 + hw.constant false {kb} + %c0 = comb.xor %arg0 {c0} : i42 + %c1 = comb.xor %c0 {c1} : i42 + hw.constant false {kc} + // CHECK-NEXT: scf.if + scf.if %arg1 { + comb.xor %a1 {ia} : i42 + comb.xor %b1 {ib} : i42 + comb.xor %c1 {ic} : i42 + } + comb.xor %a1 {xa} : i42 + comb.xor %b1 {xb} : i42 + comb.xor %c1 {xc} : i42 + return +} + +// CHECK-LABEL: func.func @MergeAdjacentIfs +func.func @MergeAdjacentIfs(%arg0: i1, %arg1: i1) { + // CHECK-NEXT: scf.if %arg0 { + // CHECK-NEXT: hw.constant false {a} + // CHECK-NEXT: hw.constant false {b} + // CHECK-NEXT: } + scf.if %arg0 { + hw.constant false {a} + } + scf.if %arg0 { + hw.constant false {b} + } + // CHECK-NEXT: scf.if %arg1 { + // CHECK-NEXT: hw.constant false {c} + // CHECK-NEXT: } + scf.if %arg1 { + hw.constant false {c} + } + return +} + +// CHECK-LABEL: func.func @MergeIfsAcrossOps +func.func @MergeIfsAcrossOps( + %arg0: i1, + %arg1: !arc.state, + %arg2: !arc.memory<2 x i42, i1>, + %arg3: i42 +) { + // CHECK-NEXT: arc.state_read + // CHECK-NEXT: arc.state_write + // CHECK-NEXT: arc.memory_read + // CHECK-NEXT: arc.memory_write + // CHECK-NEXT: scf.if %arg0 { + // CHECK-NEXT: hw.constant false {a} + // CHECK-NEXT: hw.constant false {b} + // CHECK-NEXT: } + scf.if %arg0 { + hw.constant false {a} + } + arc.state_read %arg1 : + arc.state_write %arg1 = %arg3 : + arc.memory_read %arg2[%arg0] : <2 x i42, i1> + arc.memory_write %arg2[%arg0], %arg3 : <2 x i42, i1> + scf.if %arg0 { + hw.constant false {b} + } + return +} + +// CHECK-LABEL: func.func @DontMergeIfsAcrossSideEffects +func.func @DontMergeIfsAcrossSideEffects( + %arg0: i1, + %arg1: !arc.state, + %arg2: !arc.memory<2 x i42, i1>, + %arg3: i42 +) { + // CHECK-NEXT: scf.if %arg0 { + // CHECK-NEXT: hw.constant false {a} + // CHECK-NEXT: func.call @Blocker() {blockerA} + // CHECK-NEXT: hw.constant false {b} + // CHECK-NEXT: } + scf.if %arg0 { + hw.constant false {a} + func.call @Blocker() {blockerA} : () -> () + } + scf.if %arg0 { + hw.constant false {b} + } + // CHECK-NEXT: call @Blocker() {cantMoveAcrossA} + call @Blocker() {cantMoveAcrossA} : () -> () + // CHECK-NEXT: scf.if %arg0 { + // CHECK-NEXT: hw.constant false {c} + // CHECK-NEXT: arc.state_write %arg1 = %arg3 {blockerB} + // CHECK-NEXT: hw.constant false {d} + // CHECK-NEXT: } + scf.if %arg0 { + hw.constant false {c} + arc.state_write %arg1 = %arg3 {blockerB} : + } + scf.if %arg0 { + hw.constant false {d} + } + // CHECK-NEXT: arc.state_read %arg1 {cantMoveAcrossB} + arc.state_read %arg1 {cantMoveAcrossB} : + // CHECK-NEXT: scf.if %arg0 { + // CHECK-NEXT: hw.constant false {e} + // CHECK-NEXT: arc.memory_write %arg2[%arg0], %arg3 {blockerC} + // CHECK-NEXT: hw.constant false {f} + // CHECK-NEXT: } + scf.if %arg0 { + hw.constant false {e} + arc.memory_write %arg2[%arg0], %arg3 {blockerC} : <2 x i42, i1> + } + scf.if %arg0 { + hw.constant false {f} + } + // CHECK-NEXT: arc.memory_read %arg2[%arg0] {cantMoveAcrossC} + arc.memory_read %arg2[%arg0] {cantMoveAcrossC} : <2 x i42, i1> + // CHECK-NEXT: scf.if %arg0 { + // CHECK-NEXT: hw.constant false {g} + // CHECK-NEXT: } + scf.if %arg0 { + hw.constant false {g} + } + return +} + +// CHECK-LABEL: func.func @MergeNestedIfs +func.func @MergeNestedIfs(%arg0: i42, %arg1: i1, %arg2: i1) { + // CHECK-NEXT: scf.if %arg1 { + // CHECK-NEXT: hw.constant false {a} + // CHECK-NEXT: hw.constant false {b} + // CHECK-NEXT: hw.constant false {c} + // CHECK-NEXT: scf.if %arg2 { + // CHECK-NEXT: hw.constant false {x} + // CHECK-NEXT: hw.constant false {y} + // CHECK-NEXT: } + // CHECK-NEXT: hw.constant false {d} + // CHECK-NEXT: } + scf.if %arg1 { + hw.constant false {a} + scf.if %arg2 { + hw.constant false {x} + } + hw.constant false {b} + } + scf.if %arg1 { + hw.constant false {c} + scf.if %arg2 { + hw.constant false {y} + } + hw.constant false {d} + } + return +} diff --git a/tools/arcilator/arcilator.cpp b/tools/arcilator/arcilator.cpp index 71f025135162..dc5aa8403106 100644 --- a/tools/arcilator/arcilator.cpp +++ b/tools/arcilator/arcilator.cpp @@ -337,7 +337,7 @@ static void populateHwModuleToArcPipeline(PassManager &pm) { pm.addPass(createCSEPass()); } - pm.addPass(arc::createGroupResetsAndEnablesPass()); + pm.addPass(arc::createMergeIfsPass()); pm.addPass(arc::createLegalizeStateUpdatePass()); pm.addPass(createCSEPass()); pm.addPass(arc::createArcCanonicalizerPass());