From 852a624c14ef3a6a8d4461fe7d36479d93d3f008 Mon Sep 17 00:00:00 2001 From: Amelia Dobis Date: Fri, 26 Jul 2024 17:38:17 -0700 Subject: [PATCH] Converted all FModuleOp passes to FModuleLike passes --- .../circt/Dialect/FIRRTL/FIRRTLIntrinsics.h | 4 +- lib/Dialect/FIRRTL/FIRRTLIntrinsics.cpp | 19 +- .../FIRRTL/Transforms/FlattenMemory.cpp | 24 +- .../FIRRTL/Transforms/LowerCHIRRTL.cpp | 30 +- .../FIRRTL/Transforms/LowerIntrinsics.cpp | 27 +- .../Transforms/MaterializeDebugInfo.cpp | 41 +- .../FIRRTL/Transforms/MergeConnections.cpp | 400 +++++++++--------- .../Transforms/PrintFIRRTLFieldSource.cpp | 18 +- .../Transforms/RandomizeRegisterInit.cpp | 12 +- .../FIRRTL/Transforms/RegisterOptimizer.cpp | 41 +- lib/Firtool/Firtool.cpp | 39 +- 11 files changed, 380 insertions(+), 275 deletions(-) diff --git a/include/circt/Dialect/FIRRTL/FIRRTLIntrinsics.h b/include/circt/Dialect/FIRRTL/FIRRTLIntrinsics.h index 5cf38075e3e3..1727cdd01aba 100644 --- a/include/circt/Dialect/FIRRTL/FIRRTLIntrinsics.h +++ b/include/circt/Dialect/FIRRTL/FIRRTLIntrinsics.h @@ -238,7 +238,9 @@ class IntrinsicLowerings { } /// Lowers all intrinsics in a module. Returns number converted or failure. - FailureOr lower(FModuleOp mod, bool allowUnknownIntrinsics = false); + template + FailureOr lower(ModuleLikeOp mod, + bool allowUnknownIntrinsics = false); private: template diff --git a/lib/Dialect/FIRRTL/FIRRTLIntrinsics.cpp b/lib/Dialect/FIRRTL/FIRRTLIntrinsics.cpp index 29385e08ad97..b5643402c116 100644 --- a/lib/Dialect/FIRRTL/FIRRTLIntrinsics.cpp +++ b/lib/Dialect/FIRRTL/FIRRTLIntrinsics.cpp @@ -144,7 +144,24 @@ class IntrinsicOpConversion final // IntrinsicLowerings //===----------------------------------------------------------------------===// -FailureOr IntrinsicLowerings::lower(FModuleOp mod, +// Explicit instanciation of things to avoid implementing template in header +template FailureOr +IntrinsicLowerings::lower(FModuleOp mod, bool allowUnknownIntrinsics); +template FailureOr +IntrinsicLowerings::lower(FExtModuleOp mod, bool allowUnknownIntrinsics); +template FailureOr +IntrinsicLowerings::lower(FIntModuleOp mod, bool allowUnknownIntrinsics); +template FailureOr +IntrinsicLowerings::lower(FMemModuleOp mod, bool allowUnknownIntrinsics); +template FailureOr +IntrinsicLowerings::lower(ClassOp mod, bool allowUnknownIntrinsics); +template FailureOr +IntrinsicLowerings::lower(ExtClassOp mod, bool allowUnknownIntrinsics); +template FailureOr +IntrinsicLowerings::lower(FormalOp mod, bool allowUnknownIntrinsics); + +template +FailureOr IntrinsicLowerings::lower(ModuleLikeOp mod, bool allowUnknownIntrinsics) { ConversionTarget target(*context); diff --git a/lib/Dialect/FIRRTL/Transforms/FlattenMemory.cpp b/lib/Dialect/FIRRTL/Transforms/FlattenMemory.cpp index a2fbcae31be0..76e49254dfbe 100644 --- a/lib/Dialect/FIRRTL/Transforms/FlattenMemory.cpp +++ b/lib/Dialect/FIRRTL/Transforms/FlattenMemory.cpp @@ -36,11 +36,9 @@ using namespace firrtl; namespace { struct FlattenMemoryPass : public circt::firrtl::impl::FlattenMemoryBase { - /// This pass flattens the aggregate data of memory into a UInt, and inserts - /// appropriate bitcasts to access the data. - void runOnOperation() override { - LLVM_DEBUG(llvm::dbgs() << "\n Running lower memory on module:" - << getOperation().getName()); + + template + void runOnOp(Op op) { SmallVector opsToErase; auto hasSubAnno = [&](MemOp op) -> bool { for (size_t portIdx = 0, e = op.getNumResults(); portIdx < e; ++portIdx) @@ -50,7 +48,8 @@ struct FlattenMemoryPass return false; }; - getOperation().getBodyBlock()->walk([&](MemOp memOp) { + Block *body = op.getBodyBlock(); + body->walk([&](MemOp memOp) { LLVM_DEBUG(llvm::dbgs() << "\n Memory:" << memOp); // The vector of leaf elements type after flattening the data. SmallVector flatMemType; @@ -196,6 +195,19 @@ struct FlattenMemoryPass return; }); } + /// This pass flattens the aggregate data of memory into a UInt, and inserts + /// appropriate bitcasts to access the data. + void runOnOperation() override { + LLVM_DEBUG(llvm::dbgs() << "\n Running lower memory on module:" + << getOperation().getName()); + + TypeSwitch(&(*getOperation())) + .Case([&](auto op) { runOnOp(op); }) + // All other ops are ignored -- particularly ops that don't implement + // the `getBodyBlock()` method. We don't want an error here because the + // pass wasn't designed to run on those ops. + .Default([&](auto) {}); + } private: // Convert an aggregate type into a flat list of fields. diff --git a/lib/Dialect/FIRRTL/Transforms/LowerCHIRRTL.cpp b/lib/Dialect/FIRRTL/Transforms/LowerCHIRRTL.cpp index dd2245bdf1cc..4a025b46916c 100644 --- a/lib/Dialect/FIRRTL/Transforms/LowerCHIRRTL.cpp +++ b/lib/Dialect/FIRRTL/Transforms/LowerCHIRRTL.cpp @@ -66,10 +66,10 @@ struct LowerCHIRRTLPass Value getConst(unsigned c) { auto &value = constCache[c]; if (!value) { - auto module = getOperation(); - auto builder = OpBuilder::atBlockBegin(module.getBodyBlock()); + auto builder = OpBuilder::atBlockBegin(body); auto u1Type = UIntType::get(builder.getContext(), /*width*/ 1); - value = builder.create(module.getLoc(), u1Type, APInt(1, c)); + value = builder.create(getOperation().getLoc(), u1Type, + APInt(1, c)); } return value; } @@ -96,6 +96,9 @@ struct LowerCHIRRTLPass void runOnOperation() override; + template + void runOnOp(Op op); + /// Cached constants. DenseMap constCache; DenseMap invalidCache; @@ -124,6 +127,9 @@ struct LowerCHIRRTLPass Value mode; }; DenseMap wdataValues; + + // Internally used data about the operations (avoids template issues) + Block *body; }; } // end anonymous namespace @@ -159,7 +165,7 @@ void LowerCHIRRTLPass::emitInvalid(ImplicitLocOpBuilder &builder, Value value) { auto type = value.getType(); auto &invalid = invalidCache[type]; if (!invalid) { - auto builder = OpBuilder::atBlockBegin(getOperation().getBodyBlock()); + auto builder = OpBuilder::atBlockBegin(body); invalid = builder.create(getOperation().getLoc(), type); } emitConnect(builder, value, invalid); @@ -650,12 +656,13 @@ void LowerCHIRRTLPass::visitUnhandledOp(Operation *op) { } } -void LowerCHIRRTLPass::runOnOperation() { +template +void LowerCHIRRTLPass::runOnOp(Op op) { // Walk the entire body of the module and dispatch the visitor on each // function. This will replace all CHIRRTL memories and ports, and update all // uses. - getOperation().getBodyBlock()->walk( - [&](Operation *op) { dispatchCHIRRTLVisitor(op); }); + body = op.getBodyBlock(); + body->walk([&](Operation *op) { dispatchCHIRRTLVisitor(op); }); // If there are no operations to delete, then we didn't find any CHIRRTL // memories. @@ -670,6 +677,15 @@ void LowerCHIRRTLPass::runOnOperation() { clear(); } +void LowerCHIRRTLPass::runOnOperation() { + TypeSwitch(&(*getOperation())) + .Case([&](auto op) { runOnOp(op); }) + // All other ops are ignored -- particularly ops that don't implement + // the `getBodyBlock()` method. We don't want an error here because the + // pass wasn't designed to run on those ops. + .Default([&](auto) {}); +} + std::unique_ptr circt::firrtl::createLowerCHIRRTLPass() { return std::make_unique(); } diff --git a/lib/Dialect/FIRRTL/Transforms/LowerIntrinsics.cpp b/lib/Dialect/FIRRTL/Transforms/LowerIntrinsics.cpp index bbad0df1a770..8e40cd90638e 100644 --- a/lib/Dialect/FIRRTL/Transforms/LowerIntrinsics.cpp +++ b/lib/Dialect/FIRRTL/Transforms/LowerIntrinsics.cpp @@ -33,6 +33,19 @@ using namespace firrtl; namespace { struct LowerIntrinsicsPass : public circt::firrtl::impl::LowerIntrinsicsBase { + + template + void runOnOp(Op op) { + auto result = lowering->lower(op); + if (failed(result)) + return signalPassFailure(); + + numConverted += *result; + + if (*result == 0) + markAllAnalysesPreserved(); + } + LogicalResult initialize(MLIRContext *context) override; void runOnOperation() override; @@ -53,14 +66,14 @@ LogicalResult LowerIntrinsicsPass::initialize(MLIRContext *context) { // This is the main entrypoint for the lowering pass. void LowerIntrinsicsPass::runOnOperation() { - auto result = lowering->lower(getOperation()); - if (failed(result)) - return signalPassFailure(); - - numConverted += *result; - if (*result == 0) - markAllAnalysesPreserved(); + TypeSwitch(&(*getOperation())) + .Case([&](auto op) { runOnOp(op); }) + // All other ops are ignored -- particularly ops that don't implement + // the `getBodyBlock()` method. We don't want an error here because the + // pass wasn't designed to run on those ops. + .Default([&](auto) {}); } /// This is the pass constructor. diff --git a/lib/Dialect/FIRRTL/Transforms/MaterializeDebugInfo.cpp b/lib/Dialect/FIRRTL/Transforms/MaterializeDebugInfo.cpp index 6259376eb07d..b754453fd05f 100644 --- a/lib/Dialect/FIRRTL/Transforms/MaterializeDebugInfo.cpp +++ b/lib/Dialect/FIRRTL/Transforms/MaterializeDebugInfo.cpp @@ -32,6 +32,25 @@ namespace { struct MaterializeDebugInfoPass : public circt::firrtl::impl::MaterializeDebugInfoBase< MaterializeDebugInfoPass> { + template + void runOnOp(Op module) { + auto builder = OpBuilder::atBlockBegin(module.getBodyBlock()); + + // Create DI variables for each port. + for (const auto &[port, value] : + llvm::zip(module.getPorts(), module.getArguments())) { + materializeVariable(builder, port.name, value); + } + + // Create DI variables for each declaration in the module body. + module.walk([&](Operation *op) { + TypeSwitch(op).Case( + [&](auto op) { + builder.setInsertionPointAfter(op); + materializeVariable(builder, op.getNameAttr(), op.getResult()); + }); + }); + } void runOnOperation() override; void materializeVariable(OpBuilder &builder, StringAttr name, Value value); Value convertToDebugAggregates(OpBuilder &builder, Value value); @@ -39,23 +58,13 @@ struct MaterializeDebugInfoPass } // namespace void MaterializeDebugInfoPass::runOnOperation() { - auto module = getOperation(); - auto builder = OpBuilder::atBlockBegin(module.getBodyBlock()); - - // Create DI variables for each port. - for (const auto &[port, value] : - llvm::zip(module.getPorts(), module.getArguments())) { - materializeVariable(builder, port.name, value); - } - // Create DI variables for each declaration in the module body. - module.walk([&](Operation *op) { - TypeSwitch(op).Case( - [&](auto op) { - builder.setInsertionPointAfter(op); - materializeVariable(builder, op.getNameAttr(), op.getResult()); - }); - }); + TypeSwitch(&(*getOperation())) + .Case([&](auto op) { runOnOp(op); }) + // All other ops are ignored -- particularly ops that don't implement + // the `getBodyBlock()` method. We don't want an error here because the + // pass wasn't designed to run on those ops. + .Default([&](auto) {}); } /// Materialize debug variable ops for a value. diff --git a/lib/Dialect/FIRRTL/Transforms/MergeConnections.cpp b/lib/Dialect/FIRRTL/Transforms/MergeConnections.cpp index b9ba73a04625..7ffaa757cdf6 100644 --- a/lib/Dialect/FIRRTL/Transforms/MergeConnections.cpp +++ b/lib/Dialect/FIRRTL/Transforms/MergeConnections.cpp @@ -61,249 +61,257 @@ namespace { //===----------------------------------------------------------------------===// // A helper struct to merge connections. +template struct MergeConnection { - MergeConnection(FModuleOp moduleOp, bool enableAggressiveMerging) + MergeConnection(ModuleLikeOp moduleOp, bool enableAggressiveMerging) : moduleOp(moduleOp), enableAggressiveMerging(enableAggressiveMerging) {} - // Return true if something is changed. - bool run(); - bool changed = false; - - // Return true if the given connect op is merged. - bool peelConnect(MatchingConnectOp connect); - // A map from a destination FieldRef to a pair of (i) the number of // connections seen so far and (ii) the vector to store subconnections. DenseMap>> connections; - FModuleOp moduleOp; + ModuleLikeOp moduleOp; ImplicitLocOpBuilder *builder = nullptr; // If true, we merge connections even when source values will not be // simplified. bool enableAggressiveMerging = false; -}; -bool MergeConnection::peelConnect(MatchingConnectOp connect) { - // Ignore connections between different types because it will produce a - // partial connect. Also ignore non-passive connections or non-integer - // connections. - LLVM_DEBUG(llvm::dbgs() << "Visiting " << connect << "\n"); - auto destTy = type_dyn_cast(connect.getDest().getType()); - if (!destTy || !destTy.isPassive() || - !firrtl::getBitWidth(destTy).has_value()) - return false; - - auto destFieldRef = getFieldRefFromValue(connect.getDest()); - auto destRoot = destFieldRef.getValue(); - - // If dest is derived from mem op or has a ground type, we cannot merge them. - // If the connect's destination is a root value, we cannot merge. - if (destRoot.getDefiningOp() || destRoot == connect.getDest()) - return false; - - Value parent; - unsigned index; - if (auto subfield = dyn_cast(connect.getDest().getDefiningOp())) - parent = subfield.getInput(), index = subfield.getFieldIndex(); - else if (auto subindex = - dyn_cast(connect.getDest().getDefiningOp())) - parent = subindex.getInput(), index = subindex.getIndex(); - else - llvm_unreachable("unexpected destination"); - - auto &countAndSubConnections = connections[getFieldRefFromValue(parent)]; - auto &count = countAndSubConnections.first; - auto &subConnections = countAndSubConnections.second; - - // If it is the first time to visit the parent op, then allocate the vector - // for subconnections. - if (count == 0) { - if (auto bundle = type_dyn_cast(parent.getType())) - subConnections.resize(bundle.getNumElements()); - if (auto vector = type_dyn_cast(parent.getType())) - subConnections.resize(vector.getNumElements()); - } - ++count; - subConnections[index] = connect; + // Return true if something is changed. + bool changed = false; + bool run() { + ImplicitLocOpBuilder theBuilder(moduleOp.getLoc(), moduleOp.getContext()); + builder = &theBuilder; + auto *body = moduleOp.getBodyBlock(); + // Merge connections by forward iterations. + for (auto it = body->begin(), e = body->end(); it != e;) { + auto connectOp = dyn_cast(*it); + if (!connectOp) { + it++; + continue; + } + builder->setInsertionPointAfter(connectOp); + builder->setLoc(connectOp.getLoc()); + bool removeOp = peelConnect(connectOp); + ++it; + if (removeOp) + connectOp.erase(); + } - // If we haven't visited all subconnections, stop at this point. - if (count != subConnections.size()) - return false; + // Clean up dead operations introduced by this pass. + for (auto &op : llvm::make_early_inc_range(llvm::reverse(*body))) + if (isa(op)) + if (op.use_empty()) { + changed = true; + op.erase(); + } + + return changed; + } - auto parentType = parent.getType(); - auto parentBaseTy = type_dyn_cast(parentType); + // Return true if the given connect op is merged. + bool peelConnect(MatchingConnectOp connect) { + // Ignore connections between different types because it will produce a + // partial connect. Also ignore non-passive connections or non-integer + // connections. + LLVM_DEBUG(llvm::dbgs() << "Visiting " << connect << "\n"); + auto destTy = type_dyn_cast(connect.getDest().getType()); + if (!destTy || !destTy.isPassive() || + !firrtl::getBitWidth(destTy).has_value()) + return false; + + auto destFieldRef = getFieldRefFromValue(connect.getDest()); + auto destRoot = destFieldRef.getValue(); + + // If dest is derived from mem op or has a ground type, we cannot merge + // them. If the connect's destination is a root value, we cannot merge. + if (destRoot.getDefiningOp() || destRoot == connect.getDest()) + return false; + + Value parent; + unsigned index; + if (auto subfield = dyn_cast(connect.getDest().getDefiningOp())) + parent = subfield.getInput(), index = subfield.getFieldIndex(); + else if (auto subindex = + dyn_cast(connect.getDest().getDefiningOp())) + parent = subindex.getInput(), index = subindex.getIndex(); + else + llvm_unreachable("unexpected destination"); + + auto &countAndSubConnections = connections[getFieldRefFromValue(parent)]; + auto &count = countAndSubConnections.first; + auto &subConnections = countAndSubConnections.second; + + // If it is the first time to visit the parent op, then allocate the vector + // for subconnections. + if (count == 0) { + if (auto bundle = type_dyn_cast(parent.getType())) + subConnections.resize(bundle.getNumElements()); + if (auto vector = type_dyn_cast(parent.getType())) + subConnections.resize(vector.getNumElements()); + } + ++count; + subConnections[index] = connect; + + // If we haven't visited all subconnections, stop at this point. + if (count != subConnections.size()) + return false; + + auto parentType = parent.getType(); + auto parentBaseTy = type_dyn_cast(parentType); + + // Reject if not passive, we don't support aggregate constants for these. + if (!parentBaseTy || !parentBaseTy.isPassive()) + return false; + + changed = true; + + auto getMergedValue = [&](auto aggregateType) { + SmallVector operands; + + // This flag tracks whether we can use the parent of source values as the + // merged value. + bool canUseSourceParent = true; + bool areOperandsAllConstants = true; + + // The value which might be used as a merged value. + Value sourceParent; + + auto checkSourceParent = [&](auto subelement, unsigned destIndex, + unsigned sourceIndex) { + // In the first iteration, register a parent value. + if (destIndex == 0) { + if (subelement.getInput().getType() == parentType) + sourceParent = subelement.getInput(); + else { + // If types are not same, it is not possible to use it. + canUseSourceParent = false; + } + } - // Reject if not passive, we don't support aggregate constants for these. - if (!parentBaseTy || !parentBaseTy.isPassive()) - return false; + // Check that input is the same as `sourceAggregate` and indexes match. + canUseSourceParent &= + subelement.getInput() == sourceParent && destIndex == sourceIndex; + }; - changed = true; + for (auto idx : llvm::seq(0u, (unsigned)aggregateType.getNumElements())) { + auto src = subConnections[idx].getSrc(); + assert(src && "all subconnections are guranteed to exist"); + operands.push_back(src); - auto getMergedValue = [&](auto aggregateType) { - SmallVector operands; + areOperandsAllConstants &= isConstantLike(src); - // This flag tracks whether we can use the parent of source values as the - // merged value. - bool canUseSourceParent = true; - bool areOperandsAllConstants = true; + // From here, check whether the value is derived from the same aggregate + // value. - // The value which might be used as a merged value. - Value sourceParent; + // If canUseSourceParent is already false, abort. + if (!canUseSourceParent) + continue; - auto checkSourceParent = [&](auto subelement, unsigned destIndex, - unsigned sourceIndex) { - // In the first iteration, register a parent value. - if (destIndex == 0) { - if (subelement.getInput().getType() == parentType) - sourceParent = subelement.getInput(); - else { - // If types are not same, it is not possible to use it. + // If the value is an argument, it is not derived from an aggregate + // value. + if (!src.getDefiningOp()) { canUseSourceParent = false; + continue; } - } - - // Check that input is the same as `sourceAggregate` and indexes match. - canUseSourceParent &= - subelement.getInput() == sourceParent && destIndex == sourceIndex; - }; - - for (auto idx : llvm::seq(0u, (unsigned)aggregateType.getNumElements())) { - auto src = subConnections[idx].getSrc(); - assert(src && "all subconnections are guranteed to exist"); - operands.push_back(src); - - areOperandsAllConstants &= isConstantLike(src); - // From here, check whether the value is derived from the same aggregate - // value. - - // If canUseSourceParent is already false, abort. - if (!canUseSourceParent) - continue; - - // If the value is an argument, it is not derived from an aggregate value. - if (!src.getDefiningOp()) { - canUseSourceParent = false; - continue; + TypeSwitch(src.getDefiningOp()) + .template Case([&](SubfieldOp subfield) { + checkSourceParent(subfield, idx, subfield.getFieldIndex()); + }) + .template Case([&](SubindexOp subindex) { + checkSourceParent(subindex, idx, subindex.getIndex()); + }) + .Default([&](auto) { canUseSourceParent = false; }); } - TypeSwitch(src.getDefiningOp()) - .template Case([&](SubfieldOp subfield) { - checkSourceParent(subfield, idx, subfield.getFieldIndex()); - }) - .template Case([&](SubindexOp subindex) { - checkSourceParent(subindex, idx, subindex.getIndex()); - }) - .Default([&](auto) { canUseSourceParent = false; }); - } + // If it is fine to use `sourceParent` as a merged value, we just + // return it. + if (canUseSourceParent) { + LLVM_DEBUG(llvm::dbgs() + << "Success to merge " << destFieldRef.getValue() + << " ,fieldID= " << destFieldRef.getFieldID() << " to " + << sourceParent << "\n";); + // Erase connections except for subConnections[index] since it must be + // erased at the top-level loop. + for (auto idx : llvm::seq(0u, static_cast(operands.size()))) + if (idx != index) + subConnections[idx].erase(); + return sourceParent; + } - // If it is fine to use `sourceParent` as a merged value, we just - // return it. - if (canUseSourceParent) { - LLVM_DEBUG(llvm::dbgs() << "Success to merge " << destFieldRef.getValue() - << " ,fieldID= " << destFieldRef.getFieldID() - << " to " << sourceParent << "\n";); - // Erase connections except for subConnections[index] since it must be - // erased at the top-level loop. - for (auto idx : llvm::seq(0u, static_cast(operands.size()))) + // If operands are not all constants, we don't merge connections unless + // "aggressive-merging" option is enabled. + if (!enableAggressiveMerging && !areOperandsAllConstants) + return Value(); + + SmallVector locs; + // Otherwise, we concat all values and cast them into the aggregate type. + for (auto idx : llvm::seq(0u, static_cast(operands.size()))) { + locs.push_back(subConnections[idx].getLoc()); + // Erase connections except for subConnections[index] since it must be + // erased at the top-level loop. if (idx != index) subConnections[idx].erase(); - return sourceParent; - } + } - // If operands are not all constants, we don't merge connections unless - // "aggressive-merging" option is enabled. - if (!enableAggressiveMerging && !areOperandsAllConstants) - return Value(); - - SmallVector locs; - // Otherwise, we concat all values and cast them into the aggregate type. - for (auto idx : llvm::seq(0u, static_cast(operands.size()))) { - locs.push_back(subConnections[idx].getLoc()); - // Erase connections except for subConnections[index] since it must be - // erased at the top-level loop. - if (idx != index) - subConnections[idx].erase(); - } + return isa(parentType) + ? builder->createOrFold( + builder->getFusedLoc(locs), parentType, operands) + : builder->createOrFold( + builder->getFusedLoc(locs), parentType, operands); + }; - return isa(parentType) - ? builder->createOrFold( - builder->getFusedLoc(locs), parentType, operands) - : builder->createOrFold( - builder->getFusedLoc(locs), parentType, operands); - }; - - Value merged; - if (auto bundle = type_dyn_cast(parentType)) - merged = getMergedValue(bundle); - if (auto vector = type_dyn_cast(parentType)) - merged = getMergedValue(vector); - if (!merged) - return false; - - // Emit strict connect if possible, fallback to normal connect. - // Don't use emitConnect(), will split the connect apart. - if (!parentBaseTy.hasUninferredWidth()) - builder->create(connect.getLoc(), parent, merged); - else - builder->create(connect.getLoc(), parent, merged); - - return true; -} + Value merged; + if (auto bundle = type_dyn_cast(parentType)) + merged = getMergedValue(bundle); + if (auto vector = type_dyn_cast(parentType)) + merged = getMergedValue(vector); + if (!merged) + return false; + + // Emit strict connect if possible, fallback to normal connect. + // Don't use emitConnect(), will split the connect apart. + if (!parentBaseTy.hasUninferredWidth()) + builder->create(connect.getLoc(), parent, merged); + else + builder->create(connect.getLoc(), parent, merged); -bool MergeConnection::run() { - ImplicitLocOpBuilder theBuilder(moduleOp.getLoc(), moduleOp.getContext()); - builder = &theBuilder; - auto *body = moduleOp.getBodyBlock(); - // Merge connections by forward iterations. - for (auto it = body->begin(), e = body->end(); it != e;) { - auto connectOp = dyn_cast(*it); - if (!connectOp) { - it++; - continue; - } - builder->setInsertionPointAfter(connectOp); - builder->setLoc(connectOp.getLoc()); - bool removeOp = peelConnect(connectOp); - ++it; - if (removeOp) - connectOp.erase(); + return true; } - - // Clean up dead operations introduced by this pass. - for (auto &op : llvm::make_early_inc_range(llvm::reverse(*body))) - if (isa(op)) - if (op.use_empty()) { - changed = true; - op.erase(); - } - - return changed; -} +}; struct MergeConnectionsPass : public circt::firrtl::impl::MergeConnectionsBase { MergeConnectionsPass(bool enableAggressiveMergingFlag) { enableAggressiveMerging = enableAggressiveMergingFlag; } + template + void runOnOp(Op op) { + LLVM_DEBUG(debugPassHeader(this) << "\n" + << "Module: '" << op.getName() << "'\n"); + + MergeConnection mergeConnection(op, enableAggressiveMerging); + bool changed = mergeConnection.run(); + + if (!changed) + return markAllAnalysesPreserved(); + } void runOnOperation() override; }; } // namespace void MergeConnectionsPass::runOnOperation() { - LLVM_DEBUG(debugPassHeader(this) - << "\n" - << "Module: '" << getOperation().getName() << "'\n"); - - MergeConnection mergeConnection(getOperation(), enableAggressiveMerging); - bool changed = mergeConnection.run(); - if (!changed) - return markAllAnalysesPreserved(); + TypeSwitch(&(*getOperation())) + .Case([&](auto op) { runOnOp(op); }) + // All other ops are ignored -- particularly ops that don't implement + // the `getBodyBlock()` method. We don't want an error here because the + // pass wasn't designed to run on those ops. + .Default([&](auto) {}); } std::unique_ptr diff --git a/lib/Dialect/FIRRTL/Transforms/PrintFIRRTLFieldSource.cpp b/lib/Dialect/FIRRTL/Transforms/PrintFIRRTLFieldSource.cpp index f5a48fbb23fe..c34c0a218467 100644 --- a/lib/Dialect/FIRRTL/Transforms/PrintFIRRTLFieldSource.cpp +++ b/lib/Dialect/FIRRTL/Transforms/PrintFIRRTLFieldSource.cpp @@ -62,17 +62,27 @@ struct PrintFIRRTLFieldSourcePass visitOp(fieldRefs, &op); } - void runOnOperation() override { - auto modOp = getOperation(); + template + void runOnOp(Op modOp) { + Block *body = modOp.getBodyBlock(); os << "** " << modOp.getName() << "\n"; auto &fieldRefs = getAnalysis(); - for (auto port : modOp.getBodyBlock()->getArguments()) + for (auto port : body->getArguments()) visitValue(fieldRefs, port); - for (auto &op : *modOp.getBodyBlock()) + for (auto &op : *body) visitOp(fieldRefs, &op); markAllAnalysesPreserved(); } + + void runOnOperation() override { + TypeSwitch(&(*getOperation())) + .Case([&](auto op) { runOnOp(op); }) + // All other ops are ignored -- particularly ops that don't implement + // the `getBodyBlock()` method. We don't want an error here because the + // pass wasn't designed to run on those ops. + .Default([&](auto) {}); + } raw_ostream &os; }; } // end anonymous namespace diff --git a/lib/Dialect/FIRRTL/Transforms/RandomizeRegisterInit.cpp b/lib/Dialect/FIRRTL/Transforms/RandomizeRegisterInit.cpp index edd2f285c55e..b3806d57fe1d 100644 --- a/lib/Dialect/FIRRTL/Transforms/RandomizeRegisterInit.cpp +++ b/lib/Dialect/FIRRTL/Transforms/RandomizeRegisterInit.cpp @@ -50,7 +50,8 @@ std::unique_ptr circt::firrtl::createRandomizeRegisterInitPass() { /// each register should consume. The goal is for registers to always read the /// same random bits for the same seed, regardless of optimizations that might /// remove registers. -static void createRandomizationAttributes(FModuleOp mod) { +template +static void createRandomizationAttributes(Op mod) { OpBuilder builder(mod); // Walk all registers. @@ -74,5 +75,12 @@ static void createRandomizationAttributes(FModuleOp mod) { } void RandomizeRegisterInitPass::runOnOperation() { - createRandomizationAttributes(getOperation()); + TypeSwitch(&(*getOperation())) + .Case( + [&](auto op) { createRandomizationAttributes(op); }) + // All other ops are ignored -- particularly ops that don't implement + // the `getBodyBlock()` method. We don't want an error here because the + // pass wasn't designed to run on those ops. + .Default([&](auto) {}); } diff --git a/lib/Dialect/FIRRTL/Transforms/RegisterOptimizer.cpp b/lib/Dialect/FIRRTL/Transforms/RegisterOptimizer.cpp index c6ac00f3e44b..84fc13f28feb 100644 --- a/lib/Dialect/FIRRTL/Transforms/RegisterOptimizer.cpp +++ b/lib/Dialect/FIRRTL/Transforms/RegisterOptimizer.cpp @@ -47,6 +47,26 @@ namespace { struct RegisterOptimizerPass : public circt::firrtl::impl::RegisterOptimizerBase { + + template + void runOnOp(Op mod) { + SmallVector toErase; + mlir::DominanceInfo dom(mod); + Block *body = mod.getBodyBlock(); + + for (auto &op : *body) { + if (auto reg = dyn_cast(&op)) + checkRegReset(dom, toErase, reg); + else if (auto reg = dyn_cast(&op)) + checkReg(dom, toErase, reg); + } + for (auto *op : toErase) + op->erase(); + + if (!toErase.empty()) + return markAllAnalysesPreserved(); + } + void runOnOperation() override; void checkRegReset(mlir::DominanceInfo &dom, SmallVector &toErase, RegResetOp reg); @@ -142,23 +162,14 @@ void RegisterOptimizerPass::checkRegReset(mlir::DominanceInfo &dom, } void RegisterOptimizerPass::runOnOperation() { - auto mod = getOperation(); LLVM_DEBUG(debugPassHeader(this) << "\n";); - SmallVector toErase; - mlir::DominanceInfo dom(mod); - - for (auto &op : *mod.getBodyBlock()) { - if (auto reg = dyn_cast(&op)) - checkRegReset(dom, toErase, reg); - else if (auto reg = dyn_cast(&op)) - checkReg(dom, toErase, reg); - } - for (auto *op : toErase) - op->erase(); - - if (!toErase.empty()) - return markAllAnalysesPreserved(); + TypeSwitch(&(*getOperation())) + .Case([&](auto op) { runOnOp(op); }) + // All other ops are ignored -- particularly ops that don't implement + // the `getBodyBlock()` method. We don't want an error here because the + // pass wasn't designed to run on those ops. + .Default([&](auto) {}); } std::unique_ptr circt::firrtl::createRegisterOptimizerPass() { diff --git a/lib/Firtool/Firtool.cpp b/lib/Firtool/Firtool.cpp index ef5f4ca0ae31..14dd28b49403 100644 --- a/lib/Firtool/Firtool.cpp +++ b/lib/Firtool/Firtool.cpp @@ -38,12 +38,12 @@ LogicalResult firtool::populatePreprocessTransforms(mlir::PassManager &pm, opt.shouldAllowAddingPortsOnPublic())); if (opt.shouldEnableDebugInfo()) - pm.nest().addNestedPass( + pm.nest().nestAny().addPass( firrtl::createMaterializeDebugInfoPass()); pm.nest().addPass( firrtl::createLowerIntmodulesPass(opt.shouldFixupEICGWrapper())); - pm.nest().nest().addPass( + pm.nest().nestAny().addPass( firrtl::createLowerIntrinsicsPass()); return success(); @@ -56,22 +56,21 @@ LogicalResult firtool::populateCHIRRTLToLowFIRRTL(mlir::PassManager &pm, pm.nest().addPass(firrtl::createInjectDUTHierarchyPass()); - pm.nest().nest().addPass( + pm.nest().nestAny().addPass( firrtl::createPassiveWiresPass()); - pm.nest().nest().addPass( + pm.nest().nestAny().addPass( firrtl::createDropNamesPass(opt.getPreserveMode())); if (!opt.shouldDisableOptimization()) - pm.nest().nest().addPass( - mlir::createCSEPass()); + pm.nest().nestAny().addPass(mlir::createCSEPass()); - pm.nest().nest().addPass( + pm.nest().nestAny().addPass( firrtl::createLowerCHIRRTLPass()); // Run LowerMatches before InferWidths, as the latter does not support the // match statement, but it does support what they lower to. - pm.nest().nest().addPass( + pm.nest().nestAny().addPass( firrtl::createLowerMatchesPass()); // Width inference creates canonicalization opportunities. @@ -107,7 +106,7 @@ LogicalResult firtool::populateCHIRRTLToLowFIRRTL(mlir::PassManager &pm, } if (!opt.shouldLowerMemories()) - pm.nest().nest().addPass( + pm.nest().nestAny().addPass( firrtl::createFlattenMemoryPass()); // The input mlir file could be firrtl dialect so we might need to clean @@ -117,7 +116,7 @@ LogicalResult firtool::populateCHIRRTLToLowFIRRTL(mlir::PassManager &pm, opt.getPreserveAggregate(), firrtl::PreserveAggregate::None)); { - auto &modulePM = pm.nest().nest(); + auto &modulePM = pm.nest().nestAny(); modulePM.addPass(firrtl::createExpandWhensPass()); modulePM.addPass(firrtl::createSFCCompatPass()); } @@ -133,7 +132,7 @@ LogicalResult firtool::populateCHIRRTLToLowFIRRTL(mlir::PassManager &pm, pm.nest().addPass(firrtl::createProbesToSignalsPass()); { - auto &modulePM = pm.nest().nest(); + auto &modulePM = pm.nest().nestAny(); modulePM.addPass(firrtl::createLayerMergePass()); modulePM.addPass(firrtl::createLayerSinkPass()); } @@ -147,18 +146,18 @@ LogicalResult firtool::populateCHIRRTLToLowFIRRTL(mlir::PassManager &pm, // currently in the final module it will be emitted in, all registers have // been created, and no registers have yet been removed. if (opt.isRandomEnabled(FirtoolOptions::RandomKind::Reg)) - pm.nest().nest().addPass( + pm.nest().nestAny().addPass( firrtl::createRandomizeRegisterInitPass()); // If we parsed a FIRRTL file and have optimizations enabled, clean it up. if (!opt.shouldDisableOptimization()) - pm.nest().nest().addPass( + pm.nest().nestAny().addPass( createSimpleCanonicalizerPass()); // Run the infer-rw pass, which merges read and write ports of a memory with // mutually exclusive enables. if (!opt.shouldDisableOptimization()) - pm.nest().nest().addPass( + pm.nest().nestAny().addPass( firrtl::createInferReadWritePass()); if (opt.shouldReplicateSequentialMemories()) @@ -167,7 +166,7 @@ LogicalResult firtool::populateCHIRRTLToLowFIRRTL(mlir::PassManager &pm, pm.nest().addPass(firrtl::createPrefixModulesPass()); if (opt.shouldAddCompanionAssume()) - pm.nest().nest().addPass( + pm.nest().nestAny().addPass( circt::firrtl::createCreateCompanionAssume()); if (!opt.shouldDisableOptimization()) @@ -206,9 +205,9 @@ LogicalResult firtool::populateCHIRRTLToLowFIRRTL(mlir::PassManager &pm, // canonicalization opportunities that we should pick up here before we // proceed to output-specific pipelines. if (!opt.shouldDisableOptimization()) { - pm.nest().nest().addPass( + pm.nest().nestAny().addPass( createSimpleCanonicalizerPass()); - pm.nest().nest().addPass( + pm.nest().nestAny().addPass( circt::firrtl::createRegisterOptimizerPass()); // Re-run IMConstProp to propagate constants produced by register // optimizations. @@ -221,12 +220,12 @@ LogicalResult firtool::populateCHIRRTLToLowFIRRTL(mlir::PassManager &pm, firrtl::createEmitOMIRPass(opt.getOmirOutputFile())); // Always run this, required for legalization. - pm.nest().nest().addPass( + pm.nest().nestAny().addPass( firrtl::createMergeConnectionsPass( !opt.shouldDisableAggressiveMergeConnections())); if (!opt.shouldDisableOptimization()) - pm.nest().nest().addPass( + pm.nest().nestAny().addPass( firrtl::createVectorizationPass()); auto outputFilename = opt.getOutputFilename(); @@ -255,7 +254,7 @@ LogicalResult firtool::populateLowFIRRTLToHW(mlir::PassManager &pm, pm.nest().addPass(om::createVerifyObjectFieldsPass()); // Check for static asserts. - pm.nest().nest().addPass( + pm.nest().nestAny().addPass( circt::firrtl::createLintingPass()); pm.addPass(createLowerFIRRTLToHWPass(opt.shouldEnableAnnotationWarning(),