From a63a3b1bfdb83d93ce34600dd7b6af53e2cf1995 Mon Sep 17 00:00:00 2001 From: Robert Young Date: Wed, 21 Aug 2024 16:12:58 -0400 Subject: [PATCH] Advanced LayerSink --- include/circt/Dialect/FIRRTL/Passes.td | 2 +- lib/Dialect/FIRRTL/Transforms/LayerSink.cpp | 412 +++++++++++++++++++- lib/Firtool/Firtool.cpp | 2 +- test/Dialect/FIRRTL/layer-sink.mlir | 11 +- 4 files changed, 400 insertions(+), 27 deletions(-) diff --git a/include/circt/Dialect/FIRRTL/Passes.td b/include/circt/Dialect/FIRRTL/Passes.td index ec7043fb6a97..874c94b00673 100644 --- a/include/circt/Dialect/FIRRTL/Passes.td +++ b/include/circt/Dialect/FIRRTL/Passes.td @@ -846,7 +846,7 @@ def LayerMerge : Pass<"firrtl-layer-merge", "firrtl::FModuleOp"> { ]; } -def LayerSink : Pass<"firrtl-layer-sink", "firrtl::FModuleOp"> { +def LayerSink : Pass<"firrtl-layer-sink", "firrtl::CircuitOp"> { let summary = "Sink operations into layer blocks"; let description = [{ diff --git a/lib/Dialect/FIRRTL/Transforms/LayerSink.cpp b/lib/Dialect/FIRRTL/Transforms/LayerSink.cpp index ff99c0ca1f19..e838628ce78e 100644 --- a/lib/Dialect/FIRRTL/Transforms/LayerSink.cpp +++ b/lib/Dialect/FIRRTL/Transforms/LayerSink.cpp @@ -10,15 +10,18 @@ // //===----------------------------------------------------------------------===// +#include "circt/Dialect/FIRRTL/FIRRTLInstanceGraph.h" #include "circt/Dialect/FIRRTL/FIRRTLOps.h" #include "circt/Dialect/FIRRTL/Passes.h" -#include "mlir/Pass/Pass.h" - #include "circt/Support/Debug.h" #include "mlir/IR/Dominance.h" +#include "mlir/IR/Iterators.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Pass/Pass.h" #include "mlir/Transforms/ControlFlowSinkUtils.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/PostOrderIterator.h" #define DEBUG_TYPE "firrtl-layer-sink" @@ -31,33 +34,404 @@ namespace firrtl { using namespace circt; using namespace firrtl; +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Helpers +//===----------------------------------------------------------------------===// + +// NOLINTBEGIN(misc-no-recursion) +namespace { +/// Walk the ops in `block` bottom-up, back-to-front order. +template +void walkBwd(Block *block, T &&f) { + for (auto &op : + llvm::make_early_inc_range(llvm::reverse(block->getOperations()))) { + for (auto ®ion : op.getRegions()) { + for (auto &block : region.getBlocks()) { + walkBwd(&block, f); + } + } + f(&op); + } +} +} // namespace +// NOLINTEND(misc-no-recursion) + +static bool isAncestor(Block *block, Block *other) { + while (other) { + if (block == other) + return true; + other = other->getParentOp()->getBlock(); + } + return false; +} + +//===----------------------------------------------------------------------===// +// Effect Information. +//===----------------------------------------------------------------------===// + +/// Return true if this is a wire or register we're allowed to delete. +static bool mustKeep(Operation *op) { + if (!AnnotationSet(op).canBeDeleted() || hasDontTouch(op)) + return true; + if (auto name = dyn_cast(op)) + return !name.hasDroppableName(); + return false; +} + +/// A table that can determine whether an operation is effectful. +namespace { +struct EffectInfo { + /// True if the given operation is NOT moveable due to some effect. + bool effectful(Operation *op) const { + if (isa(op)) + return false; + if (mustKeep(op)) + return true; + if (op->getNumRegions() != 0) + return true; + if (auto instance = dyn_cast(op)) + return effectfulModules.contains(instance.getModuleNameAttr().getAttr()); + if (isa(op)) + return false; + return !(mlir::isMemoryEffectFree(op) || + mlir::hasSingleEffect(op) || + mlir::hasSingleEffect(op)); + } + + /// Record whether the module contains any effectful ops. + void update(FModuleOp module) { + module.getBodyBlock()->walk([&](Operation *op) { + if (effectful(op)) { + markEffectful(module); + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + } + + void update(FModuleLike module) { + auto *op = module.getOperation(); + // Regular modules may be pure. + if (auto m = dyn_cast(op)) + return update(m); + // Memory modules are pure. + if (auto m = dyn_cast(op)) + return; + // All other kinds of modules are effectful. + // intmodules, extmodules, classes. + return markEffectful(module); + } + + void update(Operation *op) { + if (auto module = dyn_cast(op)) + update(module); + } + + /// Record that the given module contains an effectful operation. + void markEffectful(FModuleLike module) { + effectfulModules.insert(module.getModuleNameAttr()); + } + + DenseSet effectfulModules; +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Demands Analysis. +//===----------------------------------------------------------------------===// + +namespace { +/// The LCA of the blocks in which a value is used/demanded. Lattice value. +struct Demand { + constexpr Demand() : Demand(nullptr) {} + constexpr Demand(Block *block) : block(block) {} + + constexpr Demand merge(Demand other) const { + if (block == other.block) + return Demand(block); + if (other.block == nullptr) + return Demand(block); + if (block == nullptr) + return Demand(other.block); + + auto *b = block; + while (b && !isAncestor(b, other.block)) { + b = block->getParentOp()->getBlock(); + } + return Demand(b); + } + + bool mergeIn(Demand other) { + auto prev = *this; + auto next = merge(other); + *this = next; + return prev != next; + } + + constexpr bool operator==(Demand rhs) const { return block == rhs.block; } + + constexpr bool operator!=(Demand rhs) const { return block != rhs.block; } + + constexpr operator bool() const { return block; } + + // nullptr means "not demanded." + Block *block; +}; +} // namespace + +/// True if this operation is a good site to sink operations. +static bool isValidDest(Operation *op) { + return op && (isa(op) || isa(op)); +} + +/// True if we are prevented from sinking operations into the regions of the op. +static bool isBarrier(Operation *op) { return !isValidDest(op); } + +/// Adjust the demand based on the location of the op being demanded. Ideally, +/// we can sink an op directly to it's site of use. However, there are two +/// issues. +/// +/// 1) Typically, an op will dominate every demand, but for hardware +/// declarations such as wires, the declaration will demand any connections +/// driving it. In this case, the relationship is reversed: the demander +/// dominates the demandee. This can cause us to pull connect-like ops up and +/// and out of their enclosing block. To avoid this, we set an upper bound on +/// the demand: the / enclosing block of the demandee. +/// +/// 2) not all regions are valid sink targets. If there is a sink-barrier +/// between the operation and it's demand, we adjust the demand upwards so that +/// there is no sink barrier between the demandee and the demand site. +static Demand clamp(Operation *op, Demand demand) { + if (!demand) + return nullptr; + + auto *upper = op->getBlock(); + if (!upper) + return nullptr; + + if (!isAncestor(upper, demand.block)) + demand = upper; + + for (auto *i = demand.block; i != upper; i = i->getParentOp()->getBlock()) + if (isBarrier(i->getParentOp())) + demand = i->getParentOp()->getBlock(); + + return demand; +} + +namespace { +struct DemandInfo { + using WorkStack = std::vector; + + DemandInfo(const EffectInfo &, FModuleOp); + + Demand getDemandFor(Operation *op) const { return table.lookup(op); } + + /// Phase 1: Starting at effectful ops and output ports, propagate the demand + /// of values through the design, running until fixpoint. At the end, we have + /// an accurate picture of where every operation can be sunk, while preserving + /// effects in the program. + void phase1(const EffectInfo &, FModuleOp, WorkStack &); + + /// Phase 2: in order to avoid deleting undemanded operations, pretend they + /// are demanded "wherever they are", and run again until fixpoint. + void phase2(const EffectInfo &, FModuleOp, WorkStack &); + + /// Run to fixpoint. + void run(const EffectInfo &, FModuleOp, WorkStack &); + + /// Update the demand for the given op. If the demand changes, place the op + /// onto the worklist. + void update(WorkStack &work, Operation *op, Demand demand) { + auto d = clamp(op, demand); + auto &entry = table[op]; + if (entry.mergeIn(d)) + work.push_back(op); + } + + void update(WorkStack &work, Value value, Demand demand) { + if (auto result = dyn_cast(value)) + update(work, cast(value).getOwner(), demand); + } + + void updateConnects(WorkStack &, Value, Demand); + void updateConnects(WorkStack &, Operation *, Demand); + + llvm::DenseMap table; +}; +} // namespace + +DemandInfo::DemandInfo(const EffectInfo &effectInfo, FModuleOp module) { + WorkStack work; + phase1(effectInfo, module, work); + phase2(effectInfo, module, work); +} + +void DemandInfo::run(const EffectInfo &effectInfo, FModuleOp, WorkStack &work) { + while (!work.empty()) { + auto *op = work.back(); + work.pop_back(); + auto demand = getDemandFor(op); + for (auto rand : op->getOperands()) + update(work, rand, demand); + updateConnects(work, op, demand); + } +} + +void DemandInfo::phase1(const EffectInfo &effectInfo, FModuleOp module, + WorkStack &work) { + Block *body = module.getBodyBlock(); + ArrayRef dirs = module.getPortDirections(); + for (unsigned i = 0, e = module.getNumPorts(); i < e; ++i) { + if (direction::get(dirs[i]) == Direction::Out) + updateConnects(work, body->getArgument(i), module.getBodyBlock()); + } + module.getBodyBlock()->walk([&](Operation *op) { + if (effectInfo.effectful(op)) { + update(work, op, op->getBlock()); + return; + } + }); + run(effectInfo, module, work); +} + +void DemandInfo::phase2(const EffectInfo &effectInfo, FModuleOp module, + WorkStack &work) { + module.getBodyBlock()->walk([&](Operation *op) -> void { + auto demand = table[op]; + if (!demand) + update(work, op, op->getBlock()); + }); + run(effectInfo, module, work); +} + +// The value represents a hardware declaration, such as a wire or port. Search +// backwards through uses, looking through aliasing operations such as +// subfields, to find connects that drive the given value. All driving +// connects of a value are demanded by the same region as the value. If the +// demand of the connect op is updated, then the demand will propagate +// forwards to its operands through the normal forward-propagation of demand. +void DemandInfo::updateConnects(WorkStack &work, Value value, Demand demand) { + struct StackElement { + Value value; + Value::user_iterator it; + }; + + SmallVector stack; + stack.push_back({value, value.user_begin()}); + while (!stack.empty()) { + auto &top = stack.back(); + auto end = top.value.user_end(); + while (true) { + if (top.it == end) { + stack.pop_back(); + break; + } + auto *user = *(top.it++); + if (auto connect = dyn_cast(user)) { + if (connect.getDest() == top.value) { + update(work, connect, demand); + } + continue; + } + if (isa(user)) { + for (auto result : user->getResults()) + stack.push_back({result, result.user_begin()}); + break; + } + } + } +} + +void DemandInfo::updateConnects(WorkStack &work, Operation *op, Demand demand) { + if (isa(op)) { + for (auto result : op->getResults()) + updateConnects(work, result, demand); + } else if (auto inst = dyn_cast(op)) { + auto dirs = inst.getPortDirections(); + for (unsigned i = 0, e = inst->getNumResults(); i < e; ++i) { + if (direction::get(dirs[i]) == Direction::In) + updateConnects(work, inst.getResult(i), demand); + } + } +} + +//===----------------------------------------------------------------------===// +// Sink Layers. +//===----------------------------------------------------------------------===// + +static bool run(const EffectInfo &effectInfo, FModuleOp module) { + DemandInfo demandInfo(effectInfo, module); + DenseSet seen; + bool changed = false; + walkBwd(module.getBodyBlock(), [&](Operation *op) { + seen.insert(op); + auto demand = demandInfo.getDemandFor(op); + if (!demand.block) { + op->erase(); + changed = true; + return; + } + + if (demand.block == op->getBlock()) + return; + + auto *destination = demand.block->getParentOp(); + if (!seen.contains(destination)) { + destination->moveBefore(op); + seen.insert(destination); + } + + op->moveBefore(demand.block, demand.block->begin()); + changed = true; + }); + + return changed; +} + +static bool run(InstanceGraph &instanceGraph, CircuitOp circuit) { + bool changed = false; + DenseSet visited; + EffectInfo effectInfo; + for (auto *root : instanceGraph) { + for (auto *node : llvm::post_order_ext(root, visited)) { + auto *op = node->getModule().getOperation(); + effectInfo.update(op); + if (auto module = dyn_cast(node->getModule().getOperation())) + changed |= run(effectInfo, module); + } + } + return changed; +} + +//===----------------------------------------------------------------------===// +// LayerSinkPass +//===----------------------------------------------------------------------===// namespace { /// A control-flow sink pass. -struct LayerSink : public circt::firrtl::impl::LayerSinkBase { +struct LayerSinkPass final + : public circt::firrtl::impl::LayerSinkBase { void runOnOperation() override; }; -} // end anonymous namespace +} // namespace -void LayerSink::runOnOperation() { +void LayerSinkPass::runOnOperation() { LLVM_DEBUG(debugPassHeader(this) << "\n" << "Module: '" << getOperation().getName() << "'\n";); - auto &domInfo = getAnalysis(); - getOperation()->walk([&](LayerBlockOp layerBlock) { - SmallVector regionsToSink({&layerBlock.getRegion()}); - numSunk = controlFlowSink( - regionsToSink, domInfo, - [](Operation *op, Region *) { return !hasDontTouch(op); }, - [](Operation *op, Region *region) { - // Move the operation to the beginning of the region's entry block. - // This guarantees the preservation of SSA dominance of all of the - // operation's uses are in the region. - op->moveBefore(®ion->front(), region->front().begin()); - }); - }); + auto &instanceGraph = getAnalysis(); + auto changed = run(instanceGraph, getOperation()); + if (!changed) + markAllAnalysesPreserved(); } +//===----------------------------------------------------------------------===// +// Pass Constructor +//===----------------------------------------------------------------------===// + std::unique_ptr circt::firrtl::createLayerSinkPass() { - return std::make_unique(); + return std::make_unique(); } diff --git a/lib/Firtool/Firtool.cpp b/lib/Firtool/Firtool.cpp index fe93765ce7e3..c1724793dd62 100644 --- a/lib/Firtool/Firtool.cpp +++ b/lib/Firtool/Firtool.cpp @@ -137,9 +137,9 @@ LogicalResult firtool::populateCHIRRTLToLowFIRRTL(mlir::PassManager &pm, { auto &modulePM = pm.nest().nest(); modulePM.addPass(firrtl::createLayerMergePass()); - modulePM.addPass(firrtl::createLayerSinkPass()); } + pm.nest().addPass(firrtl::createLayerSinkPass()); pm.nest().addPass(firrtl::createLowerLayersPass()); pm.nest().addPass(firrtl::createInlinerPass()); diff --git a/test/Dialect/FIRRTL/layer-sink.mlir b/test/Dialect/FIRRTL/layer-sink.mlir index fb60935c733a..db698358592c 100644 --- a/test/Dialect/FIRRTL/layer-sink.mlir +++ b/test/Dialect/FIRRTL/layer-sink.mlir @@ -1,4 +1,4 @@ -// RUN: circt-opt -pass-pipeline="builtin.module(firrtl.circuit(firrtl.module(firrtl-layer-sink)))" %s | FileCheck %s +// RUN: circt-opt -pass-pipeline="builtin.module(firrtl.circuit(firrtl-layer-sink))" -allow-unregistered-dialect %s | FileCheck %s // Test that simple things are sunk: // - nodes @@ -7,8 +7,7 @@ // // CHECK-LABEL: firrtl.circuit "SimpleSink" firrtl.circuit "SimpleSink" { - firrtl.layer @A bind { - } + firrtl.layer @A bind {} // CHECK: firrtl.module @SimpleSink firrtl.module @SimpleSink(in %a: !firrtl.uint<1>) { %c0_ui1 = firrtl.constant 0 : !firrtl.uint<1> @@ -17,11 +16,11 @@ firrtl.circuit "SimpleSink" { // CHECK-NEXT: firrtl.layerblock @A firrtl.layerblock @A { // CHECK: %c0_ui1 = firrtl.constant - %constant_layer = firrtl.node %c0_ui1 : !firrtl.uint<1> // CHECK: %node = firrtl.node - %node_layer = firrtl.node %node : !firrtl.uint<1> // CHECK: %0 = firrtl.not - %primop_layer = firrtl.node %0 : !firrtl.uint<1> + "unknown"(%c0_ui1) : (!firrtl.uint<1>) -> !firrtl.uint<1> + "unknown"(%node) : (!firrtl.uint<1>) -> !firrtl.uint<1> + "unknown"(%0) : (!firrtl.uint<1>) -> !firrtl.uint<1> } } }