From 821d12fef47caa280f1d0bae1d8d80cc7f40fe6f Mon Sep 17 00:00:00 2001 From: Fabian Schuiki Date: Sun, 13 Oct 2024 18:57:38 -0700 Subject: [PATCH] [Arc] Improve LowerState to never produce read-after-write conflicts This is a complete rewrite of the `LowerState` pass that makes the `LegalizeStateUpdate` pass obsolete. The old implementation of `LowerState` produces `arc.model`s that still contain read-after-write conflicts. This primarily happens because the pass simply emits `arc.state_write` operations that write updated values to simulation memory for each `arc.state`, and any user of `arc.state` would use an `arc.state_read` operation to retrieve the original value of the state before any writes occurred. Memories are similar. The Arc dialect considers `arc.state_write` and `arc.memory_write` operations to be _deferred_ writes until the `LegalizeStateUpdate` pass runs, at which point they become _immediate_ since the legalization inserts the necessary temporaries to resolve the read-after-write conflicts. The previous implementation would also not handle state-to-output and state-to-side-effecting-op propagation paths correctly. When a model's eval function is called, registers are updated to their new value, and any outputs that combinatorially depend on those new values should also immediately update. Similarly, operations such as asserts or debug trackers should observe new values for states immediately after they have been written. However, since all writes are considered deferred, there is no way for `LowerState` to produce a mixture of operations that depend on a register's _old_ state (because they are used to compute a register's new state), and on a _new_ state because they are combinatorially derived values. This new implementation of `LowerState` completely avoids read-after-write conflicts. It does this by changing the way modules are lowered in two ways: **Phases:** The pass tracks in which _phase_ of the simulation lifecycle a value is needed and allows for operations to have different lowerings in different phases. An `arc.state` operation for example requires its inputs, enable, and reset to be computed based on the _old_ value they had, i.e. the value the end of the previous call to the model's eval function. The clock however has to be computed based on the _new_ value it has in the current call to eval. Therefore, the ops defining the inputs, enable, and reset are lowered in the _old_ phase, while the ops defining the clock are lowered in the _new_ phase. The `arc.state` op lowering will then write its _new_ value to simulation storage. This phase tracking allows registers to be used as the clock for other registers: since the clocks require _new_ values, registers serving as clock to others are lowered first, such that the dependent registers can immediately react to the updated clock. It also allows for module outputs and side-effecting ops based on `arc.state`s to be scheduled after the states have been updated, since they depend on the state's _new_ value. The pass also covers the situation where an operation depends on a module input and a state, and feeds into a module output as well as another state. In this situation that operation has to be lowered twice: once for the _old_ phase to serve as input to the subsequent state, and once for the _new_ phase to compute the new module output. In addition to the _old_ and _new_ phases representing the previous and current call to eval, the pass also models an _initial_ and _final_ phase. These are used for `seq.initial` and `llhd.final` ops, and in order to compute the initial values for states. If an `arc.state` op has an initial value operand it is lowered in the _initial_ phase. Similarly for the ops in `llhd.final`. The pass places all ops lowered in the initial and final phases into corresponding `arc.initial` and `arc.final` ops. At a later point we may want to generate the `*_initial`, `*_eval`, and `*_final` functions directly. **No more clock trees:** The new implementation also no longer generates `arc.clock_tree` and `arc.passthrough` operations. These were a holdover from the early days of the Arc dialect, where no eval function would be generated. Instead, the user was required to directly call clock functions. This was never able to model clocks changing at the exact same moment, or having clocks derived from registers and other combinatorial operations. Since Arc has since switched to generating an eval function that can accurately interleave the effects of different clocks, grouping ops by clock tree is no longer needed. In fact, removing the clock tree ops allows for the pass to more efficiently interleave the operations from different clock domains. The Rocket core in the circt/arc-tests repository still works with this new implementation of LowerState. In combination with the MergeIfs pass the performance stays the same. I have renamed the implementation and test files to make the git diffs easier to read. The names will be changed back in a follow-up commit. --- include/circt/Dialect/Arc/ArcPasses.h | 2 - include/circt/Dialect/Arc/ArcPasses.td | 17 +- integration_test/arcilator/JIT/dpi.mlir | 9 +- .../arcilator/JIT/initial-shift-reg.mlir | 1 + integration_test/arcilator/JIT/reg.mlir | 3 +- .../ConvertToArcs/ConvertToArcs.cpp | 5 +- lib/Dialect/Arc/ArcTypes.cpp | 8 +- lib/Dialect/Arc/Transforms/CMakeLists.txt | 4 +- .../Arc/Transforms/LegalizeStateUpdate.cpp | 597 -------- lib/Dialect/Arc/Transforms/LowerState.cpp | 959 ------------- .../Arc/Transforms/LowerStateRewrite.cpp | 1206 +++++++++++++++++ .../Arc/legalize-state-update-error.mlir | 22 - test/Dialect/Arc/legalize-state-update.mlir | 253 ---- test/Dialect/Arc/lower-state-errors.mlir | 39 - test/Dialect/Arc/lower-state.mlir | 949 +++++++------ test/arcilator/arcilator.mlir | 42 +- tools/arcilator/arcilator.cpp | 9 +- 17 files changed, 1809 insertions(+), 2316 deletions(-) delete mode 100644 lib/Dialect/Arc/Transforms/LegalizeStateUpdate.cpp delete mode 100644 lib/Dialect/Arc/Transforms/LowerState.cpp create mode 100644 lib/Dialect/Arc/Transforms/LowerStateRewrite.cpp delete mode 100644 test/Dialect/Arc/legalize-state-update-error.mlir delete mode 100644 test/Dialect/Arc/legalize-state-update.mlir delete mode 100644 test/Dialect/Arc/lower-state-errors.mlir diff --git a/include/circt/Dialect/Arc/ArcPasses.h b/include/circt/Dialect/Arc/ArcPasses.h index b31398d5898a..281a8a0be6b2 100644 --- a/include/circt/Dialect/Arc/ArcPasses.h +++ b/include/circt/Dialect/Arc/ArcPasses.h @@ -36,11 +36,9 @@ createInferMemoriesPass(const InferMemoriesOptions &options = {}); std::unique_ptr createInlineArcsPass(); std::unique_ptr createIsolateClocksPass(); std::unique_ptr createLatencyRetimingPass(); -std::unique_ptr createLegalizeStateUpdatePass(); std::unique_ptr createLowerArcsToFuncsPass(); std::unique_ptr createLowerClocksToFuncsPass(); std::unique_ptr createLowerLUTPass(); -std::unique_ptr createLowerStatePass(); std::unique_ptr createLowerVectorizationsPass( LowerVectorizationsModeEnum mode = LowerVectorizationsModeEnum::Full); std::unique_ptr createMakeTablesPass(); diff --git a/include/circt/Dialect/Arc/ArcPasses.td b/include/circt/Dialect/Arc/ArcPasses.td index 8b3005b32e39..7eefeab1bac2 100644 --- a/include/circt/Dialect/Arc/ArcPasses.td +++ b/include/circt/Dialect/Arc/ArcPasses.td @@ -163,12 +163,6 @@ def LatencyRetiming : Pass<"arc-latency-retiming", "mlir::ModuleOp"> { ]; } -def LegalizeStateUpdate : Pass<"arc-legalize-state-update", "mlir::ModuleOp"> { - let summary = "Insert temporaries such that state reads don't see writes"; - let constructor = "circt::arc::createLegalizeStateUpdatePass()"; - let dependentDialects = ["arc::ArcDialect"]; -} - def LowerArcsToFuncs : Pass<"arc-lower-arcs-to-funcs", "mlir::ModuleOp"> { let summary = "Lower arc definitions into functions"; let constructor = "circt::arc::createLowerArcsToFuncsPass()"; @@ -187,12 +181,15 @@ def LowerLUT : Pass<"arc-lower-lut", "arc::DefineOp"> { let dependentDialects = ["hw::HWDialect", "comb::CombDialect"]; } -def LowerState : Pass<"arc-lower-state", "mlir::ModuleOp"> { +def LowerStatePass : Pass<"arc-lower-state", "mlir::ModuleOp"> { let summary = "Split state into read and write ops grouped by clock tree"; - let constructor = "circt::arc::createLowerStatePass()"; let dependentDialects = [ - "arc::ArcDialect", "mlir::scf::SCFDialect", "mlir::func::FuncDialect", - "mlir::LLVM::LLVMDialect", "comb::CombDialect", "seq::SeqDialect" + "arc::ArcDialect", + "comb::CombDialect", + "mlir::LLVM::LLVMDialect", + "mlir::func::FuncDialect", + "mlir::scf::SCFDialect", + "seq::SeqDialect", ]; } diff --git a/integration_test/arcilator/JIT/dpi.mlir b/integration_test/arcilator/JIT/dpi.mlir index bdc3b32d80dc..93daf8294542 100644 --- a/integration_test/arcilator/JIT/dpi.mlir +++ b/integration_test/arcilator/JIT/dpi.mlir @@ -19,13 +19,14 @@ func.func @add_mlir_impl(%arg0: i32, %arg1: i32, %arg2: !llvm.ptr) { llvm.store %0, %arg2 : i32, !llvm.ptr return } + hw.module @arith(in %clock : i1, in %a : i32, in %b : i32, out c : i32, out d : i32) { %seq_clk = seq.to_clock %clock - %0 = sim.func.dpi.call @add_mlir(%a, %b) clock %seq_clk : (i32, i32) -> i32 %1 = sim.func.dpi.call @mul_shared(%a, %b) clock %seq_clk : (i32, i32) -> i32 hw.output %0, %1 : i32, i32 } + func.func @main() { %c2_i32 = arith.constant 2 : i32 %c3_i32 = arith.constant 3 : i32 @@ -34,18 +35,16 @@ func.func @main() { arc.sim.instantiate @arith as %arg0 { arc.sim.set_input %arg0, "a" = %c2_i32 : i32, !arc.sim.instance<@arith> arc.sim.set_input %arg0, "b" = %c3_i32 : i32, !arc.sim.instance<@arith> - arc.sim.set_input %arg0, "clock" = %one : i1, !arc.sim.instance<@arith> - arc.sim.step %arg0 : !arc.sim.instance<@arith> arc.sim.set_input %arg0, "clock" = %zero : i1, !arc.sim.instance<@arith> + arc.sim.step %arg0 : !arc.sim.instance<@arith> %0 = arc.sim.get_port %arg0, "c" : i32, !arc.sim.instance<@arith> %1 = arc.sim.get_port %arg0, "d" : i32, !arc.sim.instance<@arith> - arc.sim.emit "c", %0 : i32 arc.sim.emit "d", %1 : i32 - arc.sim.step %arg0 : !arc.sim.instance<@arith> arc.sim.set_input %arg0, "clock" = %one : i1, !arc.sim.instance<@arith> + arc.sim.step %arg0 : !arc.sim.instance<@arith> %2 = arc.sim.get_port %arg0, "c" : i32, !arc.sim.instance<@arith> %3 = arc.sim.get_port %arg0, "d" : i32, !arc.sim.instance<@arith> arc.sim.emit "c", %2 : i32 diff --git a/integration_test/arcilator/JIT/initial-shift-reg.mlir b/integration_test/arcilator/JIT/initial-shift-reg.mlir index 3724962d8a7f..7f3793fed052 100644 --- a/integration_test/arcilator/JIT/initial-shift-reg.mlir +++ b/integration_test/arcilator/JIT/initial-shift-reg.mlir @@ -26,6 +26,7 @@ module { %true = arith.constant 1 : i1 arc.sim.instantiate @shiftreg as %model { + arc.sim.step %model : !arc.sim.instance<@shiftreg> arc.sim.set_input %model, "en" = %false : i1, !arc.sim.instance<@shiftreg> arc.sim.set_input %model, "reset" = %false : i1, !arc.sim.instance<@shiftreg> arc.sim.set_input %model, "din" = %ff : i8, !arc.sim.instance<@shiftreg> diff --git a/integration_test/arcilator/JIT/reg.mlir b/integration_test/arcilator/JIT/reg.mlir index ea610276ae0f..e0845ebb2829 100644 --- a/integration_test/arcilator/JIT/reg.mlir +++ b/integration_test/arcilator/JIT/reg.mlir @@ -1,7 +1,7 @@ // RUN: arcilator %s --run --jit-entry=main | FileCheck %s // REQUIRES: arcilator-jit -// CHECK: o1 = 2 +// CHECK: o1 = 2 // CHECK-NEXT: o2 = 5 // CHECK-NEXT: o1 = 3 // CHECK-NEXT: o2 = 6 @@ -41,6 +41,7 @@ func.func @main() { %step = arith.constant 1 : index arc.sim.instantiate @counter as %model { + arc.sim.step %model : !arc.sim.instance<@counter> %init_val1 = arc.sim.get_port %model, "o1" : i8, !arc.sim.instance<@counter> %init_val2 = arc.sim.get_port %model, "o2" : i8, !arc.sim.instance<@counter> diff --git a/lib/Conversion/ConvertToArcs/ConvertToArcs.cpp b/lib/Conversion/ConvertToArcs/ConvertToArcs.cpp index 4eeffe054388..bf0ae7728634 100644 --- a/lib/Conversion/ConvertToArcs/ConvertToArcs.cpp +++ b/lib/Conversion/ConvertToArcs/ConvertToArcs.cpp @@ -25,8 +25,9 @@ using llvm::MapVector; static bool isArcBreakingOp(Operation *op) { return op->hasTrait() || - isa(op) || + isa(op) || op->getNumResults() > 1; } diff --git a/lib/Dialect/Arc/ArcTypes.cpp b/lib/Dialect/Arc/ArcTypes.cpp index 5cd342284f2d..55d4564920e3 100644 --- a/lib/Dialect/Arc/ArcTypes.cpp +++ b/lib/Dialect/Arc/ArcTypes.cpp @@ -21,11 +21,17 @@ using namespace mlir; #define GET_TYPEDEF_CLASSES #include "circt/Dialect/Arc/ArcTypes.cpp.inc" -unsigned StateType::getBitWidth() { return hw::getBitWidth(getType()); } +unsigned StateType::getBitWidth() { + if (llvm::isa(getType())) + return 1; + return hw::getBitWidth(getType()); +} LogicalResult StateType::verify(llvm::function_ref emitError, Type innerType) { + if (llvm::isa(innerType)) + return success(); if (hw::getBitWidth(innerType) < 0) return emitError() << "state type must have a known bit width; got " << innerType; diff --git a/lib/Dialect/Arc/Transforms/CMakeLists.txt b/lib/Dialect/Arc/Transforms/CMakeLists.txt index b9362e2f1ff9..5b89b0c8b2bb 100644 --- a/lib/Dialect/Arc/Transforms/CMakeLists.txt +++ b/lib/Dialect/Arc/Transforms/CMakeLists.txt @@ -9,11 +9,10 @@ add_circt_dialect_library(CIRCTArcTransforms InlineArcs.cpp IsolateClocks.cpp LatencyRetiming.cpp - LegalizeStateUpdate.cpp LowerArcsToFuncs.cpp LowerClocksToFuncs.cpp LowerLUT.cpp - LowerState.cpp + LowerStateRewrite.cpp LowerVectorizations.cpp MakeTables.cpp MergeIfs.cpp @@ -33,6 +32,7 @@ add_circt_dialect_library(CIRCTArcTransforms CIRCTComb CIRCTEmit CIRCTHW + CIRCTLLHD CIRCTOM CIRCTSV CIRCTSeq diff --git a/lib/Dialect/Arc/Transforms/LegalizeStateUpdate.cpp b/lib/Dialect/Arc/Transforms/LegalizeStateUpdate.cpp deleted file mode 100644 index b63bd5424149..000000000000 --- a/lib/Dialect/Arc/Transforms/LegalizeStateUpdate.cpp +++ /dev/null @@ -1,597 +0,0 @@ -//===- LegalizeStateUpdate.cpp --------------------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "circt/Dialect/Arc/ArcOps.h" -#include "circt/Dialect/Arc/ArcPasses.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/IR/Dominance.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" -#include "llvm/ADT/PointerIntPair.h" -#include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/Debug.h" - -#define DEBUG_TYPE "arc-legalize-state-update" - -namespace circt { -namespace arc { -#define GEN_PASS_DEF_LEGALIZESTATEUPDATE -#include "circt/Dialect/Arc/ArcPasses.h.inc" -} // namespace arc -} // namespace circt - -using namespace mlir; -using namespace circt; -using namespace arc; - -/// Check if an operation partakes in state accesses. -static bool isOpInteresting(Operation *op) { - if (isa(op)) - return false; - if (isa(op)) - return true; - if (op->getNumRegions() > 0) - return true; - return false; -} - -//===----------------------------------------------------------------------===// -// Access Analysis -//===----------------------------------------------------------------------===// - -namespace { - -enum class AccessType { Read = 0, Write = 1 }; - -/// A read or write access to a state value. -using Access = llvm::PointerIntPair; - -struct BlockAccesses; -struct OpAccesses; - -/// A block's access analysis information and graph edges. -struct BlockAccesses { - BlockAccesses(Block *block) : block(block) {} - - /// The block. - Block *const block; - /// The parent op lattice node. - OpAccesses *parent = nullptr; - /// The accesses from ops within this block to the block arguments. - SmallPtrSet argAccesses; - /// The accesses from ops within this block to values defined outside the - /// block. - SmallPtrSet aboveAccesses; -}; - -/// An operation's access analysis information and graph edges. -struct OpAccesses { - OpAccesses(Operation *op) : op(op) {} - - /// The operation. - Operation *const op; - /// The parent block lattice node. - BlockAccesses *parent = nullptr; - /// If this is a callable op, `callers` is the set of ops calling it. - SmallPtrSet callers; - /// The accesses performed by this op. - SmallPtrSet accesses; -}; - -/// An analysis that determines states read and written by operations and -/// blocks. Looks through calls and handles nested operations properly. Does not -/// follow state values returned from functions and modified by operations. -struct AccessAnalysis { - LogicalResult analyze(Operation *op); - OpAccesses *lookup(Operation *op); - BlockAccesses *lookup(Block *block); - - /// A global order assigned to state values. These allow us to not care about - /// ordering during the access analysis and only establish a determinstic - /// order once we insert additional operations later on. - DenseMap stateOrder; - - /// A symbol table cache. - SymbolTableCollection symbolTable; - -private: - llvm::SpecificBumpPtrAllocator opAlloc; - llvm::SpecificBumpPtrAllocator blockAlloc; - - DenseMap opAccesses; - DenseMap blockAccesses; - - SetVector opWorklist; - bool anyInvalidStateAccesses = false; - - // Get the node for an operation, creating one if necessary. - OpAccesses &get(Operation *op) { - auto &slot = opAccesses[op]; - if (!slot) - slot = new (opAlloc.Allocate()) OpAccesses(op); - return *slot; - } - - // Get the node for a block, creating one if necessary. - BlockAccesses &get(Block *block) { - auto &slot = blockAccesses[block]; - if (!slot) - slot = new (blockAlloc.Allocate()) BlockAccesses(block); - return *slot; - } - - // NOLINTBEGIN(misc-no-recursion) - void addOpAccess(OpAccesses &op, Access access); - void addBlockAccess(BlockAccesses &block, Access access); - // NOLINTEND(misc-no-recursion) -}; -} // namespace - -LogicalResult AccessAnalysis::analyze(Operation *op) { - LLVM_DEBUG(llvm::dbgs() << "Analyzing accesses in " << op->getName() << "\n"); - - // Create the lattice nodes for all blocks and operations. - llvm::SmallSetVector initWorklist; - initWorklist.insert(&get(op)); - while (!initWorklist.empty()) { - OpAccesses &opNode = *initWorklist.pop_back_val(); - - // First create lattice nodes for all nested blocks and operations. - for (auto ®ion : opNode.op->getRegions()) { - for (auto &block : region) { - BlockAccesses &blockNode = get(&block); - blockNode.parent = &opNode; - for (auto &subOp : block) { - if (!isOpInteresting(&subOp)) - continue; - OpAccesses &subOpNode = get(&subOp); - if (!subOp.hasTrait()) { - subOpNode.parent = &blockNode; - } - initWorklist.insert(&subOpNode); - } - } - } - - // Track the relationship between callers and callees. - if (auto callOp = dyn_cast(opNode.op)) - if (auto *calleeOp = callOp.resolveCallableInTable(&symbolTable)) - get(calleeOp).callers.insert(&opNode); - - // Create the seed accesses. - if (auto readOp = dyn_cast(opNode.op)) - addOpAccess(opNode, Access(readOp.getState(), AccessType::Read)); - else if (auto writeOp = dyn_cast(opNode.op)) - addOpAccess(opNode, Access(writeOp.getState(), AccessType::Write)); - } - LLVM_DEBUG(llvm::dbgs() << "- Prepared " << blockAccesses.size() - << " block and " << opAccesses.size() - << " op lattice nodes\n"); - LLVM_DEBUG(llvm::dbgs() << "- Worklist has " << opWorklist.size() - << " initial ops\n"); - - // Propagate accesses through calls. - while (!opWorklist.empty()) { - if (anyInvalidStateAccesses) - return failure(); - auto &opNode = *opWorklist.pop_back_val(); - if (opNode.callers.empty()) - continue; - auto calleeOp = dyn_cast(opNode.op); - if (!calleeOp) - return opNode.op->emitOpError( - "does not implement CallableOpInterface but has callers"); - LLVM_DEBUG(llvm::dbgs() << "- Updating callable " << opNode.op->getName() - << " " << opNode.op->getAttr("sym_name") << "\n"); - - auto &calleeRegion = *calleeOp.getCallableRegion(); - auto *blockNode = lookup(&calleeRegion.front()); - if (!blockNode) - continue; - auto calleeArgs = blockNode->block->getArguments(); - - for (auto *callOpNode : opNode.callers) { - LLVM_DEBUG(llvm::dbgs() << " - Updating " << *callOpNode->op << "\n"); - auto callArgs = cast(callOpNode->op).getArgOperands(); - for (auto [calleeArg, callArg] : llvm::zip(calleeArgs, callArgs)) { - if (blockNode->argAccesses.contains({calleeArg, AccessType::Read})) - addOpAccess(*callOpNode, {callArg, AccessType::Read}); - if (blockNode->argAccesses.contains({calleeArg, AccessType::Write})) - addOpAccess(*callOpNode, {callArg, AccessType::Write}); - } - } - } - - return failure(anyInvalidStateAccesses); -} - -OpAccesses *AccessAnalysis::lookup(Operation *op) { - return opAccesses.lookup(op); -} - -BlockAccesses *AccessAnalysis::lookup(Block *block) { - return blockAccesses.lookup(block); -} - -// NOLINTBEGIN(misc-no-recursion) -void AccessAnalysis::addOpAccess(OpAccesses &op, Access access) { - // We don't support state pointers flowing among ops and blocks. Check that - // the accessed state is either directly passed down through a block argument - // (no defining op), or is trivially a local state allocation. - auto *defOp = access.getPointer().getDefiningOp(); - if (defOp && !isa(defOp)) { - auto d = op.op->emitOpError("accesses non-trivial state value defined by `") - << defOp->getName() - << "`; only block arguments and `arc.alloc_state` results are " - "supported"; - d.attachNote(defOp->getLoc()) << "state defined here"; - anyInvalidStateAccesses = true; - } - - // HACK: Do not propagate accesses outside of `arc.passthrough` to prevent - // reads from being legalized. Ideally we'd be able to more precisely specify - // on read ops whether they should read the initial or the final value. - if (isa(op.op)) - return; - - // Propagate to the parent block and operation if the access escapes the block - // or targets a block argument. - if (op.accesses.insert(access).second && op.parent) { - stateOrder.insert({access.getPointer(), stateOrder.size()}); - addBlockAccess(*op.parent, access); - } -} - -void AccessAnalysis::addBlockAccess(BlockAccesses &block, Access access) { - Value value = access.getPointer(); - - // If the accessed value is defined outside the block, add it to the set of - // outside accesses. - if (value.getParentBlock() != block.block) { - if (block.aboveAccesses.insert(access).second) - addOpAccess(*block.parent, access); - return; - } - - // If the accessed value is defined within the block, and it is a block - // argument, add it to the list of block argument accesses. - if (auto blockArg = dyn_cast(value)) { - assert(blockArg.getOwner() == block.block); - if (!block.argAccesses.insert(access).second) - return; - - // Adding block argument accesses affects calls to the surrounding ops. Add - // the op to the worklist such that the access can propagate to callers. - opWorklist.insert(block.parent); - } -} -// NOLINTEND(misc-no-recursion) - -//===----------------------------------------------------------------------===// -// Legalization -//===----------------------------------------------------------------------===// - -namespace { -struct Legalizer { - Legalizer(AccessAnalysis &analysis) : analysis(analysis) {} - LogicalResult run(MutableArrayRef regions); - LogicalResult visitBlock(Block *block); - - AccessAnalysis &analysis; - - unsigned numLegalizedWrites = 0; - unsigned numUpdatedReads = 0; - - /// A mapping from pre-existing states to temporary states for read - /// operations, created during legalization to remove read-after-write - /// hazards. - DenseMap legalizedStates; -}; -} // namespace - -LogicalResult Legalizer::run(MutableArrayRef regions) { - for (auto ®ion : regions) - for (auto &block : region) - if (failed(visitBlock(&block))) - return failure(); - assert(legalizedStates.empty() && "should be balanced within block"); - return success(); -} - -LogicalResult Legalizer::visitBlock(Block *block) { - // In a first reverse pass over the block, find the first write that occurs - // before the last read of a state, if any. - SmallPtrSet readStates; - DenseMap illegallyWrittenStates; - for (Operation &op : llvm::reverse(*block)) { - const auto *accesses = analysis.lookup(&op); - if (!accesses) - continue; - - // Determine the states written by this op for which we have already seen a - // read earlier. These writes need to be legalized. - SmallVector affectedStates; - for (auto access : accesses->accesses) - if (access.getInt() == AccessType::Write) - if (readStates.contains(access.getPointer())) - illegallyWrittenStates[access.getPointer()] = &op; - - // Determine the states read by this op. This comes after handling of the - // writes, such that a block that contains both reads and writes to a state - // doesn't mark itself as illegal. Instead, we will descend into that block - // further down and do a more fine-grained legalization. - for (auto access : accesses->accesses) - if (access.getInt() == AccessType::Read) - readStates.insert(access.getPointer()); - } - - // Create a mapping from operations that create a read-after-write hazard to - // the states that they modify. Don't consider states that have already been - // legalized. This is important since we may have already created a temporary - // in a parent block which we can just reuse. - DenseMap> illegalWrites; - for (auto [state, op] : illegallyWrittenStates) - if (!legalizedStates.count(state)) - illegalWrites[op].push_back(state); - - // In a second forward pass over the block, insert the necessary temporary - // state to legalize the writes and recur into subblocks while providing the - // necessary rewrites. - SmallVector locallyLegalizedStates; - - auto handleIllegalWrites = - [&](Operation *op, SmallVector &states) -> LogicalResult { - LLVM_DEBUG(llvm::dbgs() << "Visiting illegal " << op->getName() << "\n"); - - // Sort the states we need to legalize by a determinstic order established - // during the access analysis. Without this the exact order in which states - // were moved into a temporary would be non-deterministic. - llvm::sort(states, [&](Value a, Value b) { - return analysis.stateOrder.lookup(a) < analysis.stateOrder.lookup(b); - }); - - // Legalize each state individually. - for (auto state : states) { - LLVM_DEBUG(llvm::dbgs() << "- Legalizing " << state << "\n"); - - // HACK: This is ugly, but we need a storage reference to allocate a state - // into. Ideally we'd materialize this later on, but the current impl of - // the alloc op requires a storage immediately. So try to find one. - auto storage = TypeSwitch(state.getDefiningOp()) - .Case( - [&](auto allocOp) { return allocOp.getStorage(); }) - .Default([](auto) { return Value{}; }); - if (!storage) { - mlir::emitError( - state.getLoc(), - "cannot find storage pointer to allocate temporary into"); - return failure(); - } - - // Allocate a temporary state, read the current value of the state we are - // legalizing, and write it to the temporary. - ++numLegalizedWrites; - ImplicitLocOpBuilder builder(state.getLoc(), op); - auto tmpState = - builder.create(state.getType(), storage, nullptr); - auto stateValue = builder.create(state); - builder.create(tmpState, stateValue, Value{}); - locallyLegalizedStates.push_back(state); - legalizedStates.insert({state, tmpState}); - } - return success(); - }; - - for (Operation &op : *block) { - if (isOpInteresting(&op)) { - if (auto it = illegalWrites.find(&op); it != illegalWrites.end()) - if (failed(handleIllegalWrites(&op, it->second))) - return failure(); - } - // BUG: This is insufficient. Actually only reads should have their state - // updated, since we want writes to still affect the original state. This - // works for `state_read`, but in the case of a function that both reads and - // writes a state we only have a single operand to change but we would need - // one for reads and one for writes instead. - // HACKY FIX: Assume that there is ever only a single write to a state. In - // that case it is safe to assume that when an op is marked as writing a - // state it wants the original state, not the temporary one for reads. - const auto *accesses = analysis.lookup(&op); - for (auto &operand : op.getOpOperands()) { - if (accesses && - accesses->accesses.contains({operand.get(), AccessType::Read}) && - accesses->accesses.contains({operand.get(), AccessType::Write})) { - auto d = op.emitWarning("operation reads and writes state; " - "legalization may be insufficient"); - d.attachNote() - << "state update legalization does not properly handle operations " - "that both read and write states at the same time; runtime data " - "races between the read and write behavior are possible"; - d.attachNote(operand.get().getLoc()) << "state defined here:"; - } - if (!accesses || - !accesses->accesses.contains({operand.get(), AccessType::Write})) { - if (auto tmpState = legalizedStates.lookup(operand.get())) { - operand.set(tmpState); - ++numUpdatedReads; - } - } - } - for (auto ®ion : op.getRegions()) - for (auto &block : region) - if (failed(visitBlock(&block))) - return failure(); - } - - // Since we're leaving this block's scope, remove all the locally-legalized - // states which are no longer accessible outside. - for (auto state : locallyLegalizedStates) - legalizedStates.erase(state); - return success(); -} - -static LogicalResult getAncestorOpsInCommonDominatorBlock( - Operation *write, Operation **writeAncestor, Operation *read, - Operation **readAncestor, DominanceInfo *domInfo) { - Block *commonDominator = - domInfo->findNearestCommonDominator(write->getBlock(), read->getBlock()); - if (!commonDominator) - return write->emitOpError( - "cannot find a common dominator block with all read operations"); - - // Path from writeOp to commmon dominator must only contain IfOps with no - // return values - Operation *writeParent = write; - while (writeParent->getBlock() != commonDominator) { - if (!isa(writeParent->getParentOp())) - return write->emitOpError("memory write operations in arbitrarily nested " - "regions not supported"); - writeParent = writeParent->getParentOp(); - } - Operation *readParent = read; - while (readParent->getBlock() != commonDominator) - readParent = readParent->getParentOp(); - - *writeAncestor = writeParent; - *readAncestor = readParent; - return success(); -} - -static LogicalResult -moveMemoryWritesAfterLastRead(Region ®ion, const DenseSet &memories, - DominanceInfo *domInfo) { - // Collect memory values and their reads - DenseMap> readOps; - auto result = region.walk([&](Operation *op) { - if (isa(op)) - return WalkResult::advance(); - SmallVector memoriesReadFrom; - if (auto readOp = dyn_cast(op)) { - memoriesReadFrom.push_back(readOp.getMemory()); - } else { - for (auto operand : op->getOperands()) - if (isa(operand.getType())) - memoriesReadFrom.push_back(operand); - } - for (auto memVal : memoriesReadFrom) { - if (!memories.contains(memVal)) - return op->emitOpError("uses memory value not directly defined by a " - "arc.alloc_memory operation"), - WalkResult::interrupt(); - readOps[memVal].insert(op); - } - - return WalkResult::advance(); - }); - - if (result.wasInterrupted()) - return failure(); - - // Collect all writes - SmallVector writes; - region.walk([&](MemoryWriteOp writeOp) { writes.push_back(writeOp); }); - - // Move the writes - for (auto writeOp : writes) { - if (!memories.contains(writeOp.getMemory())) - return writeOp->emitOpError("uses memory value not directly defined by a " - "arc.alloc_memory operation"); - for (auto *readOp : readOps[writeOp.getMemory()]) { - // (1) If the last read and the write are in the same block, just move the - // write after the read. - // (2) If the write is directly in the clock tree region and the last read - // in some nested region, move the write after the operation with the - // nested region. (3) If the write is nested in if-statements (arbitrarily - // deep) without return value, move the whole if operation after the last - // read or the operation that defines the region if the read is inside a - // nested region. (4) Number (3) may move more memory operations with the - // write op, thus messing up the order of previously moved memory writes, - // we check in a second walk-through if that is the case and just emit an - // error for now. We could instead move reads in a parent region, split if - // operations such that the memory write has its own, etc. Alternatively, - // rewrite this to insert temporaries which is more difficult for memories - // than simple states because the memory addresses have to be considered - // (we cannot just copy the whole memory each time). - Operation *readAncestor, *writeAncestor; - if (failed(getAncestorOpsInCommonDominatorBlock( - writeOp, &writeAncestor, readOp, &readAncestor, domInfo))) - return failure(); - // FIXME: the 'isBeforeInBlock` + 'moveAfter' compination can be - // computationally very expensive. - if (writeAncestor->isBeforeInBlock(readAncestor)) - writeAncestor->moveAfter(readAncestor); - } - } - - // Double check that all writes happen after all reads to the same memory. - for (auto writeOp : writes) { - for (auto *readOp : readOps[writeOp.getMemory()]) { - Operation *readAncestor, *writeAncestor; - if (failed(getAncestorOpsInCommonDominatorBlock( - writeOp, &writeAncestor, readOp, &readAncestor, domInfo))) - return failure(); - - if (writeAncestor->isBeforeInBlock(readAncestor)) - return writeOp - ->emitOpError("could not be moved to be after all reads to " - "the same memory") - .attachNote(readOp->getLoc()) - << "could not be moved after this read"; - } - } - - return success(); -} - -//===----------------------------------------------------------------------===// -// Pass Infrastructure -//===----------------------------------------------------------------------===// - -namespace { -struct LegalizeStateUpdatePass - : public arc::impl::LegalizeStateUpdateBase { - LegalizeStateUpdatePass() = default; - LegalizeStateUpdatePass(const LegalizeStateUpdatePass &pass) - : LegalizeStateUpdatePass() {} - - void runOnOperation() override; - - Statistic numLegalizedWrites{ - this, "legalized-writes", - "Writes that required temporary state for later reads"}; - Statistic numUpdatedReads{this, "updated-reads", "Reads that were updated"}; -}; -} // namespace - -void LegalizeStateUpdatePass::runOnOperation() { - auto module = getOperation(); - auto *domInfo = &getAnalysis(); - - for (auto model : module.getOps()) { - DenseSet memories; - for (auto memOp : model.getOps()) - memories.insert(memOp.getResult()); - for (auto ct : model.getOps()) - if (failed( - moveMemoryWritesAfterLastRead(ct.getBody(), memories, domInfo))) - return signalPassFailure(); - } - - AccessAnalysis analysis; - if (failed(analysis.analyze(module))) - return signalPassFailure(); - - Legalizer legalizer(analysis); - if (failed(legalizer.run(module->getRegions()))) - return signalPassFailure(); - numLegalizedWrites += legalizer.numLegalizedWrites; - numUpdatedReads += legalizer.numUpdatedReads; -} - -std::unique_ptr arc::createLegalizeStateUpdatePass() { - return std::make_unique(); -} diff --git a/lib/Dialect/Arc/Transforms/LowerState.cpp b/lib/Dialect/Arc/Transforms/LowerState.cpp deleted file mode 100644 index 8784c55ab2c8..000000000000 --- a/lib/Dialect/Arc/Transforms/LowerState.cpp +++ /dev/null @@ -1,959 +0,0 @@ -//===- LowerState.cpp ---------------------------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "circt/Dialect/Arc/ArcOps.h" -#include "circt/Dialect/Arc/ArcPasses.h" -#include "circt/Dialect/Comb/CombDialect.h" -#include "circt/Dialect/Comb/CombOps.h" -#include "circt/Dialect/HW/HWOps.h" -#include "circt/Dialect/Seq/SeqOps.h" -#include "circt/Dialect/Sim/SimOps.h" -#include "circt/Support/BackedgeBuilder.h" -#include "mlir/Analysis/TopologicalSortUtils.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/IR/IRMapping.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" -#include "mlir/IR/SymbolTable.h" -#include "mlir/Pass/Pass.h" -#include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/Debug.h" - -#define DEBUG_TYPE "arc-lower-state" - -namespace circt { -namespace arc { -#define GEN_PASS_DEF_LOWERSTATE -#include "circt/Dialect/Arc/ArcPasses.h.inc" -} // namespace arc -} // namespace circt - -using namespace circt; -using namespace arc; -using namespace hw; -using namespace mlir; -using llvm::SmallDenseSet; - -//===----------------------------------------------------------------------===// -// Data Structures -//===----------------------------------------------------------------------===// - -namespace { - -/// Statistics gathered throughout the execution of this pass. -struct Statistics { - Pass *parent; - Statistics(Pass *parent) : parent(parent) {} - using Statistic = Pass::Statistic; - - Statistic matOpsMoved{parent, "mat-ops-moved", - "Ops moved during value materialization"}; - Statistic matOpsCloned{parent, "mat-ops-cloned", - "Ops cloned during value materialization"}; - Statistic opsPruned{parent, "ops-pruned", "Ops removed as dead code"}; -}; - -/// Lowering info associated with a single primary clock. -struct ClockLowering { - /// The root clock this lowering is for. - Value clock; - /// A `ClockTreeOp` or `PassThroughOp` or `InitialOp`. - Operation *treeOp; - /// Pass statistics. - Statistics &stats; - OpBuilder builder; - /// A mapping from values outside the clock tree to their materialize form - /// inside the clock tree. - IRMapping materializedValues; - /// A cache of AND gates created for aggregating enable conditions. - DenseMap, Value> andCache; - /// A cache of OR gates created for aggregating enable conditions. - DenseMap, Value> orCache; - - // Prevent accidental construction and copying - ClockLowering() = delete; - ClockLowering(const ClockLowering &other) = delete; - - ClockLowering(Value clock, Operation *treeOp, Statistics &stats) - : clock(clock), treeOp(treeOp), stats(stats), builder(treeOp) { - assert((isa(treeOp))); - builder.setInsertionPointToStart(&treeOp->getRegion(0).front()); - } - - Value materializeValue(Value value); - Value getOrCreateAnd(Value lhs, Value rhs, Location loc); - Value getOrCreateOr(Value lhs, Value rhs, Location loc); - - bool isInitialTree() const { return isa(treeOp); } -}; - -struct GatedClockLowering { - /// Lowering info of the primary clock. - ClockLowering &clock; - /// An optional enable condition of the primary clock. May be null. - Value enable; -}; - -/// State lowering for a single `HWModuleOp`. -struct ModuleLowering { - HWModuleOp moduleOp; - /// Pass statistics. - Statistics &stats; - MLIRContext *context; - DenseMap> clockLowerings; - DenseMap gatedClockLowerings; - std::unique_ptr initialLowering; - Value storageArg; - OpBuilder clockBuilder; - OpBuilder stateBuilder; - - ModuleLowering(HWModuleOp moduleOp, Statistics &stats) - : moduleOp(moduleOp), stats(stats), context(moduleOp.getContext()), - clockBuilder(moduleOp), stateBuilder(moduleOp) {} - - GatedClockLowering getOrCreateClockLowering(Value clock); - ClockLowering &getOrCreatePassThrough(); - ClockLowering &getInitial(); - Value replaceValueWithStateRead(Value value, Value state); - - void addStorageArg(); - LogicalResult lowerPrimaryInputs(); - LogicalResult lowerPrimaryOutputs(); - LogicalResult lowerStates(); - LogicalResult lowerInitials(); - template - LogicalResult lowerStateLike(Operation *op, Value clock, Value enable, - Value reset, ArrayRef inputs, - FlatSymbolRefAttr callee, - ArrayRef initialValues = {}); - LogicalResult lowerState(StateOp stateOp); - LogicalResult lowerState(sim::DPICallOp dpiCallOp); - LogicalResult lowerState(MemoryOp memOp); - LogicalResult lowerState(MemoryWritePortOp memWriteOp); - LogicalResult lowerState(TapOp tapOp); - LogicalResult lowerExtModules(SymbolTable &symtbl); - LogicalResult lowerExtModule(InstanceOp instOp); - - LogicalResult cleanup(); -}; -} // namespace - -//===----------------------------------------------------------------------===// -// Clock Lowering -//===----------------------------------------------------------------------===// - -static bool shouldMaterialize(Operation *op) { - // Don't materialize arc uses with latency >0, since we handle these in a - // second pass once all other operations have been moved to their respective - // clock trees. - return !isa(op); -} - -static bool shouldMaterialize(Value value) { - assert(value); - - // Block arguments are just used as they are. - auto *op = value.getDefiningOp(); - if (!op) - return false; - - return shouldMaterialize(op); -} - -static bool canBeMaterializedInInitializer(Operation *op) { - if (!op) - return false; - if (op->hasTrait()) - return true; - if (isa(op->getDialect())) - return true; - if (isa(op)) - return true; - // TODO: There are some other ops we probably want to allow - return false; -} - -/// Materialize a value within this clock tree. This clones or moves all -/// operations required to produce this value inside the clock tree. -Value ClockLowering::materializeValue(Value value) { - if (!value) - return {}; - if (auto mapped = materializedValues.lookupOrNull(value)) - return mapped; - if (auto fromImmutable = value.getDefiningOp()) - // Immutable value is pre-materialized so directly lookup the input. - return materializedValues.lookup(fromImmutable.getInput()); - - if (!shouldMaterialize(value)) - return value; - - struct WorkItem { - Operation *op; - SmallVector operands; - WorkItem(Operation *op) : op(op) {} - }; - - SmallPtrSet seen; - SmallVector worklist; - - auto addToWorklist = [&](Operation *outerOp) { - SmallDenseSet seenOperands; - auto &workItem = worklist.emplace_back(outerOp); - outerOp->walk([&](Operation *innerOp) { - for (auto operand : innerOp->getOperands()) { - // Skip operands that are defined within the operation itself. - if (!operand.getParentBlock()->getParentOp()->isProperAncestor(outerOp)) - continue; - - // Skip operands that we have already seen. - if (!seenOperands.insert(operand).second) - continue; - - // Skip operands that we have already materialized or that should not - // be materialized at all. - if (materializedValues.contains(operand) || !shouldMaterialize(operand)) - continue; - - workItem.operands.push_back(operand); - } - }); - }; - - seen.insert(value.getDefiningOp()); - addToWorklist(value.getDefiningOp()); - - while (!worklist.empty()) { - auto &workItem = worklist.back(); - if (isInitialTree() && !canBeMaterializedInInitializer(workItem.op)) { - workItem.op->emitError("Value cannot be used in initializer."); - return {}; - } - if (!workItem.operands.empty()) { - auto operand = workItem.operands.pop_back_val(); - if (materializedValues.contains(operand) || !shouldMaterialize(operand)) - continue; - auto *defOp = operand.getDefiningOp(); - if (!seen.insert(defOp).second) { - defOp->emitError("combinational loop detected"); - return {}; - } - addToWorklist(defOp); - } else { - builder.clone(*workItem.op, materializedValues); - seen.erase(workItem.op); - worklist.pop_back(); - } - } - - return materializedValues.lookup(value); -} - -/// Create an AND gate if none with the given operands already exists. Note that -/// the operands may be null, in which case the function will return the -/// non-null operand, or null if both operands are null. -Value ClockLowering::getOrCreateAnd(Value lhs, Value rhs, Location loc) { - if (!lhs) - return rhs; - if (!rhs) - return lhs; - auto &slot = andCache[std::make_pair(lhs, rhs)]; - if (!slot) - slot = builder.create(loc, lhs, rhs); - return slot; -} - -/// Create an OR gate if none with the given operands already exists. Note that -/// the operands may be null, in which case the function will return the -/// non-null operand, or null if both operands are null. -Value ClockLowering::getOrCreateOr(Value lhs, Value rhs, Location loc) { - if (!lhs) - return rhs; - if (!rhs) - return lhs; - auto &slot = orCache[std::make_pair(lhs, rhs)]; - if (!slot) - slot = builder.create(loc, lhs, rhs); - return slot; -} - -//===----------------------------------------------------------------------===// -// Module Lowering -//===----------------------------------------------------------------------===// - -GatedClockLowering ModuleLowering::getOrCreateClockLowering(Value clock) { - // Look through clock gates. - if (auto ckgOp = clock.getDefiningOp()) { - // Reuse the existing lowering for this clock gate if possible. - if (auto it = gatedClockLowerings.find(clock); - it != gatedClockLowerings.end()) - return it->second; - - // Get the lowering for the parent clock gate's input clock. This will give - // us the clock tree to emit things into, alongside the compound enable - // condition of all the clock gates along the way to the primary clock. All - // we have to do is to add this clock gate's condition to that list. - auto info = getOrCreateClockLowering(ckgOp.getInput()); - auto ckgEnable = info.clock.materializeValue(ckgOp.getEnable()); - auto ckgTestEnable = info.clock.materializeValue(ckgOp.getTestEnable()); - info.enable = info.clock.getOrCreateAnd( - info.enable, - info.clock.getOrCreateOr(ckgEnable, ckgTestEnable, ckgOp.getLoc()), - ckgOp.getLoc()); - gatedClockLowerings.insert({clock, info}); - return info; - } - - // Create the `ClockTreeOp` that corresponds to this ungated clock. - auto &slot = clockLowerings[clock]; - if (!slot) { - auto newClock = - clockBuilder.createOrFold(clock.getLoc(), clock); - - // Detect a rising edge on the clock, as `(old != new) & new`. - auto oldClockStorage = stateBuilder.create( - clock.getLoc(), StateType::get(stateBuilder.getI1Type()), storageArg); - auto oldClock = - clockBuilder.create(clock.getLoc(), oldClockStorage); - clockBuilder.create(clock.getLoc(), oldClockStorage, newClock, - Value{}); - Value trigger = clockBuilder.create( - clock.getLoc(), comb::ICmpPredicate::ne, oldClock, newClock); - trigger = - clockBuilder.create(clock.getLoc(), trigger, newClock); - - // Create the tree op. - auto treeOp = clockBuilder.create(clock.getLoc(), trigger); - treeOp.getBody().emplaceBlock(); - slot = std::make_unique(clock, treeOp, stats); - } - return GatedClockLowering{*slot, Value{}}; -} - -ClockLowering &ModuleLowering::getOrCreatePassThrough() { - auto &slot = clockLowerings[Value{}]; - if (!slot) { - auto treeOp = clockBuilder.create(moduleOp.getLoc()); - treeOp.getBody().emplaceBlock(); - slot = std::make_unique(Value{}, treeOp, stats); - } - return *slot; -} - -ClockLowering &ModuleLowering::getInitial() { - assert(!!initialLowering && "Initial tree op should have been constructed"); - return *initialLowering; -} - -/// Replace all uses of a value with a `StateReadOp` on a state. -Value ModuleLowering::replaceValueWithStateRead(Value value, Value state) { - OpBuilder builder(state.getContext()); - builder.setInsertionPointAfterValue(state); - Value readOp = builder.create(value.getLoc(), state); - if (isa(value.getType())) - readOp = builder.createOrFold(value.getLoc(), readOp); - value.replaceAllUsesWith(readOp); - return readOp; -} - -/// Add the global state as an argument to the module's body block. -void ModuleLowering::addStorageArg() { - assert(!storageArg); - storageArg = moduleOp.getBodyBlock()->addArgument( - StorageType::get(context, {}), moduleOp.getLoc()); -} - -/// Lower the primary inputs of the module to dedicated ops that allocate the -/// inputs in the model's storage. -LogicalResult ModuleLowering::lowerPrimaryInputs() { - for (auto blockArg : moduleOp.getBodyBlock()->getArguments()) { - if (blockArg == storageArg) - continue; - auto name = moduleOp.getArgName(blockArg.getArgNumber()); - auto argTy = blockArg.getType(); - IntegerType innerTy; - if (isa(argTy)) { - innerTy = IntegerType::get(context, 1); - } else if (auto intType = dyn_cast(argTy)) { - innerTy = intType; - } else { - return mlir::emitError(blockArg.getLoc(), "input ") - << name << " is of non-integer type " << blockArg.getType(); - } - auto state = stateBuilder.create( - blockArg.getLoc(), StateType::get(innerTy), name, storageArg); - replaceValueWithStateRead(blockArg, state); - } - return success(); -} - -/// Lower the primary outputs of the module to dedicated ops that allocate the -/// outputs in the model's storage. -LogicalResult ModuleLowering::lowerPrimaryOutputs() { - auto outputOp = cast(moduleOp.getBodyBlock()->getTerminator()); - if (outputOp.getNumOperands() > 0) { - auto outputOperands = SmallVector(outputOp.getOperands()); - outputOp->dropAllReferences(); - auto &passThrough = getOrCreatePassThrough(); - for (auto [outputArg, name] : - llvm::zip(outputOperands, moduleOp.getOutputNames())) { - IntegerType innerTy; - if (isa(outputArg.getType())) { - innerTy = IntegerType::get(context, 1); - } else if (auto intType = dyn_cast(outputArg.getType())) { - innerTy = intType; - } else { - return mlir::emitError(outputOp.getLoc(), "output ") - << name << " is of non-integer type " << outputArg.getType(); - } - auto value = passThrough.materializeValue(outputArg); - auto state = stateBuilder.create( - outputOp.getLoc(), StateType::get(innerTy), cast(name), - storageArg); - if (isa(value.getType())) - value = passThrough.builder.createOrFold( - outputOp.getLoc(), value); - passThrough.builder.create(outputOp.getLoc(), state, value, - Value{}); - } - } - outputOp.erase(); - return success(); -} - -LogicalResult ModuleLowering::lowerInitials() { - // Merge all seq.initial ops into a single seq.initial op. - auto result = circt::seq::mergeInitialOps(moduleOp.getBodyBlock()); - if (failed(result)) - return moduleOp.emitError() << "initial ops cannot be topologically sorted"; - - auto initialOp = *result; - if (!initialOp) // There is no seq.initial op. - return success(); - - // Move the operations of the merged initial op into the builder's block. - auto terminator = - cast(initialOp.getBodyBlock()->getTerminator()); - getInitial().builder.getBlock()->getOperations().splice( - getInitial().builder.getBlock()->begin(), - initialOp.getBodyBlock()->getOperations()); - - // Map seq.initial results to their corresponding operands. - for (auto [result, operand] : - llvm::zip(initialOp.getResults(), terminator.getOperands())) - getInitial().materializedValues.map(result, operand); - terminator.erase(); - - return success(); -} - -LogicalResult ModuleLowering::lowerStates() { - SmallVector opsToLower; - for (auto &op : *moduleOp.getBodyBlock()) - if (isa(&op)) - opsToLower.push_back(&op); - - for (auto *op : opsToLower) { - LLVM_DEBUG(llvm::dbgs() << "- Lowering " << *op << "\n"); - auto result = - TypeSwitch(op) - .Case( - [&](auto op) { return lowerState(op); }) - .Default(success()); - if (failed(result)) - return failure(); - } - return success(); -} - -template -LogicalResult ModuleLowering::lowerStateLike( - Operation *stateOp, Value stateClock, Value stateEnable, Value stateReset, - ArrayRef stateInputs, FlatSymbolRefAttr callee, - ArrayRef initialValues) { - // Grab all operands from the state op at the callsite and make it drop all - // its references. This allows `materializeValue` to move an operation if this - // state was the last user. - - // Get the clock tree and enable condition for this state's clock. If this arc - // carries an explicit enable condition, fold that into the enable provided by - // the clock gates in the arc's clock tree. - auto info = getOrCreateClockLowering(stateClock); - info.enable = info.clock.getOrCreateAnd( - info.enable, info.clock.materializeValue(stateEnable), stateOp->getLoc()); - - // Allocate the necessary state within the model. - SmallVector allocatedStates; - for (unsigned stateIdx = 0; stateIdx < stateOp->getNumResults(); ++stateIdx) { - auto type = stateOp->getResult(stateIdx).getType(); - auto intType = dyn_cast(type); - if (!intType) - return stateOp->emitOpError("result ") - << stateIdx << " has non-integer type " << type - << "; only integer types are supported"; - auto stateType = StateType::get(intType); - auto state = stateBuilder.create(stateOp->getLoc(), stateType, - storageArg); - if (auto names = stateOp->getAttrOfType("names")) - state->setAttr("name", names[stateIdx]); - allocatedStates.push_back(state); - } - - // Create a copy of the arc use with latency zero. This will effectively be - // the computation of the arc's transfer function, while the latency is - // implemented through read and write functions. - SmallVector materializedOperands; - materializedOperands.reserve(stateInputs.size()); - - for (auto input : stateInputs) - materializedOperands.push_back(info.clock.materializeValue(input)); - - OpBuilder nonResetBuilder = info.clock.builder; - if (stateReset) { - auto materializedReset = info.clock.materializeValue(stateReset); - auto ifOp = info.clock.builder.create(stateOp->getLoc(), - materializedReset, true); - - for (auto [alloc, resTy] : - llvm::zip(allocatedStates, stateOp->getResultTypes())) { - if (!isa(resTy)) - stateOp->emitOpError("Non-integer result not supported yet!"); - - auto thenBuilder = ifOp.getThenBodyBuilder(); - Value constZero = - thenBuilder.create(stateOp->getLoc(), resTy, 0); - thenBuilder.create(stateOp->getLoc(), alloc, constZero, - Value()); - } - nonResetBuilder = ifOp.getElseBodyBuilder(); - } - - if (!initialValues.empty()) { - assert(initialValues.size() == allocatedStates.size() && - "Unexpected number of initializers"); - auto &initialTree = getInitial(); - for (auto [alloc, init] : llvm::zip(allocatedStates, initialValues)) { - // TODO: Can we get away without materialization? - auto matierializedInit = initialTree.materializeValue(init); - if (!matierializedInit) - return failure(); - initialTree.builder.create(stateOp->getLoc(), alloc, - matierializedInit, Value()); - } - } - - stateOp->dropAllReferences(); - - auto newStateOp = nonResetBuilder.create( - stateOp->getLoc(), stateOp->getResultTypes(), callee, - materializedOperands); - - // Create the write ops that write the result of the transfer function to the - // allocated state storage. - for (auto [alloc, result] : - llvm::zip(allocatedStates, newStateOp.getResults())) - nonResetBuilder.create(stateOp->getLoc(), alloc, result, - info.enable); - - // Replace all uses of the arc with reads from the allocated state. - for (auto [alloc, result] : llvm::zip(allocatedStates, stateOp->getResults())) - replaceValueWithStateRead(result, alloc); - stateOp->erase(); - return success(); -} - -LogicalResult ModuleLowering::lowerState(StateOp stateOp) { - // We don't support arcs beyond latency 1 yet. These should be easy to add in - // the future though. - if (stateOp.getLatency() > 1) - return stateOp.emitError("state with latency > 1 not supported"); - - auto stateInputs = SmallVector(stateOp.getInputs()); - auto stateInitializers = SmallVector(stateOp.getInitials()); - - return lowerStateLike( - stateOp, stateOp.getClock(), stateOp.getEnable(), stateOp.getReset(), - stateInputs, stateOp.getArcAttr(), stateInitializers); -} - -LogicalResult ModuleLowering::lowerState(sim::DPICallOp callOp) { - // Clocked call op can be considered as arc state with single latency. - auto stateClock = callOp.getClock(); - if (!stateClock) - return callOp.emitError("unclocked DPI call not implemented yet"); - - auto stateInputs = SmallVector(callOp.getInputs()); - - return lowerStateLike(callOp, stateClock, callOp.getEnable(), - Value(), stateInputs, - callOp.getCalleeAttr()); -} - -LogicalResult ModuleLowering::lowerState(MemoryOp memOp) { - auto allocMemOp = stateBuilder.create( - memOp.getLoc(), memOp.getType(), storageArg, memOp->getAttrs()); - memOp.replaceAllUsesWith(allocMemOp.getResult()); - memOp.erase(); - return success(); -} - -LogicalResult ModuleLowering::lowerState(MemoryWritePortOp memWriteOp) { - if (memWriteOp.getLatency() > 1) - return memWriteOp->emitOpError("latencies > 1 not supported yet"); - - // Get the clock tree and enable condition for this write port's clock. If the - // port carries an explicit enable condition, fold that into the enable - // provided by the clock gates in the port's clock tree. - auto info = getOrCreateClockLowering(memWriteOp.getClock()); - - // Grab all operands from the op and make it drop all its references. This - // allows `materializeValue` to move an operation if this op was the last - // user. - auto writeMemory = memWriteOp.getMemory(); - auto writeInputs = SmallVector(memWriteOp.getInputs()); - auto arcResultTypes = memWriteOp.getArcResultTypes(); - memWriteOp->dropAllReferences(); - - SmallVector materializedInputs; - for (auto input : writeInputs) - materializedInputs.push_back(info.clock.materializeValue(input)); - ValueRange results = - info.clock.builder - .create(memWriteOp.getLoc(), arcResultTypes, - memWriteOp.getArc(), materializedInputs) - ->getResults(); - - auto enable = - memWriteOp.getEnable() ? results[memWriteOp.getEnableIdx()] : Value(); - info.enable = - info.clock.getOrCreateAnd(info.enable, enable, memWriteOp.getLoc()); - - // Materialize the operands for the write op within the surrounding clock - // tree. - auto address = results[memWriteOp.getAddressIdx()]; - auto data = results[memWriteOp.getDataIdx()]; - if (memWriteOp.getMask()) { - Value mask = results[memWriteOp.getMaskIdx(static_cast(enable))]; - Value oldData = info.clock.builder.create( - mask.getLoc(), data.getType(), writeMemory, address); - Value allOnes = info.clock.builder.create( - mask.getLoc(), oldData.getType(), -1); - Value negatedMask = info.clock.builder.create( - mask.getLoc(), mask, allOnes, true); - Value maskedOldData = info.clock.builder.create( - mask.getLoc(), negatedMask, oldData, true); - Value maskedNewData = - info.clock.builder.create(mask.getLoc(), mask, data, true); - data = info.clock.builder.create(mask.getLoc(), maskedOldData, - maskedNewData, true); - } - info.clock.builder.create(memWriteOp.getLoc(), writeMemory, - address, info.enable, data); - memWriteOp.erase(); - return success(); -} - -// Add state for taps into the passthrough block. -LogicalResult ModuleLowering::lowerState(TapOp tapOp) { - auto intType = dyn_cast(tapOp.getValue().getType()); - if (!intType) - return mlir::emitError(tapOp.getLoc(), "tapped value ") - << tapOp.getNameAttr() << " is of non-integer type " - << tapOp.getValue().getType(); - - // Grab what we need from the tap op and then make it drop all its references. - // This will allow `materializeValue` to move ops instead of cloning them. - auto tapValue = tapOp.getValue(); - tapOp->dropAllReferences(); - - auto &passThrough = getOrCreatePassThrough(); - auto materializedValue = passThrough.materializeValue(tapValue); - auto state = stateBuilder.create( - tapOp.getLoc(), StateType::get(intType), storageArg, true); - state->setAttr("name", tapOp.getNameAttr()); - passThrough.builder.create(tapOp.getLoc(), state, - materializedValue, Value{}); - tapOp.erase(); - return success(); -} - -/// Lower all instances of external modules to internal inputs/outputs to be -/// driven from outside of the design. -LogicalResult ModuleLowering::lowerExtModules(SymbolTable &symtbl) { - auto instOps = SmallVector(moduleOp.getOps()); - for (auto op : instOps) - if (isa(symtbl.lookup(op.getModuleNameAttr().getAttr()))) - if (failed(lowerExtModule(op))) - return failure(); - return success(); -} - -LogicalResult ModuleLowering::lowerExtModule(InstanceOp instOp) { - LLVM_DEBUG(llvm::dbgs() << "- Lowering extmodule " - << instOp.getInstanceNameAttr() << "\n"); - - SmallString<32> baseName(instOp.getInstanceName()); - auto baseNameLen = baseName.size(); - - // Lower the inputs of the extmodule as state that is only written. - for (auto [operand, name] : - llvm::zip(instOp.getOperands(), instOp.getArgNames())) { - LLVM_DEBUG(llvm::dbgs() - << " - Input " << name << " : " << operand.getType() << "\n"); - auto intType = dyn_cast(operand.getType()); - if (!intType) - return mlir::emitError(operand.getLoc(), "input ") - << name << " of extern module " << instOp.getModuleNameAttr() - << " instance " << instOp.getInstanceNameAttr() - << " is of non-integer type " << operand.getType(); - baseName.resize(baseNameLen); - baseName += '/'; - baseName += cast(name).getValue(); - auto &passThrough = getOrCreatePassThrough(); - auto state = stateBuilder.create( - instOp.getLoc(), StateType::get(intType), storageArg); - state->setAttr("name", stateBuilder.getStringAttr(baseName)); - passThrough.builder.create( - instOp.getLoc(), state, passThrough.materializeValue(operand), Value{}); - } - - // Lower the outputs of the extmodule as state that is only read. - for (auto [result, name] : - llvm::zip(instOp.getResults(), instOp.getResultNames())) { - LLVM_DEBUG(llvm::dbgs() - << " - Output " << name << " : " << result.getType() << "\n"); - auto intType = dyn_cast(result.getType()); - if (!intType) - return mlir::emitError(result.getLoc(), "output ") - << name << " of extern module " << instOp.getModuleNameAttr() - << " instance " << instOp.getInstanceNameAttr() - << " is of non-integer type " << result.getType(); - baseName.resize(baseNameLen); - baseName += '/'; - baseName += cast(name).getValue(); - auto state = stateBuilder.create( - result.getLoc(), StateType::get(intType), storageArg); - state->setAttr("name", stateBuilder.getStringAttr(baseName)); - replaceValueWithStateRead(result, state); - } - - instOp.erase(); - return success(); -} - -LogicalResult ModuleLowering::cleanup() { - // Clean up dead ops in the model. - SetVector erasureWorklist; - auto isDead = [](Operation *op) { - if (isOpTriviallyDead(op)) - return true; - if (!op->use_empty()) - return false; - return false; - }; - for (auto &op : *moduleOp.getBodyBlock()) - if (isDead(&op)) - erasureWorklist.insert(&op); - while (!erasureWorklist.empty()) { - auto *op = erasureWorklist.pop_back_val(); - if (!isDead(op)) - continue; - op->walk([&](Operation *innerOp) { - for (auto operand : innerOp->getOperands()) - if (auto *defOp = operand.getDefiningOp()) - if (!op->isProperAncestor(defOp)) - erasureWorklist.insert(defOp); - }); - op->erase(); - } - - // Establish an order among all operations (to avoid an O(n²) pathological - // pattern with `moveBefore`) and replicate read operations into the blocks - // where they have uses. The established order is used to create the read - // operation as late in the block as possible, just before the first use. - DenseMap opOrder; - SmallVector readsToSink; - moduleOp.walk([&](Operation *op) { - opOrder.insert({op, opOrder.size()}); - if (auto readOp = dyn_cast(op)) - readsToSink.push_back(readOp); - }); - for (auto readToSink : readsToSink) { - SmallDenseMap> readsByBlock; - for (auto &use : llvm::make_early_inc_range(readToSink->getUses())) { - auto *user = use.getOwner(); - auto userOrder = opOrder.lookup(user); - auto &localRead = readsByBlock[user->getBlock()]; - if (!localRead.first) { - if (user->getBlock() == readToSink->getBlock()) { - localRead.first = readToSink; - readToSink->moveBefore(user); - } else { - localRead.first = OpBuilder(user).cloneWithoutRegions(readToSink); - } - localRead.second = userOrder; - } else if (userOrder < localRead.second) { - localRead.first->moveBefore(user); - localRead.second = userOrder; - } - use.set(localRead.first); - } - if (readToSink.use_empty()) - readToSink.erase(); - } - return success(); -} - -//===----------------------------------------------------------------------===// -// Pass Infrastructure -//===----------------------------------------------------------------------===// - -namespace { -struct LowerStatePass : public arc::impl::LowerStateBase { - LowerStatePass() = default; - LowerStatePass(const LowerStatePass &pass) : LowerStatePass() {} - - void runOnOperation() override; - LogicalResult runOnModule(HWModuleOp moduleOp, SymbolTable &symtbl); - - Statistics stats{this}; -}; -} // namespace - -void LowerStatePass::runOnOperation() { - auto &symtbl = getAnalysis(); - SmallVector extModules; - for (auto &op : llvm::make_early_inc_range(getOperation().getOps())) { - if (auto moduleOp = dyn_cast(&op)) { - if (failed(runOnModule(moduleOp, symtbl))) - return signalPassFailure(); - } else if (auto extModuleOp = dyn_cast(&op)) { - extModules.push_back(extModuleOp); - } - } - for (auto op : extModules) - op.erase(); - - // Lower remaining MemoryReadPort ops to MemoryRead ops. This can occur when - // the fan-in of a MemoryReadPortOp contains another such operation and is - // materialized before the one in the fan-in as the MemoryReadPortOp is not - // marked as a fan-in blocking/termination operation in `shouldMaterialize`. - // Adding it there can lead to dominance issues which would then have to be - // resolved instead. - SetVector arcsToLower; - OpBuilder builder(getOperation()); - getOperation()->walk([&](MemoryReadPortOp memReadOp) { - if (auto defOp = memReadOp->getParentOfType()) - arcsToLower.insert(defOp); - - builder.setInsertionPoint(memReadOp); - Value newRead = builder.create( - memReadOp.getLoc(), memReadOp.getMemory(), memReadOp.getAddress()); - memReadOp.replaceAllUsesWith(newRead); - memReadOp.erase(); - }); - - SymbolTableCollection symbolTable; - mlir::SymbolUserMap userMap(symbolTable, getOperation()); - for (auto defOp : arcsToLower) { - auto *terminator = defOp.getBodyBlock().getTerminator(); - builder.setInsertionPoint(terminator); - builder.create(terminator->getLoc(), - terminator->getOperands()); - terminator->erase(); - builder.setInsertionPoint(defOp); - auto funcOp = builder.create(defOp.getLoc(), defOp.getName(), - defOp.getFunctionType()); - funcOp->setAttr("llvm.linkage", - LLVM::LinkageAttr::get(builder.getContext(), - LLVM::linkage::Linkage::Internal)); - funcOp.getBody().takeBody(defOp.getBody()); - - for (auto *user : userMap.getUsers(defOp)) { - builder.setInsertionPoint(user); - ValueRange results = builder - .create( - user->getLoc(), funcOp, - cast(user).getArgOperands()) - ->getResults(); - user->replaceAllUsesWith(results); - user->erase(); - } - - defOp.erase(); - } -} - -LogicalResult LowerStatePass::runOnModule(HWModuleOp moduleOp, - SymbolTable &symtbl) { - LLVM_DEBUG(llvm::dbgs() << "Lowering state in `" << moduleOp.getModuleName() - << "`\n"); - ModuleLowering lowering(moduleOp, stats); - - // Add sentinel ops to separate state allocations from clock trees. - lowering.stateBuilder.setInsertionPointToStart(moduleOp.getBodyBlock()); - - Operation *stateSentinel = - lowering.stateBuilder.create(moduleOp.getLoc()); - Operation *clockSentinel = - lowering.stateBuilder.create(moduleOp.getLoc()); - - // Create the 'initial' pseudo clock tree. - auto initialTreeOp = - lowering.stateBuilder.create(moduleOp.getLoc()); - initialTreeOp.getBody().emplaceBlock(); - lowering.initialLowering = - std::make_unique(Value{}, initialTreeOp, stats); - - lowering.stateBuilder.setInsertionPoint(stateSentinel); - lowering.clockBuilder.setInsertionPoint(clockSentinel); - - lowering.addStorageArg(); - if (failed(lowering.lowerInitials())) - return failure(); - if (failed(lowering.lowerPrimaryInputs())) - return failure(); - if (failed(lowering.lowerPrimaryOutputs())) - return failure(); - if (failed(lowering.lowerStates())) - return failure(); - if (failed(lowering.lowerExtModules(symtbl))) - return failure(); - - // Clean up the module body which contains a lot of operations that the - // pessimistic value materialization has left behind because it couldn't - // reliably determine that the ops were no longer needed. - if (failed(lowering.cleanup())) - return failure(); - - // Erase the sentinel ops. - stateSentinel->erase(); - clockSentinel->erase(); - - // Replace the `HWModuleOp` with a `ModelOp`. - moduleOp.getBodyBlock()->eraseArguments( - [&](auto arg) { return arg != lowering.storageArg; }); - ImplicitLocOpBuilder builder(moduleOp.getLoc(), moduleOp); - auto modelOp = - builder.create(moduleOp.getLoc(), moduleOp.getModuleNameAttr(), - TypeAttr::get(moduleOp.getModuleType()), - FlatSymbolRefAttr{}, FlatSymbolRefAttr{}); - modelOp.getBody().takeBody(moduleOp.getBody()); - moduleOp->erase(); - sortTopologically(&modelOp.getBodyBlock()); - - return success(); -} - -std::unique_ptr arc::createLowerStatePass() { - return std::make_unique(); -} diff --git a/lib/Dialect/Arc/Transforms/LowerStateRewrite.cpp b/lib/Dialect/Arc/Transforms/LowerStateRewrite.cpp new file mode 100644 index 000000000000..e1d4a78ad2af --- /dev/null +++ b/lib/Dialect/Arc/Transforms/LowerStateRewrite.cpp @@ -0,0 +1,1206 @@ +//===- LowerState.cpp -----------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "circt/Dialect/Arc/ArcOps.h" +#include "circt/Dialect/Arc/ArcPasses.h" +#include "circt/Dialect/Comb/CombDialect.h" +#include "circt/Dialect/Comb/CombOps.h" +#include "circt/Dialect/HW/HWOps.h" +#include "circt/Dialect/LLHD/IR/LLHDOps.h" +#include "circt/Dialect/Seq/SeqOps.h" +#include "circt/Dialect/Sim/SimOps.h" +#include "circt/Support/BackedgeBuilder.h" +#include "mlir/Analysis/TopologicalSortUtils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "arc-lower-state" + +namespace circt { +namespace arc { +#define GEN_PASS_DEF_LOWERSTATEPASS +#include "circt/Dialect/Arc/ArcPasses.h.inc" +} // namespace arc +} // namespace circt + +using namespace circt; +using namespace arc; +using namespace hw; +using namespace mlir; +using llvm::SmallDenseSet; + +namespace { +enum class Phase { Initial, Old, New, Final }; + +template +OS &operator<<(OS &os, Phase phase) { + switch (phase) { + case Phase::Initial: + return os << "initial"; + case Phase::Old: + return os << "old"; + case Phase::New: + return os << "new"; + case Phase::Final: + return os << "final"; + } +} + +struct ModuleLowering; + +/// All state associated with lowering a single operation. Instances of this +/// struct are kept on a worklist to perform a depth-first traversal of the +/// module being lowered. +/// +/// The actual lowering occurs in `lower()`. This function is called exactly +/// twice. A first time with `initial` being true, where other values and +/// operations that have to be lowered first may be marked with `addPending`. No +/// actual lowering or error reporting should occur when `initial` is true. The +/// worklist then ensures that all `pending` ops are lowered before `lower()` is +/// called a second time with `initial` being false. At this point the actual +/// lowering and error reporting should occur. +/// +/// The `initial` variable is used to allow for a single block of code to mark +/// values and ops as dependencies and actually do the lowering based on them. +struct OpLowering { + Operation *op; + Phase phase; + ModuleLowering &module; + + bool initial = true; + SmallVector, 2> pending; + + OpLowering(Operation *op, Phase phase, ModuleLowering &module) + : op(op), phase(phase), module(module) {} + + // Operation Lowering. + LogicalResult lower(); + LogicalResult lowerDefault(); + LogicalResult lower(StateOp op); + LogicalResult lower(sim::DPICallOp op); + LogicalResult + lowerStateful(Value clock, Value enable, Value reset, ValueRange inputs, + ResultRange results, + llvm::function_ref createMapping); + LogicalResult lower(MemoryOp op); + LogicalResult lower(TapOp op); + LogicalResult lower(InstanceOp op); + LogicalResult lower(hw::OutputOp op); + LogicalResult lower(seq::InitialOp op); + LogicalResult lower(llhd::FinalOp op); + + scf::IfOp createIfClockOp(Value clock); + + // Value Lowering. These functions are called from the `lower()` functions + // above. They handle values used by the `op`. This can generate reads from + // state and memory storage on-the-fly, or mark other ops as dependencies to + // be lowered first. + Value lowerValue(Value value, Phase phase); + Value lowerValue(InstanceOp op, OpResult result, Phase phase); + Value lowerValue(StateOp op, OpResult result, Phase phase); + Value lowerValue(sim::DPICallOp op, OpResult result, Phase phase); + Value lowerValue(MemoryReadPortOp op, OpResult result, Phase phase); + Value lowerValue(seq::InitialOp op, OpResult result, Phase phase); + Value lowerValue(seq::FromImmutableOp op, OpResult result, Phase phase); + + void addPending(Value value, Phase phase); + void addPending(Operation *op, Phase phase); +}; + +/// All state associated with lowering a single module. +struct ModuleLowering { + /// The module being lowered. + HWModuleOp moduleOp; + /// The builder for the main body of the model. + OpBuilder builder; + /// The builder for state allocation ops. + OpBuilder allocBuilder; + /// The builder for the initial phase. + OpBuilder initialBuilder; + /// The builder for the final phase. + OpBuilder finalBuilder; + + /// The storage value that can be used for `arc.alloc_state` and friends. + Value storageArg; + + /// A worklist of pending op lowerings. + SmallVector opsWorklist; + /// The set of ops currently in the worklist. Used to detect cycles. + SmallDenseSet> opsSeen; + /// The ops that have already been lowered. + DenseSet> loweredOps; + /// The values that have already been lowered. + DenseMap, Value> loweredValues; + + /// The allocated input ports. + SmallVector allocatedInputs; + /// The allocated states as a mapping from op results to `arc.alloc_state` + /// results. + DenseMap allocatedStates; + /// The allocated storage for instance inputs and top module outputs. + DenseMap allocatedOutputs; + /// The allocated storage for values computed during the initial phase. + DenseMap allocatedInitials; + /// The allocated storage for taps. + DenseMap allocatedTaps; + + /// A mapping from unlowered clocks to a value indicating a posedge. This is + /// used to not create an excessive number of posedge detectors. + DenseMap loweredPosedges; + /// The previous enable and the value it was lowered to. This is used to reuse + /// previous if ops for the same enable value. + std::pair prevEnable; + /// The previous reset and the value it was lowered to. This is used to reuse + /// previous uf ops for the same reset value. + std::pair prevReset; + + ModuleLowering(HWModuleOp moduleOp) + : moduleOp(moduleOp), builder(moduleOp), allocBuilder(moduleOp), + initialBuilder(moduleOp), finalBuilder(moduleOp) {} + LogicalResult run(); + LogicalResult lowerOp(Operation *op); + Value getAllocatedState(OpResult result); + Value detectPosedge(Value clock); + OpBuilder &getBuilder(Phase phase); + Value requireLoweredValue(Value value, Phase phase, Location useLoc); +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Module Lowering +//===----------------------------------------------------------------------===// + +LogicalResult ModuleLowering::run() { + LLVM_DEBUG(llvm::dbgs() << "Lowering module `" << moduleOp.getModuleName() + << "`\n"); + + // Create the replacement `ModelOp`. + auto modelOp = + builder.create(moduleOp.getLoc(), moduleOp.getModuleNameAttr(), + TypeAttr::get(moduleOp.getModuleType()), + FlatSymbolRefAttr{}, FlatSymbolRefAttr{}); + auto &modelBlock = modelOp.getBody().emplaceBlock(); + storageArg = modelBlock.addArgument( + StorageType::get(builder.getContext(), {}), modelOp.getLoc()); + builder.setInsertionPointToStart(&modelBlock); + + // Create the `arc.initial` op to contain the ops for the initialization + // phase. + auto initialOp = builder.create(moduleOp.getLoc()); + initialBuilder.setInsertionPointToStart(&initialOp.getBody().emplaceBlock()); + + // Create the `arc.final` op to contain the ops for the finalization phase. + auto finalOp = builder.create(moduleOp.getLoc()); + finalBuilder.setInsertionPointToStart(&finalOp.getBody().emplaceBlock()); + + // Position the alloc builder such that allocation ops get inserted above the + // initial op. + allocBuilder.setInsertionPoint(initialOp); + + // Allocate storage for the inputs. + for (auto arg : moduleOp.getBodyBlock()->getArguments()) { + auto name = moduleOp.getArgName(arg.getArgNumber()); + auto state = allocBuilder.create( + arg.getLoc(), StateType::get(arg.getType()), name, storageArg); + allocatedInputs.push_back(state); + } + + // Lower the ops. + for (auto &op : moduleOp.getOps()) { + if (mlir::isMemoryEffectFree(&op) && !isa(op)) + continue; + if (isa(op)) + continue; // handled as part of `MemoryOp` + if (failed(lowerOp(&op))) + return failure(); + } + + // Clean up any dead ops. The lowering inserts a few defensive + // `arc.state_read` ops that may remain unused. This cleans them up. + for (auto &op : llvm::make_early_inc_range(llvm::reverse(modelBlock))) + if (mlir::isOpTriviallyDead(&op)) + op.erase(); + + return success(); +} + +/// Lower an op and its entire fan-in cone. +LogicalResult ModuleLowering::lowerOp(Operation *op) { + LLVM_DEBUG(llvm::dbgs() << "- Handling " << *op << "\n"); + + // Pick in which phases the given operation has to perform some work. + SmallVector phases = {Phase::New}; + if (isa(op)) + phases = {Phase::Initial}; + if (isa(op)) + phases = {Phase::Final}; + if (isa(op)) + phases = {Phase::Initial, Phase::New}; + + for (auto phase : phases) { + if (loweredOps.contains({op, phase})) + return success(); + opsWorklist.push_back(OpLowering(op, phase, *this)); + opsSeen.insert({op, phase}); + } + + auto dumpWorklist = [&] { + for (auto &opLowering : llvm::reverse(opsWorklist)) + opLowering.op->emitRemark() + << "computing " << opLowering.phase << " phase here"; + }; + + while (!opsWorklist.empty()) { + auto &opLowering = opsWorklist.back(); + + // Collect an initial list of operands that need to be lowered. + if (opLowering.initial) { + if (failed(opLowering.lower())) { + dumpWorklist(); + return failure(); + } + std::reverse(opLowering.pending.begin(), opLowering.pending.end()); + opLowering.initial = false; + } + + // Push operands onto the worklist. + if (!opLowering.pending.empty()) { + auto [defOp, phase] = opLowering.pending.pop_back_val(); + if (loweredOps.contains({defOp, phase})) + continue; + if (!opsSeen.insert({defOp, phase}).second) { + defOp->emitOpError("is on a combinational loop"); + dumpWorklist(); + return failure(); + } + opsWorklist.push_back(OpLowering(defOp, phase, *this)); + continue; + } + + // At this point all operands are available and the op itself can be + // lowered. + LLVM_DEBUG(llvm::dbgs() << " - Lowering " << opLowering.phase << " " + << *opLowering.op << "\n"); + if (failed(opLowering.lower())) { + dumpWorklist(); + return failure(); + } + loweredOps.insert({opLowering.op, opLowering.phase}); + opsSeen.erase({opLowering.op, opLowering.phase}); + opsWorklist.pop_back(); + } + + return success(); +} + +/// Return the `arc.alloc_state` associated with the given state op result. +/// Creates the allocation op if it does not yet exist. +Value ModuleLowering::getAllocatedState(OpResult result) { + if (auto alloc = allocatedStates.lookup(result)) + return alloc; + + // Handle memories. + if (auto memOp = dyn_cast(result.getOwner())) { + auto alloc = allocBuilder.create( + memOp.getLoc(), memOp.getType(), storageArg, memOp->getAttrs()); + allocatedStates.insert({result, alloc}); + return alloc; + } + + // Create the allocation op. + auto alloc = allocBuilder.create( + result.getLoc(), StateType::get(result.getType()), storageArg); + allocatedStates.insert({result, alloc}); + + // HACK: If the result comes from an instance op, add the instance and port + // name as an attribute to the allocation. This will make it show up in the C + // headers later. Get rid of this once we have proper debug dialect support. + if (auto instOp = dyn_cast(result.getOwner())) + alloc->setAttr( + "name", builder.getStringAttr( + instOp.getInstanceName() + "/" + + instOp.getResultName(result.getResultNumber()).getValue())); + + // HACK: If the result comes from an op that has a "names" attribute, use that + // as a name for the allocation. This should no longer be necessary once we + // properly support the Debug dialect. + if (isa(result.getOwner())) + if (auto names = result.getOwner()->getAttrOfType("names")) + if (result.getResultNumber() < names.size()) + alloc->setAttr("name", names[result.getResultNumber()]); + + return alloc; +} + +/// Allocate the necessary storage, reads, writes, and comparisons to detect a +/// rising edge on a clock value. +Value ModuleLowering::detectPosedge(Value clock) { + auto loc = clock.getLoc(); + if (isa(clock.getType())) + clock = builder.createOrFold(loc, clock); + + // Allocate storage to store the previous clock value. + auto oldStorage = allocBuilder.create( + loc, StateType::get(builder.getI1Type()), storageArg); + + // Read the old clock value from storage and write the new clock value to + // storage. + auto oldClock = builder.create(loc, oldStorage); + builder.create(loc, oldStorage, clock, Value{}); + + // Detect a rising edge. + Value edge = builder.create(loc, oldClock, clock); + edge = builder.create(loc, edge, clock); + return edge; +} + +/// Get the builder appropriate for the given phase. +OpBuilder &ModuleLowering::getBuilder(Phase phase) { + switch (phase) { + case Phase::Initial: + return initialBuilder; + case Phase::Old: + case Phase::New: + return builder; + case Phase::Final: + return finalBuilder; + } +} + +/// Get the lowered value, or emit a diagnostic and return null. +Value ModuleLowering::requireLoweredValue(Value value, Phase phase, + Location useLoc) { + if (auto lowered = loweredValues.lookup({value, phase})) + return lowered; + auto d = emitError(value.getLoc()) << "value has not been lowered"; + d.attachNote(useLoc) << "value used here"; + return {}; +} + +//===----------------------------------------------------------------------===// +// Operation Lowering +//===----------------------------------------------------------------------===// + +/// Create a new `scf.if` operation with the given builder, or reuse a previous +/// `scf.if` if the builder's insertion point is located right after it. +static scf::IfOp createOrReuseIf(OpBuilder &builder, Value condition, + bool withElse) { + scf::IfOp ifClockOp; + if (auto ip = builder.getInsertionPoint(); ip != builder.getBlock()->begin()) + if (auto ifOp = dyn_cast(*std::prev(ip))) + if (ifOp.getCondition() == condition) + return ifOp; + return builder.create(condition.getLoc(), condition, withElse); +} + +/// This function is called from the lowering worklist in order to perform a +/// depth-first traversal of the surrounding module. These functions call +/// `lowerValue` to mark their operands as dependencies in the depth-first +/// traversal, and to map them to the lowered value in one go. +LogicalResult OpLowering::lower() { + return TypeSwitch(op) + // Operations with special lowering. + .Case([&](auto op) { return lower(op); }) + + // Operations that should be skipped entirely and never land on the + // worklist to be lowered. + .Case([&](auto op) { + op.emitOpError() << "is handled by memory op and must be skipped"; + return success(); + }) + + // All other ops are simply cloned into the lowered model. + .Default([&](auto) { return lowerDefault(); }); +} + +/// Called for all operations for which there is no special lowering. Simply +/// clones the operation. +LogicalResult OpLowering::lowerDefault() { + // Make sure that all operand values are lowered first. + IRMapping mapping; + auto anyFailed = false; + op->walk([&](Operation *nestedOp) { + for (auto operand : nestedOp->getOperands()) { + if (op->isAncestor(operand.getParentBlock()->getParentOp())) + continue; + auto lowered = lowerValue(operand, phase); + if (!lowered) + anyFailed = true; + mapping.map(operand, lowered); + } + }); + if (initial) + return success(); + if (anyFailed) + return failure(); + + // Clone the operation. + auto *clonedOp = module.getBuilder(phase).clone(*op, mapping); + + // Keep track of the results. + for (auto [oldResult, newResult] : + llvm::zip(op->getResults(), clonedOp->getResults())) + module.loweredValues[{oldResult, phase}] = newResult; + + return success(); +} + +/// Lower a state to a corresponding storage allocation and write of the state's +/// new value to it. This function uses the `Old` phase to get the values at the +/// state input before the current update, and then uses them to compute the +/// `New` value. +LogicalResult OpLowering::lower(StateOp op) { + // Handle initialization. + if (phase == Phase::Initial) { + // Ensure the initial values of the register have been lowered before. + if (initial) { + for (auto initial : op.getInitials()) + lowerValue(initial, Phase::Initial); + return success(); + } + + // Write the initial values to the allocated storage in the initial block. + if (op.getInitials().empty()) + return success(); + for (auto [initial, result] : + llvm::zip(op.getInitials(), op.getResults())) { + auto value = lowerValue(initial, Phase::Initial); + if (!value) + return failure(); + auto state = module.getAllocatedState(result); + if (!state) + return failure(); + module.initialBuilder.create(value.getLoc(), state, value, + Value{}); + } + return success(); + } + + if (phase != Phase::New) + return op.emitOpError() << "cannot be lowered in the " << phase << " phase"; + + if (!initial) { + if (!op.getClock()) + return op.emitOpError() << "must have a clock"; + if (op.getLatency() > 1) + return op.emitOpError("latencies > 1 not supported yet"); + } + + return lowerStateful(op.getClock(), op.getEnable(), op.getReset(), + op.getInputs(), op.getResults(), [&](ValueRange inputs) { + return module.builder + .create(op.getLoc(), op.getResultTypes(), + op.getArc(), inputs) + .getResults(); + }); +} + +/// Lower a state to a corresponding storage allocation and write of the state's +/// new value to it. This function uses the `Old` phase to get the values at the +/// state input before the current update, and then uses them to compute the +/// `New` value. +LogicalResult OpLowering::lower(sim::DPICallOp op) { + // Handle unclocked DPI calls. + if (!op.getClock()) { + // Make sure that all operands have been lowered. + SmallVector inputs; + for (auto operand : op.getInputs()) + inputs.push_back(lowerValue(operand, phase)); + if (initial) + return success(); + if (llvm::is_contained(inputs, Value{})) + return failure(); + if (op.getEnable()) + return op.emitOpError() << "without clock cannot have an enable"; + + // Lower the op to a regular function call. + auto callOp = module.getBuilder(phase).create( + op.getLoc(), op.getCalleeAttr(), op.getResultTypes(), inputs); + for (auto [oldResult, newResult] : + llvm::zip(op.getResults(), callOp.getResults())) + module.loweredValues[{oldResult, phase}] = newResult; + return success(); + } + + if (phase != Phase::New) + return op.emitOpError() << "cannot be lowered in the " << phase << " phase"; + + return lowerStateful(op.getClock(), op.getEnable(), /*reset=*/{}, + op.getInputs(), op.getResults(), [&](ValueRange inputs) { + return module.builder + .create(op.getLoc(), + op.getCalleeAttr(), + op.getResultTypes(), inputs) + .getResults(); + }); +} + +/// Lower a state to a corresponding storage allocation and write of the state's +/// new value to it. This function uses the `Old` phase to get the values at the +/// state input before the current update, and then uses them to compute the +/// `New` value. +LogicalResult OpLowering::lowerStateful( + Value clock, Value enable, Value reset, ValueRange inputs, + ResultRange results, + llvm::function_ref createMapping) { + // Ensure all operands are lowered before we lower the op itself. State ops + // are special in that they require the "old" value of their inputs and + // enable, in order to compute the updated "new" value. The clock needs to be + // the "new" value though, such that other states can act as a clock source. + if (initial) { + lowerValue(clock, Phase::New); + if (enable) + lowerValue(enable, Phase::Old); + if (reset) + lowerValue(reset, Phase::Old); + for (auto input : inputs) + lowerValue(input, Phase::Old); + return success(); + } + + // Check if we're inserting right after an if op for the same clock edge, in + // which case we can reuse that op. Otherwise create the new if op. + auto ifClockOp = createIfClockOp(clock); + if (!ifClockOp) + return failure(); + OpBuilder::InsertionGuard guard(module.builder); + module.builder.setInsertionPoint(ifClockOp.thenYield()); + + // Make sure we have the state storage available such that we can read and + // write from and to them. + SmallVector states; + for (auto result : results) { + auto state = module.getAllocatedState(result); + if (!state) + return failure(); + states.push_back(state); + } + + // Handle the reset. + if (reset) { + // Check if we can reuse a previous reset value. + auto &[unloweredReset, loweredReset] = module.prevReset; + if (unloweredReset != reset || + loweredReset.getParentBlock() != module.builder.getBlock()) { + unloweredReset = reset; + loweredReset = lowerValue(reset, Phase::Old); + if (!loweredReset) + return failure(); + } + + // Check if we're inserting right after an if op for the same reset, in + // which case we can reuse that op. Otherwise create the new if op. + auto ifResetOp = createOrReuseIf(module.builder, loweredReset, true); + module.builder.setInsertionPoint(ifResetOp.thenYield()); + + // Generate the zero value writes. + for (auto state : states) { + auto type = cast(state.getType()).getType(); + Value value = module.builder.create( + loweredReset.getLoc(), + module.builder.getIntegerType(hw::getBitWidth(type)), 0); + if (value.getType() != type) + value = module.builder.create(loweredReset.getLoc(), type, + value); + module.builder.create(loweredReset.getLoc(), state, value, + Value{}); + } + module.builder.setInsertionPoint(ifResetOp.elseYield()); + } + + // Handle the enable. + if (enable) { + // Check if we can reuse a previous enable value. + auto &[unloweredEnable, loweredEnable] = module.prevEnable; + if (unloweredEnable != enable || + loweredEnable.getParentBlock() != module.builder.getBlock()) { + unloweredEnable = enable; + loweredEnable = lowerValue(enable, Phase::Old); + if (!loweredEnable) + return failure(); + } + + // Check if we're inserting right after an if op for the same enable, in + // which case we can reuse that op. Otherwise create the new if op. + auto ifEnableOp = createOrReuseIf(module.builder, loweredEnable, false); + module.builder.setInsertionPoint(ifEnableOp.thenYield()); + } + + // Get the transfer function inputs. This potentially inserts read ops. + SmallVector loweredInputs; + for (auto input : inputs) { + auto lowered = lowerValue(input, Phase::Old); + if (!lowered) + return failure(); + loweredInputs.push_back(lowered); + } + + // Compute the transfer function and write its results to the state's storage. + auto loweredResults = createMapping(loweredInputs); + for (auto [state, value] : llvm::zip(states, loweredResults)) + module.builder.create(value.getLoc(), state, value, Value{}); + + // Since we just wrote the new state value to storage, insert read ops just + // before the if op that keep the old value around for any later ops that + // still need it. + module.builder.setInsertionPoint(ifClockOp); + for (auto [state, result] : llvm::zip(states, results)) { + auto oldValue = module.builder.create(result.getLoc(), state); + module.loweredValues[{result, Phase::Old}] = oldValue; + } + + return success(); +} + +/// Lower a memory and its read and write ports to corresponding +/// `arc.memory_write` operations. Reads are also executed at this point and +/// stored in `loweredValues` for later operations to pick up. +LogicalResult OpLowering::lower(MemoryOp op) { + if (phase != Phase::New) + return op.emitOpError() << "cannot be lowered in the " << phase << " phase"; + + // Collect all the reads and writes. + SmallVector reads; + SmallVector writes; + + for (auto *user : op->getUsers()) { + if (auto read = dyn_cast(user)) { + reads.push_back(read); + } else if (auto write = dyn_cast(user)) { + writes.push_back(write); + } else { + auto d = op.emitOpError() + << "users must all be memory read or write port ops"; + d.attachNote(user->getLoc()) + << "but found " << user->getName() << " user here"; + return d; + } + } + + // Ensure all operands are lowered before we lower the memory itself. + if (initial) { + for (auto read : reads) + lowerValue(read, Phase::Old); + for (auto write : writes) { + if (write.getClock()) + lowerValue(write.getClock(), Phase::New); + for (auto input : write.getInputs()) + lowerValue(input, Phase::Old); + } + return success(); + } + + // Get the allocated storage for the memory. + auto state = module.getAllocatedState(op->getResult(0)); + + // Since we are going to write new values into storage, insert read ops that + // keep the old values around for any later ops that still need them. + for (auto read : reads) { + auto oldValue = lowerValue(read, Phase::Old); + if (!oldValue) + return failure(); + module.loweredValues[{read, Phase::Old}] = oldValue; + } + + // Lower the writes. + for (auto write : writes) { + if (!write.getClock()) + return write.emitOpError() << "must have a clock"; + if (write.getLatency() > 1) + return write.emitOpError("latencies > 1 not supported yet"); + + // Create the if op for the clock edge. + auto ifClockOp = createIfClockOp(write.getClock()); + if (!ifClockOp) + return failure(); + OpBuilder::InsertionGuard guard(module.builder); + module.builder.setInsertionPoint(ifClockOp.thenYield()); + + // Call the arc that computes the address, data, and enable. + SmallVector inputs; + for (auto input : write.getInputs()) { + auto lowered = lowerValue(input, Phase::Old); + if (!lowered) + return failure(); + inputs.push_back(lowered); + } + auto callOp = module.builder.create( + write.getLoc(), write.getArcResultTypes(), write.getArc(), inputs); + + // If the write has an enable, wrap the remaining logic in an if op. + if (write.getEnable()) { + auto ifEnableOp = createOrReuseIf( + module.builder, callOp.getResult(write.getEnableIdx()), false); + module.builder.setInsertionPoint(ifEnableOp.thenYield()); + } + + // If the write is masked, read the current + // value in the memory and merge it with the updated value. + auto address = callOp.getResult(write.getAddressIdx()); + auto data = callOp.getResult(write.getDataIdx()); + if (write.getMask()) { + auto mask = callOp.getResult(write.getMaskIdx(write.getEnable())); + auto maskInv = module.builder.createOrFold( + write.getLoc(), mask, + module.builder.create(write.getLoc(), mask.getType(), -1), + true); + auto oldData = + module.builder.create(write.getLoc(), state, address); + auto oldMasked = module.builder.create( + write.getLoc(), maskInv, oldData, true); + auto newMasked = + module.builder.create(write.getLoc(), mask, data, true); + data = module.builder.create(write.getLoc(), oldMasked, + newMasked, true); + } + + // Actually write to the memory. + module.builder.create(write.getLoc(), state, address, + Value{}, data); + } + + return success(); +} + +/// Lower a tap by allocating state storage for it and writing the current value +/// observed by the tap to it. +LogicalResult OpLowering::lower(TapOp op) { + if (phase != Phase::New) + return op.emitOpError() << "cannot be lowered in the " << phase << " phase"; + + auto value = lowerValue(op.getValue(), phase); + if (initial) + return success(); + if (!value) + return failure(); + + auto &state = module.allocatedTaps[op]; + if (!state) { + auto alloc = module.allocBuilder.create( + op.getLoc(), StateType::get(value.getType()), module.storageArg, true); + alloc->setAttr("name", op.getNameAttr()); + state = alloc; + } + module.builder.create(op.getLoc(), state, value, Value{}); + return success(); +} + +/// Lower an instance by allocating state storage for each of its inputs and +/// writing the current value into that storage. This makes instance inputs +/// behave like outputs of the top-level module. +LogicalResult OpLowering::lower(InstanceOp op) { + if (phase != Phase::New) + return op.emitOpError() << "cannot be lowered in the " << phase << " phase"; + + // Get the current values flowing into the instance's inputs. + SmallVector values; + for (auto operand : op.getOperands()) + values.push_back(lowerValue(operand, Phase::New)); + if (initial) + return success(); + if (llvm::is_contained(values, Value{})) + return failure(); + + // Then allocate storage for each instance input and assign the corresponding + // value. + for (auto [value, name] : llvm::zip(values, op.getArgNames())) { + auto state = module.allocBuilder.create( + value.getLoc(), StateType::get(value.getType()), module.storageArg); + state->setAttr("name", module.builder.getStringAttr( + op.getInstanceName() + "/" + + cast(name).getValue())); + module.builder.create(value.getLoc(), state, value, Value{}); + } + + // HACK: Also ensure that storage has been allocated for all outputs. + // Otherwise only the actually used instance outputs would be allocated, which + // would make the optimization user-visible. Remove this once we use the debug + // dialect. + for (auto result : op.getResults()) + module.getAllocatedState(result); + + return success(); +} + +/// Lower the main module's outputs by allocating storage for each and then +/// writing the current value into that storage. +LogicalResult OpLowering::lower(hw::OutputOp op) { + if (phase != Phase::New) + return op.emitOpError() << "cannot be lowered in the " << phase << " phase"; + + // First get the current value of all outputs. + SmallVector values; + for (auto operand : op.getOperands()) + values.push_back(lowerValue(operand, Phase::New)); + if (initial) + return success(); + if (llvm::is_contained(values, Value{})) + return failure(); + + // Then allocate storage for each output and assign the corresponding value. + for (auto [value, name] : + llvm::zip(values, module.moduleOp.getOutputNames())) { + auto state = module.allocBuilder.create( + value.getLoc(), StateType::get(value.getType()), cast(name), + module.storageArg); + module.builder.create(value.getLoc(), state, value, Value{}); + } + return success(); +} + +/// Lower `seq.initial` ops by inlining them into the `arc.initial` op. +LogicalResult OpLowering::lower(seq::InitialOp op) { + if (phase != Phase::Initial) + return op.emitOpError() << "cannot be lowered in the " << phase << " phase"; + + // First get the initial value of all operands. + SmallVector operands; + for (auto operand : op.getOperands()) + operands.push_back(lowerValue(operand, Phase::Initial)); + if (initial) + return success(); + if (llvm::is_contained(operands, Value{})) + return failure(); + + // Expose the `seq.initial` operands as values for the block arguments. + for (auto [arg, operand] : llvm::zip(op.getBody().getArguments(), operands)) + module.loweredValues[{arg, Phase::Initial}] = operand; + + // Lower each op in the body. + for (auto &bodyOp : op.getOps()) { + if (isa(bodyOp)) + continue; + + // Clone the operation. + auto *clonedOp = module.initialBuilder.clone(bodyOp); + auto result = clonedOp->walk([&](Operation *nestedClonedOp) { + for (auto &operand : nestedClonedOp->getOpOperands()) { + if (clonedOp->isAncestor(operand.get().getParentBlock()->getParentOp())) + continue; + auto value = module.requireLoweredValue(operand.get(), Phase::Initial, + nestedClonedOp->getLoc()); + if (!value) + return WalkResult::interrupt(); + operand.set(value); + } + return WalkResult::advance(); + }); + if (result.wasInterrupted()) + return failure(); + + // Keep track of the results. + for (auto [result, lowered] : + llvm::zip(bodyOp.getResults(), clonedOp->getResults())) + module.loweredValues[{result, Phase::Initial}] = lowered; + } + + // Expose the operands of `seq.yield` as results from the initial op. + auto *terminator = op.getBodyBlock()->getTerminator(); + for (auto [result, operand] : + llvm::zip(op.getResults(), terminator->getOperands())) { + auto value = module.requireLoweredValue(operand, Phase::Initial, + terminator->getLoc()); + if (!value) + return failure(); + module.loweredValues[{result, Phase::Initial}] = value; + } + + return success(); +} + +/// Lower `llhd.final` ops into `scf.execute_region` ops in the `arc.final` op. +LogicalResult OpLowering::lower(llhd::FinalOp op) { + if (phase != Phase::Final) + return op.emitOpError() << "cannot be lowered in the " << phase << " phase"; + + // Determine the uses of values defined outside the op. + SmallVector externalOperands; + op.walk([&](Operation *nestedOp) { + for (auto value : nestedOp->getOperands()) + if (!op->isAncestor(value.getParentBlock()->getParentOp())) + externalOperands.push_back(value); + }); + + // Make sure that all uses of external values are lowered first. + IRMapping mapping; + for (auto operand : externalOperands) { + auto lowered = lowerValue(operand, Phase::Final); + if (!initial && !lowered) + return failure(); + mapping.map(operand, lowered); + } + if (initial) + return success(); + + // Handle the simple case where the final op contains only one block, which we + // can inline directly. + if (op.getBody().hasOneBlock()) { + for (auto &bodyOp : op.getBody().front().without_terminator()) + module.finalBuilder.clone(bodyOp, mapping); + return success(); + } + + // Create a new `scf.execute_region` op and clone the entire `llhd.final` body + // region into it. Replace `llhd.halt` ops with `scf.yield`. + auto executeOp = module.finalBuilder.create( + op.getLoc(), TypeRange{}); + module.finalBuilder.cloneRegionBefore(op.getBody(), executeOp.getRegion(), + executeOp.getRegion().begin(), mapping); + executeOp.walk([&](llhd::HaltOp op) { + OpBuilder(op).create(op.getLoc()); + op.erase(); + }); + + return success(); +} + +/// Create the operations necessary to detect a posedge on the given clock, +/// potentially reusing a previous posedge detection, and create an `scf.if` +/// operation for that posedge. This also tries to reuse an `scf.if` operation +/// immediately before the builder's insertion point if possible. +scf::IfOp OpLowering::createIfClockOp(Value clock) { + auto &posedge = module.loweredPosedges[clock]; + if (!posedge) { + auto loweredClock = lowerValue(clock, Phase::New); + if (!loweredClock) + return {}; + posedge = module.detectPosedge(loweredClock); + } + return createOrReuseIf(module.builder, posedge, false); +} + +//===----------------------------------------------------------------------===// +// Value Lowering +//===----------------------------------------------------------------------===// + +/// Lower a value being used by the current operation. This will mark the +/// defining operation as to be lowered first (through `addPending`) in most +/// cases. Some operations and values have special handling though. For example, +/// states and memory reads are immediately materialized as a new read op. +Value OpLowering::lowerValue(Value value, Phase phase) { + // Handle module inputs. They read the same in all phases. + if (auto arg = dyn_cast(value)) { + if (initial) + return {}; + auto state = module.allocatedInputs[arg.getArgNumber()]; + return module.getBuilder(phase).create(arg.getLoc(), state); + } + + // Check if the value has already been lowered. + if (auto lowered = module.loweredValues.lookup({value, phase})) + return lowered; + + // At this point the value is the result of an op. (Block arguments are + // handled above.) + auto result = cast(value); + auto *op = result.getOwner(); + + // Special handling for some ops. + if (auto instOp = dyn_cast(op)) + return lowerValue(instOp, result, phase); + if (auto stateOp = dyn_cast(op)) + return lowerValue(stateOp, result, phase); + if (auto dpiOp = dyn_cast(op); dpiOp && dpiOp.getClock()) + return lowerValue(dpiOp, result, phase); + if (auto readOp = dyn_cast(op)) + return lowerValue(readOp, result, phase); + if (auto initialOp = dyn_cast(op)) + return lowerValue(initialOp, result, phase); + if (auto castOp = dyn_cast(op)) + return lowerValue(castOp, result, phase); + + // Otherwise we mark the defining operation as to be lowered first. This will + // cause the lookup in `loweredValues` above to return a value the next time + // (i.e. when initial is false). + if (initial) { + addPending(op, phase); + return {}; + } + emitError(result.getLoc()) << "value has not been lowered"; + return {}; +} + +/// Handle instance outputs. They behave essentially like a top-level module +/// input, and read the same in all phases. +Value OpLowering::lowerValue(InstanceOp op, OpResult result, Phase phase) { + if (initial) + return {}; + auto state = module.getAllocatedState(result); + return module.getBuilder(phase).create(result.getLoc(), state); +} + +/// Handle uses of a state. This creates a `arc.state_read` op to read from the +/// state's storage. If the new value after all updates is requested, marks the +/// state as to be lowered first (which will perform the writes). If the old +/// value is requested, asserts that no new values have been written. +Value OpLowering::lowerValue(StateOp op, OpResult result, Phase phase) { + if (initial) { + // Ensure that the new or initial value has been written by the lowering of + // the state op before we attempt to read it. + if (phase == Phase::New || phase == Phase::Initial) + addPending(op, phase); + return {}; + } + + // If we want to read the old value, no writes must have been lowered yet. + if (phase == Phase::Old) + assert(!module.loweredOps.contains({op, Phase::New}) && + "need old value but new value already written"); + + auto state = module.getAllocatedState(result); + return module.getBuilder(phase).create(result.getLoc(), state); +} + +/// Handle uses of a state. This creates a `arc.state_read` op to read from the +/// state's storage. If the new value after all updates is requested, marks the +/// state as to be lowered first (which will perform the writes). If the old +/// value is requested, asserts that no new values have been written. +Value OpLowering::lowerValue(sim::DPICallOp op, OpResult result, Phase phase) { + if (initial) { + // Ensure that the new or initial value has been written by the lowering of + // the state op before we attempt to read it. + if (phase == Phase::New || phase == Phase::Initial) + addPending(op, phase); + return {}; + } + + // If we want to read the old value, no writes must have been lowered yet. + if (phase == Phase::Old) + assert(!module.loweredOps.contains({op, Phase::New}) && + "need old value but new value already written"); + + auto state = module.getAllocatedState(result); + return module.getBuilder(phase).create(result.getLoc(), state); +} + +/// Handle uses of a memory read operation. This creates an `arc.memory_read` op +/// to read from the memory's storage. Similar to the `StateOp` handling +/// otherwise. +Value OpLowering::lowerValue(MemoryReadPortOp op, OpResult result, + Phase phase) { + auto memOp = op.getMemory().getDefiningOp(); + if (!memOp) { + if (!initial) + op->emitOpError() << "memory must be defined locally"; + return {}; + } + + auto address = lowerValue(op.getAddress(), phase); + if (initial) { + // Ensure that all new values are written before we attempt to read them. + if (phase == Phase::New) + addPending(memOp.getOperation(), Phase::New); + return {}; + } + if (!address) + return {}; + + if (phase == Phase::Old) { + // If we want to read the old value, no writes must have been lowered yet. + assert(!module.loweredOps.contains({memOp, Phase::New}) && + "need old memory value but new value already written"); + } else if (phase != Phase::New) { + op.emitOpError() << "result cannot be used in " << phase << " phase\n"; + return {}; + } + + auto state = module.getAllocatedState(memOp->getResult(0)); + return module.getBuilder(phase).create(result.getLoc(), state, + address); +} + +/// Handle uses of `seq.initial` values computed during the initial phase. This +/// ensures that the interesting value is stored into storage during the initial +/// phase, and then reads it back using an `arc.state_read` op. +Value OpLowering::lowerValue(seq::InitialOp op, OpResult result, Phase phase) { + // Ensure the op has been lowered first. + if (initial) { + addPending(op, Phase::Initial); + return {}; + } + auto value = module.loweredValues.lookup({result, Phase::Initial}); + if (!value) { + emitError(result.getLoc()) << "value has not been lowered"; + return {}; + } + + // If we are using the value of `seq.initial` in the initial phase directly, + // there is no need to write it so any temporary storage. + if (phase == Phase::Initial) + return value; + + // If necessary, allocate storage for the computed value and store it in the + // initial phase. + auto &state = module.allocatedInitials[result]; + if (!state) { + state = module.allocBuilder.create( + value.getLoc(), StateType::get(value.getType()), module.storageArg); + OpBuilder::InsertionGuard guard(module.initialBuilder); + module.initialBuilder.setInsertionPointAfterValue(value); + module.initialBuilder.create(value.getLoc(), state, value, + Value{}); + } + + // Read back the value computed during the initial phase. + return module.getBuilder(phase).create(state.getLoc(), state); +} + +/// The `seq.from_immutable` cast is just a passthrough. +Value OpLowering::lowerValue(seq::FromImmutableOp op, OpResult result, + Phase phase) { + return lowerValue(op.getInput(), phase); +} + +/// Mark a value as to be lowered before the current op. +void OpLowering::addPending(Value value, Phase phase) { + auto *defOp = value.getDefiningOp(); + assert(defOp && "block args should never be marked as a dependency"); + addPending(defOp, phase); +} + +/// Mark an operation as to be lowered before the current op. This adds that +/// operation to the `pending` list if the operation has not yet been lowered. +void OpLowering::addPending(Operation *op, Phase phase) { + auto pair = std::make_pair(op, phase); + if (!module.loweredOps.contains(pair)) + if (!llvm::is_contained(pending, pair)) + pending.push_back(pair); +} + +//===----------------------------------------------------------------------===// +// Pass Infrastructure +//===----------------------------------------------------------------------===// + +namespace { +struct LowerStatePass : public arc::impl::LowerStatePassBase { + using LowerStatePassBase::LowerStatePassBase; + void runOnOperation() override; +}; +} // namespace + +void LowerStatePass::runOnOperation() { + auto op = getOperation(); + for (auto moduleOp : llvm::make_early_inc_range(op.getOps())) { + if (failed(ModuleLowering(moduleOp).run())) + return signalPassFailure(); + moduleOp.erase(); + } + for (auto extModuleOp : + llvm::make_early_inc_range(op.getOps())) + extModuleOp.erase(); +} diff --git a/test/Dialect/Arc/legalize-state-update-error.mlir b/test/Dialect/Arc/legalize-state-update-error.mlir deleted file mode 100644 index b5fb061786a7..000000000000 --- a/test/Dialect/Arc/legalize-state-update-error.mlir +++ /dev/null @@ -1,22 +0,0 @@ -// RUN: circt-opt %s --arc-legalize-state-update --split-input-file --verify-diagnostics - -arc.model @Memory io !hw.modty<> { -^bb0(%arg0: !arc.storage): - %false = hw.constant false - %mem1 = arc.alloc_memory %arg0 : (!arc.storage) -> !arc.memory<2 x i32, i1> - %mem2 = arc.alloc_memory %arg0 : (!arc.storage) -> !arc.memory<2 x i32, i1> - %s1 = arc.alloc_state %arg0 : (!arc.storage) -> !arc.state - arc.clock_tree %false attributes {ct4} { - %r1 = arc.state_read %s1 : - scf.if %false { - // expected-error @+1 {{could not be moved to be after all reads to the same memory}} - arc.memory_write %mem2[%false], %r1 : <2 x i32, i1> - %mr1 = arc.memory_read %mem1[%false] : <2 x i32, i1> - } - scf.if %false { - arc.memory_write %mem1[%false], %r1 : <2 x i32, i1> - // expected-note @+1 {{could not be moved after this read}} - %mr1 = arc.memory_read %mem2[%false] : <2 x i32, i1> - } - } -} diff --git a/test/Dialect/Arc/legalize-state-update.mlir b/test/Dialect/Arc/legalize-state-update.mlir deleted file mode 100644 index 7dc390ccb681..000000000000 --- a/test/Dialect/Arc/legalize-state-update.mlir +++ /dev/null @@ -1,253 +0,0 @@ -// RUN: circt-opt %s --arc-legalize-state-update | FileCheck %s - -// CHECK-LABEL: func.func @Unaffected -func.func @Unaffected(%arg0: !arc.storage, %arg1: i4) -> i4 { - %0 = arc.alloc_state %arg0 : (!arc.storage) -> !arc.state - %1 = arc.state_read %0 : - arc.state_write %0 = %arg1 : - return %1 : i4 - // CHECK-NEXT: arc.alloc_state - // CHECK-NEXT: arc.state_read - // CHECK-NEXT: arc.state_write - // CHECK-NEXT: return -} -// CHECK-NEXT: } - -// CHECK-LABEL: func.func @SameBlock -func.func @SameBlock(%arg0: !arc.storage, %arg1: i4) -> i4 { - %0 = arc.alloc_state %arg0 : (!arc.storage) -> !arc.state - %1 = arc.state_read %0 : - // CHECK-NEXT: [[STATE:%.+]] = arc.alloc_state - // CHECK-NEXT: arc.state_read [[STATE]] - - arc.state_write %0 = %arg1 : - // CHECK-NEXT: [[TMP:%.+]] = arc.alloc_state - // CHECK-NEXT: [[CURRENT:%.+]] = arc.state_read [[STATE]] - // CHECK-NEXT: arc.state_write [[TMP]] = [[CURRENT]] - // CHECK-NEXT: arc.state_write [[STATE]] = %arg1 - - %2 = arc.state_read %0 : - %3 = arc.state_read %0 : - %4 = comb.xor %1, %2, %3 : i4 - return %4 : i4 - // CHECK-NEXT: arc.state_read [[TMP]] - // CHECK-NEXT: arc.state_read [[TMP]] - // CHECK-NEXT: comb.xor - // CHECK-NEXT: return -} -// CHECK-NEXT: } - -// CHECK-LABEL: func.func @FuncLegal -func.func @FuncLegal(%arg0: !arc.storage, %arg1: i4) -> i4 { - %0 = arc.alloc_state %arg0 : (!arc.storage) -> !arc.state - %1 = call @ReadFunc(%0) : (!arc.state) -> i4 - call @WriteFunc(%0, %arg1) : (!arc.state, i4) -> () - return %1 : i4 - // CHECK-NEXT: arc.alloc_state - // CHECK-NEXT: call @ReadFunc - // CHECK-NEXT: call @WriteFunc - // CHECK-NEXT: return -} -// CHECK-NEXT: } - -// CHECK-LABEL: func.func @FuncIllegal -func.func @FuncIllegal(%arg0: !arc.storage, %arg1: i4) -> i4 { - %0 = arc.alloc_state %arg0 : (!arc.storage) -> !arc.state - %1 = call @ReadFunc(%0) : (!arc.state) -> i4 - // CHECK-NEXT: [[STATE:%.+]] = arc.alloc_state - // CHECK-NEXT: call @ReadFunc - - call @WriteFunc(%0, %arg1) : (!arc.state, i4) -> () - // CHECK-NEXT: [[TMP:%.+]] = arc.alloc_state - // CHECK-NEXT: [[CURRENT:%.+]] = arc.state_read [[STATE]] - // CHECK-NEXT: arc.state_write [[TMP]] = [[CURRENT]] - // CHECK-NEXT: call @WriteFunc - - %2 = call @ReadFunc(%0) : (!arc.state) -> i4 - %3 = call @ReadFunc(%0) : (!arc.state) -> i4 - %4 = comb.xor %1, %2, %3 : i4 - return %4 : i4 - // CHECK-NEXT: call @ReadFunc([[TMP]]) - // CHECK-NEXT: call @ReadFunc([[TMP]]) - // CHECK-NEXT: comb.xor - // CHECK-NEXT: return -} -// CHECK-NEXT: } - -// CHECK-LABEL: func.func @NestedBlocks -func.func @NestedBlocks(%arg0: !arc.storage, %arg1: i4) -> i4 { - %0 = arc.alloc_state %arg0 : (!arc.storage) -> !arc.state - %11 = arc.alloc_state %arg0 : (!arc.storage) -> !arc.state - // CHECK-NEXT: [[S0:%.+]] = arc.alloc_state - // CHECK-NEXT: [[S1:%.+]] = arc.alloc_state - - // CHECK-NEXT: scf.execute_region - %10 = scf.execute_region -> i4 { - // CHECK-NEXT: [[TMP1:%.+]] = arc.alloc_state - // CHECK-NEXT: [[CURRENT:%.+]] = arc.state_read [[S1]] - // CHECK-NEXT: arc.state_write [[TMP1]] = [[CURRENT]] - // CHECK-NEXT: [[TMP0:%.+]] = arc.alloc_state - // CHECK-NEXT: [[CURRENT:%.+]] = arc.state_read [[S0]] - // CHECK-NEXT: arc.state_write [[TMP0]] = [[CURRENT]] - // CHECK-NEXT: scf.execute_region - %3 = scf.execute_region -> i4 { - // CHECK-NEXT: scf.execute_region - %1 = scf.execute_region -> i4 { - %2 = arc.state_read %0 : - scf.yield %2 : i4 - // CHECK-NEXT: arc.state_read [[TMP0]] - // CHECK-NEXT: scf.yield - } - // CHECK-NEXT: } - // CHECK-NEXT: scf.execute_region - scf.execute_region { - arc.state_write %0 = %arg1 : - arc.state_write %11 = %arg1 : - scf.yield - // CHECK-NEXT: arc.state_write [[S0]] - // CHECK-NEXT: arc.state_write [[S1]] - // CHECK-NEXT: scf.yield - } - // CHECK-NEXT: } - scf.yield %1 : i4 - // CHECK-NEXT: scf.yield - } - // CHECK-NEXT: } - func.call @WriteFunc(%0, %arg1) : (!arc.state, i4) -> () - // CHECK-NEXT: func.call @WriteFunc([[S0]], %arg1) - // CHECK-NEXT: scf.execute_region - %7, %8 = scf.execute_region -> (i4, i4) { - // CHECK-NEXT: scf.execute_region - %4 = scf.execute_region -> i4 { - %5 = func.call @ReadFunc(%0) : (!arc.state) -> i4 - scf.yield %5 : i4 - // CHECK-NEXT: func.call @ReadFunc([[TMP0]]) - // CHECK-NEXT: scf.yield - } - // CHECK-NEXT: } - %6 = arc.state_read %0 : - %12 = arc.state_read %11 : - scf.yield %4, %6 : i4, i4 - // CHECK-NEXT: arc.state_read [[TMP0]] - // CHECK-NEXT: arc.state_read [[TMP1]] - // CHECK-NEXT: scf.yield - } - // CHECK-NEXT: } - %9 = comb.xor %3, %7, %8 : i4 - scf.yield %9 : i4 - // CHECK-NEXT: comb.xor - // CHECK-NEXT: scf.yield - } - // CHECK-NEXT: } - return %10 : i4 - // CHECK-NEXT: return -} - -func.func @ReadFunc(%arg0: !arc.state) -> i4 { - %0 = func.call @InnerReadFunc(%arg0) : (!arc.state) -> i4 - return %0 : i4 -} - -func.func @WriteFunc(%arg0: !arc.state, %arg1: i4) { - func.call @InnerWriteFunc(%arg0, %arg1) : (!arc.state, i4) -> () - return -} - -func.func @InnerReadFunc(%arg0: !arc.state) -> i4 { - %0 = arc.state_read %arg0 : - return %0 : i4 -} - -func.func @InnerWriteFunc(%arg0: !arc.state, %arg1: i4) { - arc.state_write %arg0 = %arg1 : - return -} - -// State legalization should not happen across clock trees and passthrough ops. -// CHECK-LABEL: arc.model @DontLeakThroughClockTreeOrPassthrough -arc.model @DontLeakThroughClockTreeOrPassthrough io !hw.modty { -^bb0(%arg0: !arc.storage): - %false = hw.constant false - %in_a = arc.root_input "a", %arg0 : (!arc.storage) -> !arc.state - %out_b = arc.root_output "b", %arg0 : (!arc.storage) -> !arc.state - // CHECK: arc.alloc_state %arg0 {foo} - %0 = arc.alloc_state %arg0 {foo} : (!arc.storage) -> !arc.state - // CHECK-NOT: arc.alloc_state - // CHECK-NOT: arc.state_read - // CHECK-NOT: arc.state_write - // CHECK: arc.clock_tree - arc.clock_tree %false { - %1 = arc.state_read %in_a : - arc.state_write %0 = %1 : - } - // CHECK: arc.passthrough - arc.passthrough { - %1 = arc.state_read %0 : - arc.state_write %out_b = %1 : - } -} - -// CHECK-LABEL: arc.model @Memory -arc.model @Memory io !hw.modty<> { -^bb0(%arg0: !arc.storage): - %false = hw.constant false - // CHECK: [[MEM1:%.+]] = arc.alloc_memory %arg0 : - // CHECK: [[MEM2:%.+]] = arc.alloc_memory %arg0 : - %mem1 = arc.alloc_memory %arg0 : (!arc.storage) -> !arc.memory<2 x i32, i1> - %mem2 = arc.alloc_memory %arg0 : (!arc.storage) -> !arc.memory<2 x i32, i1> - %s1 = arc.alloc_state %arg0 : (!arc.storage) -> !arc.state - // CHECK: arc.clock_tree %false attributes {ct1} - arc.clock_tree %false attributes {ct1} { - // CHECK-NEXT: arc.state_read - // CHECK-NEXT: arc.memory_read [[MEM1]][%false] - // CHECK-NEXT: arc.memory_write [[MEM1]] - // CHECK-NEXT: arc.memory_read [[MEM2]][%false] - // CHECK-NEXT: arc.memory_write [[MEM2]] - %r1 = arc.state_read %s1 : - arc.memory_write %mem2[%false], %r1 : <2 x i32, i1> - arc.memory_write %mem1[%false], %r1 : <2 x i32, i1> - %mr1 = arc.memory_read %mem1[%false] : <2 x i32, i1> - %mr2 = arc.memory_read %mem2[%false] : <2 x i32, i1> - // CHECK-NEXT: } - } - // CHECK: arc.clock_tree %false attributes {ct2} - arc.clock_tree %false attributes {ct2} { - // CHECK-NEXT: arc.state_read - // CHECK-NEXT: arc.memory_read - // CHECK-NEXT: scf.if %false { - // CHECK-NEXT: arc.memory_read - // CHECK-NEXT: } - // CHECK-NEXT: arc.memory_write - %r1 = arc.state_read %s1 : - arc.memory_write %mem1[%false], %r1 : <2 x i32, i1> - %mr1 = arc.memory_read %mem1[%false] : <2 x i32, i1> - scf.if %false { - %mr2 = arc.memory_read %mem1[%false] : <2 x i32, i1> - } - // CHECK-NEXT: } - } - // CHECK: arc.clock_tree %false attributes {ct3} - arc.clock_tree %false attributes {ct3} { - // CHECK-NEXT: arc.memory_read [[MEM1]] - // CHECK-NEXT: arc.memory_read [[MEM2]] - // CHECK-NEXT: scf.if %false { - // CHECK-NEXT: arc.state_read - // CHECK-NEXT: scf.if %false { - // CHECK-NEXT: arc.memory_write [[MEM2]] - // CHECK-NEXT: arc.memory_read [[MEM1]] - // CHECK-NEXT: } - // CHECK-NEXT: arc.memory_write [[MEM1]] - // CHECK-NEXT: } - scf.if %false { - %r1 = arc.state_read %s1 : - arc.memory_write %mem1[%false], %r1 : <2 x i32, i1> - scf.if %false { - arc.memory_write %mem2[%false], %r1 : <2 x i32, i1> - %mr3 = arc.memory_read %mem1[%false] : <2 x i32, i1> - } - } - %mr1 = arc.memory_read %mem1[%false] : <2 x i32, i1> - %mr2 = arc.memory_read %mem2[%false] : <2 x i32, i1> - // CHECK-NEXT: } - } -} diff --git a/test/Dialect/Arc/lower-state-errors.mlir b/test/Dialect/Arc/lower-state-errors.mlir deleted file mode 100644 index 50edab4afdfd..000000000000 --- a/test/Dialect/Arc/lower-state-errors.mlir +++ /dev/null @@ -1,39 +0,0 @@ -// RUN: circt-opt %s --arc-lower-state --split-input-file --verify-diagnostics - -arc.define @DummyArc(%arg0: i42) -> i42 { - arc.output %arg0 : i42 -} - -// expected-error @+1 {{Value cannot be used in initializer.}} -hw.module @argInit(in %clk: !seq.clock, in %input: i42) { - %0 = arc.state @DummyArc(%0) clock %clk initial (%input : i42) latency 1 : (i42) -> i42 -} - - -// ----- - - -arc.define @DummyArc(%arg0: i42) -> i42 { - arc.output %arg0 : i42 -} - -hw.module @argInit(in %clk: !seq.clock, in %input: i42) { - // expected-error @+1 {{Value cannot be used in initializer.}} - %0 = arc.state @DummyArc(%0) clock %clk latency 1 : (i42) -> i42 - %1 = arc.state @DummyArc(%1) clock %clk initial (%0 : i42) latency 1 : (i42) -> i42 -} - -// ----- - -// expected-error @+1 {{initial ops cannot be topologically sorted}} -hw.module @toposort_failure(in %clk: !seq.clock, in %rst: i1, in %i: i32) { - %init = seq.initial (%add) { - ^bb0(%arg0: i32): - seq.yield %arg0 : i32 - } : (!seq.immutable) -> !seq.immutable - - %add = seq.initial (%init) { - ^bb0(%arg0 : i32): - seq.yield %arg0 : i32 - } : (!seq.immutable) -> !seq.immutable -} diff --git a/test/Dialect/Arc/lower-state.mlir b/test/Dialect/Arc/lower-state.mlir index d4940137b58c..a29fa81566a3 100644 --- a/test/Dialect/Arc/lower-state.mlir +++ b/test/Dialect/Arc/lower-state.mlir @@ -1,296 +1,611 @@ // RUN: circt-opt %s --arc-lower-state | FileCheck %s +func.func private @VoidFunc() +func.func private @RandomI42() -> i42 +func.func private @ConsumeI42(i42) +func.func private @IdI42(i42) -> i42 + +arc.define @Not(%arg0: i1) -> i1 { + %true = hw.constant true + %0 = comb.xor %arg0, %true : i1 + arc.output %0 : i1 +} + +arc.define @IdI42Arc(%arg0: i42) -> i42 { + arc.output %arg0 : i42 +} + +arc.define @IdI2AndI42Arc(%arg0: i2, %arg1: i42) -> (i2, i42) { + arc.output %arg0, %arg1 : i2, i42 +} + +arc.define @IdI2AndI42AndI1Arc(%arg0: i2, %arg1: i42, %arg2: i1) -> (i2, i42, i1) { + arc.output %arg0, %arg1, %arg2 : i2, i42, i1 +} + +arc.define @IdI2AndI42AndI1AndI42Arc(%arg0: i2, %arg1: i42, %arg2: i1, %arg3: i42) -> (i2, i42, i1, i42) { + arc.output %arg0, %arg1, %arg2, %arg3 : i2, i42, i1, i42 +} + +arc.define @RandomI42Arc() -> i42 { + %0 = hw.constant 42 : i42 + arc.output %0 : i42 +} + +arc.define @RandomI42AndI19Arc() -> (i42, i19) { + %0 = hw.constant 42 : i42 + %1 = hw.constant 1337 : i19 + arc.output %0, %1 : i42, i19 +} + // CHECK-LABEL: arc.model @Empty // CHECK-NEXT: ^bb0(%arg0: !arc.storage): // CHECK-NEXT: } -hw.module @Empty() { -} - -// CHECK-LABEL: arc.model @InputsAndOutputs -hw.module @InputsAndOutputs(in %a: i42, in %b: i17, out c: i42, out d: i17) { - %0 = comb.add %a, %a : i42 - %1 = comb.add %b, %b : i17 - hw.output %0, %1 : i42, i17 - // CHECK-NEXT: (%arg0: !arc.storage): - // CHECK-NEXT: [[INA:%.+]] = arc.root_input "a", %arg0 : (!arc.storage) -> !arc.state - // CHECK-NEXT: [[INB:%.+]] = arc.root_input "b", %arg0 : (!arc.storage) -> !arc.state - // CHECK-NEXT: [[OUTA:%.+]] = arc.root_output "c", %arg0 : (!arc.storage) -> !arc.state - // CHECK-NEXT: [[OUTB:%.+]] = arc.root_output "d", %arg0 : (!arc.storage) -> !arc.state - - // CHECK-NEXT: arc.passthrough { - // CHECK-NEXT: [[A:%.+]] = arc.state_read [[INA]] : - // CHECK-NEXT: [[TMP:%.+]] = comb.add [[A]], [[A]] : i42 - // CHECK-NEXT: arc.state_write [[OUTA]] = [[TMP]] : - // CHECK-NEXT: [[B:%.+]] = arc.state_read [[INB]] : - // CHECK-NEXT: [[TMP:%.+]] = comb.add [[B]], [[B]] : i17 - // CHECK-NEXT: arc.state_write [[OUTB]] = [[TMP]] : - // CHECK-NEXT: } +hw.module @Empty() {} + +// CHECK-LABEL: arc.model @InputToOutput +hw.module @InputToOutput(in %a: i42, out b: i42) { + // CHECK: [[TMP1:%.+]] = arc.state_read %in_a + // CHECK-NEXT: [[TMP2:%.+]] = comb.xor [[TMP1]] + // CHECK-NEXT: [[TMP3:%.+]] = func.call @IdI42([[TMP2]]) + %2 = comb.xor %1 : i42 + %1 = func.call @IdI42(%0) : (i42) -> i42 + %0 = comb.xor %a : i42 + // CHECK-NEXT: func.call @ConsumeI42([[TMP2]]) + func.call @ConsumeI42(%0) : (i42) -> () + // CHECK-NEXT: [[TMP4:%.+]] = comb.xor [[TMP3]] + // CHECK-NEXT: arc.state_write %out_b = [[TMP4]] + hw.output %2 : i42 } -// CHECK-LABEL: arc.model @State -hw.module @State(in %clk: !seq.clock, in %en: i1, in %en2: i1) { - %gclk = seq.clock_gate %clk, %en, %en2 - %3 = arc.state @DummyArc(%6) clock %clk latency 1 : (i42) -> i42 - %4 = arc.state @DummyArc(%5) clock %gclk latency 1 : (i42) -> i42 - %5 = comb.add %3, %3 : i42 - %6 = comb.add %4, %4 : i42 - // CHECK-NEXT: (%arg0: !arc.storage): - // CHECK-NEXT: [[INCLK:%.+]] = arc.root_input "clk", %arg0 : (!arc.storage) -> !arc.state - // CHECK-NEXT: [[INEN:%.+]] = arc.root_input "en", %arg0 : (!arc.storage) -> !arc.state - // CHECK-NEXT: [[INEN2:%.+]] = arc.root_input "en2", %arg0 : (!arc.storage) -> !arc.state - // CHECK-NEXT: [[CLK_OLD:%.+]] = arc.alloc_state %arg0 : (!arc.storage) -> !arc.state - // CHECK-NEXT: [[S0:%.+]] = arc.alloc_state %arg0 : (!arc.storage) -> !arc.state - // CHECK-NEXT: [[S1:%.+]] = arc.alloc_state %arg0 : (!arc.storage) -> !arc.state - - // CHECK-NEXT: [[TMP2:%.+]] = arc.state_read [[INCLK]] : - // CHECK-NEXT: arc.state_write [[CLK_OLD]] = [[TMP2]] : - // CHECK-NEXT: [[TMP1:%.+]] = arc.state_read [[CLK_OLD]] : - // CHECK-NEXT: [[TMP3:%.+]] = comb.icmp ne [[TMP1]], [[TMP2]] : i1 - // CHECK-NEXT: [[TMP4:%.+]] = comb.and [[TMP3]], [[TMP2]] : i1 - - // CHECK-NEXT: arc.clock_tree [[TMP4]] { - // CHECK-NEXT: [[TMP0:%.+]] = arc.state_read [[S1]] : - // CHECK-NEXT: [[TMP1:%.+]] = comb.add [[TMP0]], [[TMP0]] - // CHECK-NEXT: [[TMP2:%.+]] = arc.call @DummyArc([[TMP1]]) : (i42) -> i42 - // CHECK-NEXT: arc.state_write [[S0]] = [[TMP2]] : - // CHECK-NEXT: [[EN:%.+]] = arc.state_read [[INEN]] : - // CHECK-NEXT: [[EN2:%.+]] = arc.state_read [[INEN2]] : - // CHECK-NEXT: [[TMP3:%.+]] = comb.or [[EN]], [[EN2]] : i1 - // CHECK-NEXT: [[TMP0:%.+]] = arc.state_read [[S0]] : - // CHECK-NEXT: [[TMP1:%.+]] = comb.add [[TMP0]], [[TMP0]] - // CHECK-NEXT: [[TMP2:%.+]] = arc.call @DummyArc([[TMP1]]) : (i42) -> i42 - // CHECK-NEXT: arc.state_write [[S1]] = [[TMP2]] if [[TMP3]] : +// CHECK-LABEL: arc.model @ReadsBeforeUpdate +hw.module @ReadsBeforeUpdate(in %clock: !seq.clock, in %a: i42, out b: i42) { + // CHECK: [[Q1:%.+]] = arc.alloc_state %arg0 {name = "q1"} + // CHECK: [[Q0:%.+]] = arc.alloc_state %arg0 {name = "q0"} + // CHECK: scf.if {{%.+}} { + // CHECK-NEXT: [[TMP1:%.+]] = arc.state_read [[Q0]] + // CHECK-NEXT: [[TMP2:%.+]] = arc.call @IdI42Arc([[TMP1]]) + // CHECK-NEXT: arc.state_write [[Q1]] = [[TMP2]] + // CHECK-NEXT: [[TMP1:%.+]] = arc.state_read %in_a + // CHECK-NEXT: [[TMP2:%.+]] = arc.call @IdI42Arc([[TMP1]]) + // CHECK-NEXT: arc.state_write [[Q0]] = [[TMP2]] // CHECK-NEXT: } + %q1 = arc.state @IdI42Arc(%q0) clock %clock latency 1 {names = ["q1"]} : (i42) -> i42 + %q0 = arc.state @IdI42Arc(%a) clock %clock latency 1 {names = ["q0"]} : (i42) -> i42 + // CHECK-NEXT: [[TMP:%.+]] = arc.state_read [[Q1]] + // CHECK-NEXT: arc.state_write %out_b = [[TMP]] + hw.output %q1 : i42 } -// CHECK-LABEL: arc.model @State2 -hw.module @State2(in %clk: !seq.clock) { - %3 = arc.state @DummyArc(%3) clock %clk latency 1 : (i42) -> i42 - %4 = arc.state @DummyArc(%4) clock %clk latency 1 : (i42) -> i42 - // CHECK-NEXT: (%arg0: !arc.storage): - // CHECK-NEXT: [[INCLK:%.+]] = arc.root_input "clk", %arg0 : (!arc.storage) -> !arc.state - // CHECK-NEXT: [[CLK_OLD:%.+]] = arc.alloc_state %arg0 : (!arc.storage) -> !arc.state - // CHECK-NEXT: [[S0]] = arc.alloc_state %arg0 : (!arc.storage) -> !arc.state - // CHECK-NEXT: [[S1]] = arc.alloc_state %arg0 : (!arc.storage) -> !arc.state - - // CHECK-NEXT: [[TMP2:%.+]] = arc.state_read [[INCLK]] : - // CHECK-NEXT: arc.state_write [[CLK_OLD]] = [[TMP2]] : - // CHECK-NEXT: [[TMP1:%.+]] = arc.state_read [[CLK_OLD]] : - // CHECK-NEXT: [[TMP3:%.+]] = comb.icmp ne [[TMP1]], [[TMP2]] : i1 - // CHECK-NEXT: [[TMP4:%.+]] = comb.and [[TMP3]], [[TMP2]] : i1 - - // CHECK-NEXT: arc.clock_tree [[TMP4]] { - // CHECK-NEXT: [[TMP0:%.+]] = arc.state_read [[S0:%.+]] : - // CHECK-NEXT: [[TMP1:%.+]] = arc.call @DummyArc([[TMP0]]) : (i42) -> i42 - // CHECK-NEXT: arc.state_write [[S0]] = [[TMP1]] : - // CHECK-NEXT: [[TMP2:%.+]] = arc.state_read [[S1:%.+]] : - // CHECK-NEXT: [[TMP3:%.+]] = arc.call @DummyArc([[TMP2]]) : (i42) -> i42 - // CHECK-NEXT: arc.state_write [[S1]] = [[TMP3]] : +// CHECK-LABEL: arc.model @ReadsAfterUpdate +hw.module @ReadsAfterUpdate(in %clock: !seq.clock, in %a: i42, out b: i42) { + // CHECK: [[Q0:%.+]] = arc.alloc_state %arg0 {name = "q0"} + // CHECK: [[Q1:%.+]] = arc.alloc_state %arg0 {name = "q1"} + // CHECK: [[Q0_OLD:%.+]] = arc.state_read [[Q0]] + // CHECK: scf.if {{%.+}} { + // CHECK-NEXT: [[TMP1:%.+]] = arc.state_read %in_a + // CHECK-NEXT: [[TMP2:%.+]] = arc.call @IdI42Arc([[TMP1]]) + // CHECK-NEXT: arc.state_write [[Q0]] = [[TMP2]] + // CHECK-NEXT: [[TMP:%.+]] = arc.call @IdI42Arc([[Q0_OLD]]) + // CHECK-NEXT: arc.state_write [[Q1]] = [[TMP]] // CHECK-NEXT: } + %q0 = arc.state @IdI42Arc(%a) clock %clock latency 1 {names = ["q0"]} : (i42) -> i42 + %q1 = arc.state @IdI42Arc(%q0) clock %clock latency 1 {names = ["q1"]} : (i42) -> i42 + // CHECK-NEXT: [[TMP:%.+]] = arc.state_read [[Q1]] + // CHECK-NEXT: arc.state_write %out_b = [[TMP]] + hw.output %q1 : i42 } -arc.define @DummyArc(%arg0: i42) -> i42 { - arc.output %arg0 : i42 +// CHECK-LABEL: arc.model @ClockDivBy4 +hw.module @ClockDivBy4(in %clock: !seq.clock, out z: i1) { + // CHECK: [[Q0:%.+]] = arc.alloc_state %arg0 {name = "q0"} + // CHECK: [[Q1:%.+]] = arc.alloc_state %arg0 {name = "q1"} + // CHECK: [[TMP1:%.+]] = arc.state_read %in_clock + // CHECK: [[TMP2:%.+]] = seq.from_clock [[TMP1]] + // CHECK: [[CLOCK_EDGE:%.+]] = comb.and {{%.+}}, [[TMP2]] + // CHECK: scf.if [[CLOCK_EDGE]] { + // CHECK: arc.state_write [[Q0]] + // CHECK: } + %q0 = arc.state @Not(%q0) clock %clock latency 1 {names = ["q0"]} : (i1) -> i1 + // CHECK: [[TMP1:%.+]] = arc.state_read [[Q0]] + // CHECK: [[Q0_EDGE:%.+]] = comb.and {{%.+}}, [[TMP1]] + // CHECK: scf.if [[Q0_EDGE]] { + // CHECK: arc.state_write [[Q1]] + // CHECK: } + %0 = seq.to_clock %q0 + %q1 = arc.state @Not(%q1) clock %0 latency 1 {names = ["q1"]} : (i1) -> i1 + // CHECK: [[Q1_NEW:%.+]] = arc.state_read [[Q1]] + // CHECK: arc.state_write %out_z = [[Q1_NEW]] + hw.output %q1 : i1 +} + +// CHECK-LABEL: arc.model @EnablePort +hw.module @EnablePort(in %clock: !seq.clock, in %a: i42, in %en: i1) { + // CHECK: scf.if {{%.+}} { + // CHECK: [[EN:%.+]] = arc.state_read %in_en + // CHECK: scf.if [[EN]] { + // CHECK: arc.state_write + // CHECK: arc.state_write + // CHECK: } + // CHECK: } + %q0 = arc.state @IdI42Arc(%a) clock %clock enable %en latency 1 : (i42) -> i42 + %q1 = arc.state @IdI42Arc(%q0) clock %clock enable %en latency 1 : (i42) -> i42 +} + +// CHECK-LABEL: arc.model @EnableLocal +hw.module @EnableLocal(in %clock: !seq.clock, in %a: i42) { + // CHECK: [[TMP1:%.+]] = hw.constant 9001 + // CHECK: [[TMP2:%.+]] = arc.state_read %in_a + // CHECK: [[TMP3:%.+]] = comb.icmp ne [[TMP2]], [[TMP1]] + // CHECK: scf.if {{%.+}} { + // CHECK: scf.if [[TMP3]] { + // CHECK: arc.state_write + // CHECK: arc.state_write + // CHECK: } + // CHECK: } + %0 = hw.constant 9001 : i42 + %1 = comb.icmp ne %a, %0 : i42 + %q0 = arc.state @IdI42Arc(%a) clock %clock enable %1 latency 1 : (i42) -> i42 + %q1 = arc.state @IdI42Arc(%q0) clock %clock enable %1 latency 1 : (i42) -> i42 +} + +// CHECK-LABEL: arc.model @Reset +hw.module @Reset(in %clock: !seq.clock, in %a: i42, in %reset: i1) { + // CHECK: scf.if {{%.+}} { + // CHECK: [[RESET:%.+]] = arc.state_read %in_reset + // CHECK: scf.if [[RESET]] { + // CHECK: arc.state_write + // CHECK: arc.state_write + // CHECK: } + // CHECK: } + %q0 = arc.state @IdI42Arc(%a) clock %clock reset %reset latency 1 : (i42) -> i42 + %q1 = arc.state @IdI42Arc(%q0) clock %clock reset %reset latency 1 : (i42) -> i42 +} + +// CHECK-LABEL: arc.model @ResetLocal +hw.module @ResetLocal(in %clock: !seq.clock, in %a: i42) { + // CHECK: [[TMP1:%.+]] = hw.constant 9001 + // CHECK: [[TMP2:%.+]] = arc.state_read %in_a + // CHECK: [[TMP3:%.+]] = comb.icmp ne [[TMP2]], [[TMP1]] + // CHECK: scf.if {{%.+}} { + // CHECK: scf.if [[TMP3]] { + // CHECK: arc.state_write + // CHECK: arc.state_write + // CHECK: } + // CHECK: } + %0 = hw.constant 9001 : i42 + %1 = comb.icmp ne %a, %0 : i42 + %q0 = arc.state @IdI42Arc(%a) clock %clock reset %1 latency 1 : (i42) -> i42 + %q1 = arc.state @IdI42Arc(%q0) clock %clock reset %1 latency 1 : (i42) -> i42 +} + +// CHECK-LABEL: arc.model @ResetAndEnable +hw.module @ResetAndEnable(in %clock: !seq.clock, in %a: i42, in %reset: i1, in %en: i1) { + // CHECK: [[Q0:%.+]] = arc.alloc_state %arg0 {name = "q0"} + // CHECK: [[Q1:%.+]] = arc.alloc_state %arg0 {name = "q1"} + // CHECK: [[Q2:%.+]] = arc.alloc_state %arg0 {name = "q2"} + // CHECK: [[Q3:%.+]] = arc.alloc_state %arg0 {name = "q3"} + // CHECK: scf.if {{%.+}} { + // CHECK: [[EN:%.+]] = arc.state_read %in_en + // CHECK: scf.if [[EN]] { + // CHECK: arc.state_write [[Q0]] + // CHECK: } + // CHECK: [[RESET:%.+]] = arc.state_read %in_reset + // CHECK: scf.if [[RESET]] { + // CHECK: [[TMP:%.+]] = hw.constant 0 + // CHECK: arc.state_write [[Q1]] = [[TMP]] + // CHECK: [[TMP:%.+]] = hw.constant 0 + // CHECK: arc.state_write [[Q2]] = [[TMP]] + // CHECK: [[TMP:%.+]] = hw.constant 0 + // CHECK: arc.state_write [[Q3]] = [[TMP]] + // CHECK: } else { + // CHECK: [[EN:%.+]] = arc.state_read %in_en + // CHECK: scf.if [[EN]] { + // CHECK: arc.state_write [[Q1]] + // CHECK: arc.state_write [[Q2]] + // CHECK: } + // CHECK: arc.state_write [[Q3]] + // CHECK: } + // CHECK: } + %q0 = arc.state @IdI42Arc(%a) clock %clock enable %en latency 1 {names = ["q0"]} : (i42) -> i42 + %q1 = arc.state @IdI42Arc(%q0) clock %clock enable %en reset %reset latency 1 {names = ["q1"]} : (i42) -> i42 + %q2 = arc.state @IdI42Arc(%q1) clock %clock enable %en reset %reset latency 1 {names = ["q2"]} : (i42) -> i42 + %q3 = arc.state @IdI42Arc(%q2) clock %clock reset %reset latency 1 {names = ["q3"]} : (i42) -> i42 } -// CHECK-LABEL: arc.model @NonMaskedMemoryWrite -hw.module @NonMaskedMemoryWrite(in %clk0: !seq.clock) { - %c0_i2 = hw.constant 0 : i2 - %c9001_i42 = hw.constant 9001 : i42 +// CHECK-LABEL: arc.model @BlackBox +hw.module @BlackBox(in %clock: !seq.clock, in %a: i42, out b: i42) { + // CHECK: [[Q0:%.+]] = arc.alloc_state %arg0 {name = "q0"} + // CHECK: [[S:%.+]] = arc.alloc_state %arg0 {name = "ext/s"} + // CHECK: [[P:%.+]] = arc.alloc_state %arg0 {name = "ext/p"} + // CHECK: [[Q:%.+]] = arc.alloc_state %arg0 {name = "ext/q"} + // CHECK: [[R:%.+]] = arc.alloc_state %arg0 {name = "ext/r"} + // CHECK: [[Q1:%.+]] = arc.alloc_state %arg0 {name = "q1"} + // CHECK: scf.if {{%.+}} { + // CHECK: arc.state_write [[Q0]] + // CHECK: } + %0 = arc.state @IdI42Arc(%a) clock %clock latency 1 {names = ["q0"]} : (i42) -> i42 + // CHECK: [[Q0_NEW:%.+]] = arc.state_read [[Q0]] + // CHECK: [[S_NEW:%.+]] = arc.state_read [[S]] + // CHECK: [[TMP:%.+]] = comb.and [[Q0_NEW]], [[S_NEW]] + // CHECK: [[Q0_NEW:%.+]] = arc.state_read [[Q0]] + // CHECK: arc.state_write [[P]] = [[Q0_NEW]] + // CHECK: arc.state_write [[Q]] = [[TMP]] + %1 = comb.and %0, %3 : i42 + %2, %3 = hw.instance "ext" @BlackBoxExt(p: %0: i42, q: %1: i42) -> (r: i42, s: i42) + // CHECK: scf.if {{%.+}} { + // CHECK: [[S_NEW:%.+]] = arc.state_read [[S]] + // CHECK: [[TMP:%.+]] = arc.call @IdI42Arc([[S_NEW]]) + // CHECK: arc.state_write [[Q1]] = [[TMP]] + // CHECK: } + %4 = arc.state @IdI42Arc(%3) clock %clock latency 1 {names = ["q1"]} : (i42) -> i42 + // CHECK: [[R_NEW:%.+]] = arc.state_read [[R]] + // CHECK: [[Q1_NEW:%.+]] = arc.state_read [[Q1]] + // CHECK: [[TMP:%.+]] = comb.or [[R_NEW]], [[Q1_NEW]] + // CHECK: arc.state_write %out_b = [[TMP]] + %5 = comb.or %2, %4 : i42 + hw.output %5 : i42 +} + +hw.module.extern private @BlackBoxExt(in %p: i42, in %q: i42, out r: i42, out s: i42) + +// CHECK-LABEL: arc.model @MemoryInputToOutput +hw.module @MemoryInputToOutput(in %clock: !seq.clock, in %a: i2, in %b: i42, out c: i42) { + // CHECK: [[MEM:%.+]] = arc.alloc_memory + // CHECK: scf.if {{%.+}} { + // CHECK: [[A:%.+]] = arc.state_read %in_a + // CHECK: [[B:%.+]] = arc.state_read %in_b + // CHECK: [[TMP:%.+]]:2 = arc.call @IdI2AndI42Arc([[A]], [[B]]) + // CHECK: arc.memory_write [[MEM]][[[TMP]]#0], [[TMP]]#1 + // CHECK: } %mem = arc.memory <4 x i42, i2> - arc.memory_write_port %mem, @identity(%c0_i2, %c9001_i42) clock %clk0 latency 1 : <4 x i42, i2>, i2, i42 - - // CHECK-NEXT: (%arg0: !arc.storage): - // CHECK-NEXT: [[INCLK:%.+]] = arc.root_input "clk0", %arg0 : (!arc.storage) -> !arc.state - // CHECK-NEXT: [[MEM:%.+]] = arc.alloc_memory %arg0 : (!arc.storage) -> !arc.memory<4 x i42, i2> - // CHECK-NEXT: [[CLK_OLD:%.+]] = arc.alloc_state %arg0 : (!arc.storage) -> !arc.state - - // CHECK-NEXT: [[TMP2:%.+]] = arc.state_read [[INCLK]] : - // CHECK-NEXT: arc.state_write [[CLK_OLD]] = [[TMP2]] : - // CHECK-NEXT: [[TMP1:%.+]] = arc.state_read [[CLK_OLD]] : - // CHECK-NEXT: [[TMP3:%.+]] = comb.icmp ne [[TMP1]], [[TMP2]] : i1 - // CHECK-NEXT: [[TMP4:%.+]] = comb.and [[TMP3]], [[TMP2]] : i1 - - // CHECK-NEXT: arc.clock_tree [[TMP4]] { - // CHECK: [[RES:%.+]]:2 = arc.call @identity(%c0_i2, %c9001_i42) : (i2, i42) -> (i2, i42) - // CHECK: arc.memory_write [[MEM]][[[RES]]#0], [[RES]]#1 : <4 x i42, i2> - // CHECK-NEXT: } + %0 = arc.memory_read_port %mem[%a] : <4 x i42, i2> + arc.memory_write_port %mem, @IdI2AndI42Arc(%a, %b) clock %clock latency 1 : <4 x i42, i2>, i2, i42 + // CHECK: [[A:%.+]] = arc.state_read %in_a + // CHECK: [[TMP:%.+]] = arc.memory_read [[MEM]][[[A]]] + // CHECK: arc.state_write %out_c = [[TMP]] + hw.output %0 : i42 } -arc.define @identity(%arg0: i2, %arg1: i42) -> (i2, i42) { - arc.output %arg0, %arg1 : i2, i42 + +// CHECK-LABEL: arc.model @MemoryReadBeforeUpdate +hw.module @MemoryReadBeforeUpdate(in %clock: !seq.clock, in %a: i2, in %b: i42, in %en: i1, out c: i42) { + // CHECK: [[Q0:%.+]] = arc.alloc_state %arg0 {name = "q0"} + // CHECK: [[MEM:%.+]] = arc.alloc_memory + // Port ops are ignored. Writes are lowered with `arc.memory`, reads when they + // are used by an op. + %0 = arc.memory_read_port %mem[%a] : <4 x i42, i2> + arc.memory_write_port %mem, @IdI2AndI42Arc(%a, %b) clock %clock latency 1 : <4 x i42, i2>, i2, i42 + // CHECK: scf.if {{%.+}} { + // CHECK: [[A:%.+]] = arc.state_read %in_a + // CHECK: [[TMP1:%.+]] = arc.memory_read [[MEM]][[[A]]] + // CHECK: [[TMP2:%.+]] = arc.call @IdI42Arc([[TMP1]]) + // CHECK: arc.state_write [[Q0]] = [[TMP2]] + %q0 = arc.state @IdI42Arc(%0) clock %clock latency 1 {names = ["q0"]} : (i42) -> i42 + // CHECK: [[A:%.+]] = arc.state_read %in_a + // CHECK: [[B:%.+]] = arc.state_read %in_b + // CHECK: [[TMP:%.+]]:2 = arc.call @IdI2AndI42Arc([[A]], [[B]]) + // CHECK: arc.memory_write [[MEM]][[[TMP]]#0], [[TMP]]#1 + %mem = arc.memory <4 x i42, i2> + // CHECK: } + // CHECK: [[Q0_NEW:%.+]] = arc.state_read [[Q0]] + // CHECK: arc.state_write %out_c = [[Q0_NEW]] + hw.output %q0 : i42 } -// CHECK-LABEL: arc.model @lowerMemoryReadPorts -hw.module @lowerMemoryReadPorts(out out0: i42, out out1: i42) { - %c0_i2 = hw.constant 0 : i2 +// CHECK-LABEL: arc.model @MemoryReadAfterUpdate +hw.module @MemoryReadAfterUpdate(in %clock: !seq.clock, in %a: i2, in %b: i42, in %en: i1, out c: i42) { + // CHECK: [[MEM:%.+]] = arc.alloc_memory + // CHECK: [[Q0:%.+]] = arc.alloc_state %arg0 {name = "q0"} + // Port ops are ignored. Writes are lowered with `arc.memory`, reads when they + // are used by an op. + %0 = arc.memory_read_port %mem[%a] : <4 x i42, i2> + arc.memory_write_port %mem, @IdI2AndI42Arc(%a, %b) clock %clock latency 1 : <4 x i42, i2>, i2, i42 + // CHECK: [[A:%.+]] = arc.state_read %in_a + // CHECK: [[READ_OLD:%.+]] = arc.memory_read [[MEM]][[[A]]] + // CHECK: scf.if {{%.+}} { + // CHECK: [[A:%.+]] = arc.state_read %in_a + // CHECK: [[B:%.+]] = arc.state_read %in_b + // CHECK: [[TMP:%.+]]:2 = arc.call @IdI2AndI42Arc([[A]], [[B]]) + // CHECK: arc.memory_write [[MEM]][[[TMP]]#0], [[TMP]]#1 %mem = arc.memory <4 x i42, i2> - // CHECK: arc.memory_read {{%.+}}[%c0_i2] : <4 x i42, i2> - %0 = arc.memory_read_port %mem[%c0_i2] : <4 x i42, i2> - // CHECK: func.call @arcWithMemoryReadsIsLowered - %1 = arc.call @arcWithMemoryReadsIsLowered(%mem) : (!arc.memory<4 x i42, i2>) -> i42 - hw.output %0, %1 : i42, i42 -} - -// CHECK-LABEL: func.func @arcWithMemoryReadsIsLowered(%arg0: !arc.memory<4 x i42, i2>) -> i42 attributes {llvm.linkage = #llvm.linkage} -arc.define @arcWithMemoryReadsIsLowered(%mem: !arc.memory<4 x i42, i2>) -> i42 { - %c0_i2 = hw.constant 0 : i2 - // CHECK: arc.memory_read {{%.+}}[%c0_i2] : <4 x i42, i2> - %0 = arc.memory_read_port %mem[%c0_i2] : <4 x i42, i2> - // CHECK-NEXT: return - arc.output %0 : i42 + // CHECK: [[TMP:%.+]] = arc.call @IdI42Arc([[READ_OLD]]) + // CHECK: arc.state_write [[Q0]] = [[TMP]] + %q0 = arc.state @IdI42Arc(%0) clock %clock latency 1 {names = ["q0"]} : (i42) -> i42 + // CHECK: } + // CHECK: [[Q0_NEW:%.+]] = arc.state_read [[Q0]] + // CHECK: arc.state_write %out_c = [[Q0_NEW]] + hw.output %q0 : i42 } -// CHECK-LABEL: arc.model @maskedMemoryWrite -hw.module @maskedMemoryWrite(in %clk: !seq.clock) { - %true = hw.constant true - %c0_i2 = hw.constant 0 : i2 - %c9001_i42 = hw.constant 9001 : i42 - %c1010_i42 = hw.constant 1010 : i42 +// CHECK-LABEL: arc.model @MemoryEnable +hw.module @MemoryEnable(in %clock: !seq.clock, in %a: i2, in %b: i42, in %en: i1) { + // CHECK: [[MEM:%.+]] = arc.alloc_memory + // CHECK: scf.if {{%.+}} { + // CHECK: [[A:%.+]] = arc.state_read %in_a + // CHECK: [[B:%.+]] = arc.state_read %in_b + // CHECK: [[EN:%.+]] = arc.state_read %in_en + // CHECK: [[TMP:%.+]]:3 = arc.call @IdI2AndI42AndI1Arc([[A]], [[B]], [[EN]]) + // CHECK: scf.if [[TMP]]#2 { + // CHECK: arc.memory_write [[MEM]][[[TMP]]#0], [[TMP]]#1 + // CHECK: } + // CHECK: } %mem = arc.memory <4 x i42, i2> - arc.memory_write_port %mem, @identity2(%c0_i2, %c9001_i42, %true, %c1010_i42) clock %clk enable mask latency 1 : <4 x i42, i2>, i2, i42, i1, i42 + arc.memory_write_port %mem, @IdI2AndI42AndI1Arc(%a, %b, %en) clock %clock enable latency 1 : <4 x i42, i2>, i2, i42, i1 } -arc.define @identity2(%arg0: i2, %arg1: i42, %arg2: i1, %arg3: i42) -> (i2, i42, i1, i42) { - arc.output %arg0, %arg1, %arg2, %arg3 : i2, i42, i1, i42 + +// CHECK-LABEL: arc.model @MemoryEnableAndMask +hw.module @MemoryEnableAndMask(in %clock: !seq.clock, in %a: i2, in %b: i42, in %en: i1, in %mask: i42) { + // CHECK: [[MEM:%.+]] = arc.alloc_memory + // CHECK: scf.if {{%.+}} { + // CHECK: [[A:%.+]] = arc.state_read %in_a + // CHECK: [[B:%.+]] = arc.state_read %in_b + // CHECK: [[EN:%.+]] = arc.state_read %in_en + // CHECK: [[MASK:%.+]] = arc.state_read %in_mask + // CHECK: [[TMP:%.+]]:4 = arc.call @IdI2AndI42AndI1AndI42Arc([[A]], [[B]], [[EN]], [[MASK]]) + // CHECK: scf.if [[TMP]]#2 { + // CHECK: [[ALL_ONES:%.+]] = hw.constant -1 + // CHECK: [[MASK_INV:%.+]] = comb.xor bin [[TMP]]#3, [[ALL_ONES]] + // CHECK: [[DATA_OLD:%.+]] = arc.memory_read [[MEM]][[[TMP]]#0] + // CHECK: [[MASKED_OLD:%.+]] = comb.and bin [[MASK_INV]], [[DATA_OLD]] + // CHECK: [[MASKED_NEW:%.+]] = comb.and bin [[TMP]]#3, [[TMP]]#1 + // CHECK: [[DATA_NEW:%.+]] = comb.or bin [[MASKED_OLD]], [[MASKED_NEW]] + // CHECK: arc.memory_write [[MEM]][[[TMP]]#0], [[DATA_NEW]] + // CHECK: } + // CHECK: } + %mem = arc.memory <4 x i42, i2> + arc.memory_write_port %mem, @IdI2AndI42AndI1AndI42Arc(%a, %b, %en, %mask) clock %clock enable mask latency 1 : <4 x i42, i2>, i2, i42, i1, i42 +} + +// CHECK-LABEL: arc.model @SimpleInitial +hw.module @SimpleInitial() { + // CHECK: arc.initial { + // CHECK: func.call @VoidFunc() {initA} + // CHECK: func.call @VoidFunc() {initB} + // CHECK: } + // CHECK: func.call @VoidFunc() {body} + seq.initial() { + func.call @VoidFunc() {initA} : () -> () + } : () -> () + func.call @VoidFunc() {body} : () -> () + seq.initial() { + func.call @VoidFunc() {initB} : () -> () + } : () -> () +} + +// CHECK-LABEL: arc.model @InitialWithDependencies +hw.module @InitialWithDependencies() { + // CHECK: arc.initial { + // CHECK-NEXT: func.call @VoidFunc() {before} + // CHECK-NEXT: [[A:%.+]] = func.call @RandomI42() {initA} + // CHECK-NEXT: [[B:%.+]] = func.call @RandomI42() {initB} + // CHECK-NEXT: [[TMP:%.+]] = comb.add [[A]], [[B]] + // CHECK-NEXT: func.call @ConsumeI42([[TMP]]) + // CHECK-NEXT: func.call @VoidFunc() {after} + // CHECK-NEXT: } + + seq.initial() { + func.call @VoidFunc() {before} : () -> () + } : () -> () + + // This pulls up %initA, %initB, %initC since it depends on their SSA values. + seq.initial(%initC) { + ^bb0(%arg0: i42): + func.call @ConsumeI42(%arg0) : (i42) -> () + } : (!seq.immutable) -> () + + seq.initial() { + func.call @VoidFunc() {after} : () -> () + } : () -> () + + // The following is pulled up. + %initA = seq.initial() { + %1 = func.call @RandomI42() {initA} : () -> i42 + seq.yield %1 : i42 + } : () -> (!seq.immutable) + + %initB = seq.initial() { + %2 = func.call @RandomI42() {initB} : () -> i42 + seq.yield %2 : i42 + } : () -> (!seq.immutable) + + %initC = seq.initial(%initA, %initB) { + ^bb0(%arg0: i42, %arg1: i42): + %3 = comb.add %arg0, %arg1 : i42 + seq.yield %3 : i42 + } : (!seq.immutable, !seq.immutable) -> (!seq.immutable) +} + +// CHECK-LABEL: arc.model @FromImmutableCast +hw.module @FromImmutableCast() { + // CHECK: [[STORAGE:%.+]] = arc.alloc_state + // CHECK: arc.initial { + // CHECK: [[TMP:%.+]] = func.call @RandomI42() + // CHECK: arc.state_write [[STORAGE]] = [[TMP]] + // CHECK: } + // CHECK: [[TMP:%.+]] = arc.state_read [[STORAGE]] + // CHECK: func.call @ConsumeI42([[TMP]]) + func.call @ConsumeI42(%1) : (i42) -> () + %1 = seq.from_immutable %0 : (!seq.immutable) -> i42 + %0 = seq.initial() { + %2 = func.call @RandomI42() : () -> (i42) + seq.yield %2 : i42 + } : () -> (!seq.immutable) } -// CHECK: %c9001_i42 = hw.constant 9001 : i42 -// CHECK: %c1010_i42 = hw.constant 1010 : i42 -// CHECK: [[RES:%.+]]:4 = arc.call @identity2(%c0_i2, %c9001_i42, %true, %c1010_i42) : (i2, i42, i1, i42) -> (i2, i42, i1, i42) -// CHECK: [[RD:%.+]] = arc.memory_read [[MEM:%.+]][[[RES]]#0] : <4 x i42, i2> -// CHECK: %c-1_i42 = hw.constant -1 : i42 -// CHECK: [[NEG_MASK:%.+]] = comb.xor bin [[RES]]#3, %c-1_i42 : i42 -// CHECK: [[OLD_MASKED:%.+]] = comb.and bin [[NEG_MASK]], [[RD]] : i42 -// CHECK: [[NEW_MASKED:%.+]] = comb.and bin [[RES]]#3, [[RES]]#1 : i42 -// CHECK: [[DATA:%.+]] = comb.or bin [[OLD_MASKED]], [[NEW_MASKED]] : i42 -// CHECK: arc.memory_write [[MEM]][[[RES]]#0], [[DATA]] if [[RES]]#2 : <4 x i42, i2> // CHECK-LABEL: arc.model @Taps hw.module @Taps() { - // CHECK-NOT: arc.tap - // CHECK-DAG: [[VALUE:%.+]] = hw.constant 0 : i42 - // CHECK-DAG: [[STATE:%.+]] = arc.alloc_state %arg0 tap {name = "myTap"} - // CHECK-DAG: arc.state_write [[STATE]] = [[VALUE]] - %c0_i42 = hw.constant 0 : i42 - arc.tap %c0_i42 {name = "myTap"} : i42 + // CHECK: [[STORAGE:%.+]] = arc.alloc_state + // CHECK: [[TMP:%.+]] = func.call @RandomI42() + // CHECK: arc.state_write [[STORAGE]] = [[TMP]] + %0 = func.call @RandomI42() : () -> i42 + arc.tap %0 {name = "myTap"} : i42 } -// CHECK-LABEL: arc.model @MaterializeOpsWithRegions -hw.module @MaterializeOpsWithRegions(in %clk0: !seq.clock, in %clk1: !seq.clock, out z: i42) { - %true = hw.constant true - %c19_i42 = hw.constant 19 : i42 - %0 = scf.if %true -> (i42) { - scf.yield %c19_i42 : i42 - } else { - %c42_i42 = hw.constant 42 : i42 - scf.yield %c42_i42 : i42 - } +// CHECK-LABEL: arc.model @StateInitializerUsesOtherState +hw.module @StateInitializerUsesOtherState(in %clock: !seq.clock, in %a: i42, in %b: i19) { + // CHECK: [[Q2:%.+]] = arc.alloc_state %arg0 {name = "q2"} + // CHECK: [[Q3:%.+]] = arc.alloc_state %arg0 {name = "q3"} + // CHECK: [[Q0:%.+]] = arc.alloc_state %arg0 {name = "q0"} + // CHECK: [[Q1:%.+]] = arc.alloc_state %arg0 {name = "q1"} + // CHECK: arc.initial { + // CHECK-NEXT: [[A:%.+]] = arc.state_read %in_a + // CHECK-NEXT: arc.state_write [[Q0]] = [[A]] + // CHECK-NEXT: [[TMP1:%.+]] = arc.state_read [[Q0]] + // CHECK-NEXT: [[TMP2:%.+]] = comb.xor [[TMP1]] + // CHECK-NEXT: arc.state_write [[Q1]] = [[TMP2]] + // CHECK-NEXT: [[TMP:%.+]] = arc.state_read [[Q1]] + // CHECK-NEXT: arc.state_write [[Q2]] = [[TMP]] + // CHECK-NEXT: [[TMP:%.+]] = arc.state_read %in_b + // CHECK-NEXT: arc.state_write [[Q3]] = [[TMP]] + // CHECK-NEXT: } + arc.state @RandomI42AndI19Arc() clock %clock initial (%2, %b : i42, i19) latency 1 {names = ["q2", "q3"]} : () -> (i42, i19) + %2 = arc.state @RandomI42Arc() clock %clock initial (%1 : i42) latency 1 {names = ["q1"]} : () -> i42 + %1 = comb.xor %0 : i42 + %0 = arc.state @RandomI42Arc() clock %clock initial (%a : i42) latency 1 {names = ["q0"]} : () -> i42 +} - // CHECK: arc.passthrough { - // CHECK-NEXT: %true = hw.constant true - // CHECK-NEXT: %c19_i42 = hw.constant 19 - // CHECK-NEXT: [[TMP:%.+]] = scf.if %true -> (i42) { - // CHECK-NEXT: scf.yield %c19_i42 - // CHECK-NEXT: } else { - // CHECK-NEXT: %c42_i42 = hw.constant 42 - // CHECK-NEXT: scf.yield %c42_i42 - // CHECK-NEXT: } - // CHECK-NEXT: arc.state_write +// CHECK-LABEL: arc.model @StateInitializerUsesInitial +hw.module @StateInitializerUsesInitial(in %clock: !seq.clock) { + // CHECK: [[Q0:%.+]] = arc.alloc_state %arg0 {name = "q0"} + // CHECK: arc.initial { + // CHECK-NEXT: [[TMP1:%.+]] = hw.constant 9001 + // CHECK-NEXT: [[TMP2:%.+]] = comb.xor [[TMP1]] + // CHECK-NEXT: arc.state_write [[Q0]] = [[TMP2]] // CHECK-NEXT: } + %0 = seq.initial() { + %4 = hw.constant 9001 : i42 + seq.yield %4 : i42 + } : () -> !seq.immutable + %1 = seq.initial(%0) { + ^bb0(%5: i42): + %6 = comb.xor %5 : i42 + seq.yield %6 : i42 + } : (!seq.immutable) -> !seq.immutable + %2 = seq.from_immutable %1 : (!seq.immutable) -> i42 + %3 = arc.state @RandomI42Arc() clock %clock initial (%2 : i42) latency 1 {names = ["q0"]} : () -> i42 +} - // CHECK: [[CLK0:%.+]] = arc.state_read %in_clk0 - // CHECK: [[TMP:%.+]] = comb.and {{%.+}}, [[CLK0]] - // CHECK-NEXT: arc.clock_tree [[TMP]] { - // CHECK-NEXT: %true = hw.constant true - // CHECK-NEXT: %c19_i42 = hw.constant 19 - // CHECK-NEXT: [[TMP:%.+]] = scf.if %true -> (i42) { - // CHECK-NEXT: scf.yield %c19_i42 - // CHECK-NEXT: } else { - // CHECK-NEXT: %c42_i42 = hw.constant 42 - // CHECK-NEXT: scf.yield %c42_i42 - // CHECK-NEXT: } - // CHECK-NEXT: arc.call @DummyArc([[TMP]]) - // CHECK-NEXT: arc.state_write +// CHECK-LABEL: arc.model @SimpleFinal +hw.module @SimpleFinal() { + // CHECK: arc.final { + // CHECK-NEXT: func.call @VoidFunc() {finalA} + // CHECK-NEXT: func.call @VoidFunc() {finalB} // CHECK-NEXT: } + // CHECK-NEXT: func.call @VoidFunc() {body} + llhd.final { + func.call @VoidFunc() {finalA} : () -> () + llhd.halt + } + func.call @VoidFunc() {body} : () -> () + llhd.final { + func.call @VoidFunc() {finalB} : () -> () + llhd.halt + } +} + +// CHECK-LABEL: arc.model @FinalWithDependencies +hw.module @FinalWithDependencies() { + // CHECK: arc.final { + // CHECK-NEXT: [[TMP:%.+]] = hw.constant 9001 + // CHECK-NEXT: func.call @ConsumeI42([[TMP]]) {finalA} + // CHECK-NEXT: [[TMP:%.+]] = func.call @RandomI42() {sideEffect} + // CHECK-NEXT: func.call @ConsumeI42([[TMP]]) {finalB} + // CHECK-NEXT: } + // CHECK-NEXT: func.call @VoidFunc() {body} + // CHECK-NEXT: func.call @RandomI42() {sideEffect} + llhd.final { + func.call @ConsumeI42(%0) {finalA} : (i42) -> () + llhd.halt + } + func.call @VoidFunc() {body} : () -> () + llhd.final { + func.call @ConsumeI42(%1) {finalB} : (i42) -> () + llhd.halt + } + %0 = hw.constant 9001 : i42 + %1 = func.call @RandomI42() {sideEffect} : () -> i42 +} - // CHECK: [[CLK1:%.+]] = arc.state_read %in_clk1 - // CHECK: [[TMP:%.+]] = comb.and {{%.+}}, [[CLK1]] - // CHECK-NEXT: arc.clock_tree [[TMP]] { - // CHECK-NEXT: %true = hw.constant true - // CHECK-NEXT: %c19_i42 = hw.constant 19 - // CHECK-NEXT: [[TMP:%.+]] = scf.if %true -> (i42) { - // CHECK-NEXT: scf.yield %c19_i42 - // CHECK-NEXT: } else { - // CHECK-NEXT: %c42_i42 = hw.constant 42 - // CHECK-NEXT: scf.yield %c42_i42 +// CHECK-LABEL: arc.model @FinalWithControlFlow +hw.module @FinalWithControlFlow() { + // CHECK: arc.final { + // CHECK-NEXT: scf.execute_region { + // CHECK-NEXT: [[TMP:%.+]] = func.call @RandomI42() + // CHECK-NEXT: cf.br ^[[BB:.+]]([[TMP]] : i42) + // CHECK-NEXT: ^[[BB]]([[TMP:%.+]]: i42) + // CHECK-NEXT: func.call @ConsumeI42([[TMP]]) + // CHECK-NEXT: scf.yield // CHECK-NEXT: } - // CHECK-NEXT: arc.call @DummyArc([[TMP]]) - // CHECK-NEXT: arc.state_write // CHECK-NEXT: } + llhd.final { + %0 = func.call @RandomI42() : () -> i42 + cf.br ^bb0(%0 : i42) + ^bb0(%1: i42): + func.call @ConsumeI42(%1) : (i42) -> () + llhd.halt + } +} - %1 = arc.state @DummyArc(%0) clock %clk0 latency 1 : (i42) -> i42 - %2 = arc.state @DummyArc(%0) clock %clk1 latency 1 : (i42) -> i42 +// CHECK-LABEL: arc.model @UnclockedDpiCall +hw.module @UnclockedDpiCall(in %a: i42, out b: i42) { + // CHECK: [[TMP1:%.+]] = arc.state_read %in_a + // CHECK: [[TMP2:%.+]] = func.call @IdI42([[TMP1]]) + %0 = sim.func.dpi.call @IdI42(%a) : (i42) -> i42 + // CHECK: arc.state_write %out_b = [[TMP2]] hw.output %0 : i42 } -arc.define @i1Identity(%arg0: i1) -> i1 { - arc.output %arg0 : i1 +// CHECK-LABEL: arc.model @ClockedDpiCall +hw.module @ClockedDpiCall(in %clock: !seq.clock, in %a: i42, out b: i42) { + // CHECK: [[Q0:%.+]] = arc.alloc_state %arg0 {name = "q0"} + // CHECK: [[Q1:%.+]] = arc.alloc_state %arg0 {name = "q1"} + // CHECK: [[Q0_OLD:%.+]] = arc.state_read [[Q0]] + // CHECK: scf.if {{%.+}} { + // CHECK-NEXT: [[TMP1:%.+]] = arc.state_read %in_a + // CHECK-NEXT: [[TMP2:%.+]] = func.call @IdI42([[TMP1]]) + // CHECK-NEXT: arc.state_write [[Q0]] = [[TMP2]] + // CHECK-NEXT: [[TMP3:%.+]] = func.call @IdI42([[Q0_OLD]]) + // CHECK-NEXT: arc.state_write [[Q1]] = [[TMP3]] + // CHECK-NEXT: } + %0 = sim.func.dpi.call @IdI42(%a) clock %clock {names = ["q0"]} : (i42) -> i42 + %1 = sim.func.dpi.call @IdI42(%0) clock %clock {names = ["q1"]} : (i42) -> i42 + // CHECK-NEXT: [[TMP:%.+]] = arc.state_read [[Q1]] + // CHECK-NEXT: arc.state_write %out_b = [[TMP]] + hw.output %1 : i42 } -arc.define @DummyArc2(%arg0: i42) -> (i42, i42) { - arc.output %arg0, %arg0 : i42, i42 -} - -hw.module @stateReset(in %clk: !seq.clock, in %arg0: i42, in %rst: i1, out out0: i42, out out1: i42) { - %0 = arc.call @i1Identity(%rst) : (i1) -> (i1) - %1 = arc.call @i1Identity(%rst) : (i1) -> (i1) - %2, %3 = arc.state @DummyArc2(%arg0) clock %clk enable %0 reset %1 latency 1 : (i42) -> (i42, i42) - hw.output %2, %3 : i42, i42 -} -// CHECK-LABEL: arc.model @stateReset -// CHECK: [[ALLOC1:%.+]] = arc.alloc_state %arg0 : (!arc.storage) -> !arc.state -// CHECK: [[ALLOC2:%.+]] = arc.alloc_state %arg0 : (!arc.storage) -> !arc.state -// CHECK: arc.clock_tree %{{.*}} { -// CHECK: [[IN_RST:%.+]] = arc.state_read %in_rst : -// CHECK: [[EN:%.+]] = arc.call @i1Identity([[IN_RST]]) : (i1) -> i1 -// CHECK: [[RST:%.+]] = arc.call @i1Identity([[IN_RST]]) : (i1) -> i1 -// CHECK: scf.if [[RST]] { -// CHECK: arc.state_write [[ALLOC1]] = %c0_i42{{.*}} : -// CHECK: arc.state_write [[ALLOC2]] = %c0_i42{{.*}} : -// CHECK: } else { -// CHECK: [[ARG:%.+]] = arc.state_read %in_arg0 : -// CHECK: [[STATE:%.+]]:2 = arc.call @DummyArc2([[ARG]]) : (i42) -> (i42, i42) -// CHECK: arc.state_write [[ALLOC1]] = [[STATE]]#0 if [[EN]] : -// CHECK: arc.state_write [[ALLOC2]] = [[STATE]]#1 if [[EN]] : -// CHECK: } -// CHECK: } - -hw.module @SeparateResets(in %clock: !seq.clock, in %i0: i42, in %rst1: i1, in %rst2: i1, out out1: i42, out out2: i42) { - %0 = arc.state @DummyArc(%i0) clock %clock reset %rst1 latency 1 {names = ["foo"]} : (i42) -> i42 - %1 = arc.state @DummyArc(%i0) clock %clock reset %rst2 latency 1 {names = ["bar"]} : (i42) -> i42 - hw.output %0, %1 : i42, i42 -} - -// CHECK-LABEL: arc.model @SeparateResets -// CHECK: [[FOO_ALLOC:%.+]] = arc.alloc_state %arg0 {name = "foo"} : (!arc.storage) -> !arc.state -// CHECK: [[BAR_ALLOC:%.+]] = arc.alloc_state %arg0 {name = "bar"} : (!arc.storage) -> !arc.state -// CHECK: arc.clock_tree %{{.*}} { -// CHECK: [[IN_RST1:%.+]] = arc.state_read %in_rst1 : -// CHECK: scf.if [[IN_RST1]] { -// CHECK: %c0_i42{{.*}} = hw.constant 0 : i42 -// CHECK: arc.state_write [[FOO_ALLOC]] = %c0_i42{{.*}} : -// CHECK: } else { -// CHECK: [[IN_I0:%.+]] = arc.state_read %in_i0 : -// CHECK: [[STATE:%.+]] = arc.call @DummyArc([[IN_I0]]) : (i42) -> i42 -// CHECK: arc.state_write [[FOO_ALLOC]] = [[STATE]] : -// CHECK: } -// CHECK: [[IN_RST2:%.+]] = arc.state_read %in_rst2 : -// CHECK: scf.if [[IN_RST2]] { -// CHECK: %c0_i42{{.*}} = hw.constant 0 : i42 -// CHECK: arc.state_write [[BAR_ALLOC]] = %c0_i42{{.*}} : -// CHECK: } else { -// CHECK: [[IN_I0_2:%.+]] = arc.state_read %in_i0 : -// CHECK: [[STATE_2:%.+]] = arc.call @DummyArc([[IN_I0_2]]) : (i42) -> i42 -// CHECK: arc.state_write [[BAR_ALLOC]] = [[STATE_2]] : -// CHECK: } +// CHECK-LABEL: arc.model @OpsWithRegions +hw.module @OpsWithRegions(in %clock: !seq.clock, in %a: i42, in %b: i1, out c: i42) { + // CHECK: [[Q0:%.+]] = arc.alloc_state %arg0 {name = "q0"} + %0 = scf.if %b -> (i42) { + scf.yield %a : i42 + } else { + scf.yield %1 : i42 + } + // CHECK: [[A:%.+]] = arc.state_read %in_a + // CHECK: [[Q0_OLD:%.+]] = arc.state_read [[Q0]] + // CHECK: [[B:%.+]] = arc.state_read %in_b + // CHECK: [[IF_OLD:%.+]] = scf.if [[B]] -> (i42) { + // CHECK: scf.yield [[A]] + // CHECK: } else { + // CHECK: scf.yield [[Q0_OLD]] + // CHECK: } + // CHECK: scf.if {{%.+}} { + // CHECK: [[TMP:%.+]] = arc.call @IdI42Arc([[IF_OLD]]) + // CHECK: arc.state_write [[Q0]] = [[TMP]] + // CHECK: } + %1 = arc.state @IdI42Arc(%0) clock %clock latency 1 {names = ["q0"]} : (i42) -> i42 + // CHECK: [[A:%.+]] = arc.state_read %in_a + // CHECK: [[Q0_NEW:%.+]] = arc.state_read [[Q0]] + // CHECK: [[B:%.+]] = arc.state_read %in_b + // CHECK: [[IF_NEW:%.+]] = scf.if [[B]] -> (i42) { + // CHECK: scf.yield [[A]] + // CHECK: } else { + // CHECK: scf.yield [[Q0_NEW]] + // CHECK: } + // CHECK: arc.state_write %out_c = [[IF_NEW]] + hw.output %0 : i42 +} // Regression check on worklist producing false positive comb loop errors. // CHECK-LABEL: @CombLoopRegression @@ -309,139 +624,19 @@ arc.define @CombLoopRegressionArc2(%arg0: i1) -> (i1, i1) { // Regression check for invalid memory port lowering errors. // CHECK-LABEL: arc.model @MemoryPortRegression hw.module private @MemoryPortRegression(in %clock: !seq.clock, in %reset: i1, in %in: i3, out x: i3) { - %0 = arc.memory <2 x i3, i1> {name = "ram_ext"} + %0 = arc.memory <2 x i3, i1> %1 = arc.memory_read_port %0[%3] : <2 x i3, i1> - arc.memory_write_port %0, @identity3(%3, %in) clock %clock latency 1 : <2 x i3, i1>, i1, i3 - %3 = arc.state @Queue_arc_0(%reset) clock %clock latency 1 : (i1) -> i1 - %4 = arc.call @Queue_arc_1(%1) : (i3) -> i3 + arc.memory_write_port %0, @MemoryPortRegressionArc1(%3, %in) clock %clock latency 1 : <2 x i3, i1>, i1, i3 + %3 = arc.state @MemoryPortRegressionArc2(%reset) clock %clock latency 1 : (i1) -> i1 + %4 = arc.call @MemoryPortRegressionArc3(%1) : (i3) -> i3 hw.output %4 : i3 } -arc.define @identity3(%arg0: i1, %arg1: i3) -> (i1, i3) { +arc.define @MemoryPortRegressionArc1(%arg0: i1, %arg1: i3) -> (i1, i3) { arc.output %arg0, %arg1 : i1, i3 } -arc.define @Queue_arc_0(%arg0: i1) -> i1 { +arc.define @MemoryPortRegressionArc2(%arg0: i1) -> i1 { arc.output %arg0 : i1 } -arc.define @Queue_arc_1(%arg0: i3) -> i3 { +arc.define @MemoryPortRegressionArc3(%arg0: i3) -> i3 { arc.output %arg0 : i3 } - -// CHECK-LABEL: arc.model @BlackBox -hw.module @BlackBox(in %clk: !seq.clock) { - %0 = arc.state @DummyArc(%2) clock %clk latency 1 : (i42) -> i42 - %1 = comb.and %0, %0 : i42 - %ext.c, %ext.d = hw.instance "ext" @BlackBoxExt(a: %0: i42, b: %1: i42) -> (c: i42, d: i42) - %2 = comb.or %ext.c, %ext.d : i42 - // CHECK-DAG: [[EXT_A:%.+]] = arc.alloc_state %arg0 {name = "ext/a"} - // CHECK-DAG: [[EXT_B:%.+]] = arc.alloc_state %arg0 {name = "ext/b"} - // CHECK-DAG: [[EXT_C:%.+]] = arc.alloc_state %arg0 {name = "ext/c"} - // CHECK-DAG: [[EXT_D:%.+]] = arc.alloc_state %arg0 {name = "ext/d"} - // CHECK-DAG: [[STATE:%.+]] = arc.alloc_state %arg0 : (!arc.storage) -> !arc.state - - // Clock Tree - // CHECK-DAG: [[TMP1:%.+]] = arc.state_read [[EXT_C]] - // CHECK-DAG: [[TMP2:%.+]] = arc.state_read [[EXT_D]] - // CHECK-DAG: [[TMP3:%.+]] = comb.or [[TMP1]], [[TMP2]] - // CHECK-DAG: [[TMP4:%.+]] = arc.call @DummyArc([[TMP3]]) - // CHECK-DAG: arc.state_write [[STATE]] = [[TMP4]] - - // Passthrough - // CHECK-DAG: [[TMP1:%.+]] = arc.state_read [[STATE]] - // CHECK-DAG: [[TMP2:%.+]] = comb.and [[TMP1]], [[TMP1]] - // CHECK-DAG: arc.state_write [[EXT_A]] = [[TMP1]] - // CHECK-DAG: arc.state_write [[EXT_B]] = [[TMP2]] -} -// CHECK-NOT: hw.module.extern private @BlackBoxExt -hw.module.extern private @BlackBoxExt(in %a: i42, in %b: i42, out c: i42, out d: i42) - - -func.func private @func(%arg0: i32, %arg1: i32) -> i32 -// CHECK-LABEL: arc.model @adder -hw.module @adder(in %clock : i1, in %a : i32, in %b : i32, out c : i32) { - %0 = seq.to_clock %clock - %1 = sim.func.dpi.call @func(%a, %b) clock %0 : (i32, i32) -> i32 - // CHECK: arc.clock_tree - // CHECK-NEXT: %[[A:.+]] = arc.state_read %in_a : - // CHECK-NEXT: %[[B:.+]] = arc.state_read %in_b : - // CHECK-NEXT: %[[RESULT:.+]] = func.call @func(%6, %7) : (i32, i32) -> i32 - hw.output %1 : i32 -} - -// CHECK-LABEL: arc.model @InitializedStates -hw.module @InitializedStates(in %clk: !seq.clock, in %reset: i1, in %input: i42) { - -// CHECK: [[ST1:%.+]] = arc.alloc_state %arg0 : (!arc.storage) -> !arc.state -// CHECK-NEXT: [[ST2:%.+]] = arc.alloc_state %arg0 : (!arc.storage) -> !arc.state -// CHECK-NEXT: [[ST3:%.+]] = arc.alloc_state %arg0 : (!arc.storage) -> !arc.state -// CHECK-NEXT: [[ST4:%.+]] = arc.alloc_state %arg0 : (!arc.storage) -> !arc.state -// CHECK-NEXT: [[ST5:%.+]] = arc.alloc_state %arg0 : (!arc.storage) -> !arc.state - -// CHECK: arc.initial { - - %csta = hw.constant 1 : i42 - %cstb = hw.constant 10 : i42 - %cstc = hw.constant 100 : i42 - %cstd = hw.constant 1000 : i42 - %add = comb.add bin %cstb, %cstc, %csta : i42 - %mul = comb.mul bin %add, %csta : i42 - - // CHECK-NEXT: [[CSTD:%.+]] = hw.constant 1000 : i42 - // CHECK-NEXT: arc.state_write [[ST1]] = [[CSTD]] : - %0 = arc.state @DummyArc(%input) clock %clk initial (%cstd : i42) latency 1 : (i42) -> i42 - - // CHECK-DAG: [[CSTA:%.+]] = hw.constant 1 : i42 - // CHECK-DAG: [[CSTB:%.+]] = hw.constant 10 : i42 - // CHECK-DAG: [[CSTC:%.+]] = hw.constant 100 : i42 - // CHECK-DAG: [[ADD:%.+]] = comb.add bin [[CSTB]], [[CSTC]], [[CSTA]] : i42 - // CHECK-DAG: [[MUL:%.+]] = comb.mul bin [[ADD]], [[CSTA]] : i42 - - // CHECK: arc.state_write [[ST2]] = [[MUL]] : - %1 = arc.state @DummyArc(%0) clock %clk initial (%mul : i42) latency 1 : (i42) -> i42 - // CHECK-NEXT: arc.state_write [[ST3]] = [[CSTB]] : - %2 = arc.state @DummyArc(%1) clock %clk reset %reset initial (%cstb : i42) latency 1 : (i42) -> i42 - // CHECK-DAG: arc.state_write [[ST4]] = [[CSTB]] : - // CHECK-DAG: arc.state_write [[ST5]] = [[ADD]] : - %3, %4 = arc.state @DummyArc2(%2) clock %clk initial (%cstb, %add : i42, i42) latency 1 : (i42) -> (i42, i42) -// CHECK: } -} - -func.func private @random() -> i32 -arc.define @counter_arc(%arg0: i8) -> i8 { - %c1_i8 = hw.constant 1 : i8 - %0 = comb.add %arg0, %c1_i8 : i8 - arc.output %0 : i8 -} -arc.define @counter_arc_0(%arg0: i1) -> !seq.clock { - %0 = seq.to_clock %arg0 - arc.output %0 : !seq.clock -} -// CHECK-LABEL: arc.model @seqInitial -hw.module @seqInitial(in %clk : i1, out o1 : i8, out o2 : i8) { - // CHECK: %[[STATE1:.+]] = arc.alloc_state %arg0 : (!arc.storage) -> !arc.state - // CHECK-NEXT: %[[STATE2:.+]] = arc.alloc_state %arg0 : (!arc.storage) -> !arc.state - - // CHECK: arc.initial { - // CHECK-NEXT: %c5_i8 = hw.constant 5 : i8 - // CHECK-NEXT: %[[RAND:.+]] = func.call @random() : () -> i32 - // CHECK-NEXT: %[[EXTRACT:.+]] = comb.extract %[[RAND]] from 0 : (i32) -> i8 - // CHECK-NEXT: arc.state_write %[[STATE1]] = %[[VAL1:.+]] : - // CHECK-NEXT: arc.state_write %[[STATE2]] = %[[VAL2:.+]] : - // CHECK-NEXT: } - %0 = seq.from_immutable %7 : (!seq.immutable) -> i8 - %1 = seq.from_immutable %2#1 : (!seq.immutable) -> i8 - %2:2 = seq.initial() { - %c5_i8 = hw.constant 5 : i8 - %6 = func.call @random() : () -> i32 - seq.yield %6, %c5_i8 : i32, i8 - } : () -> (!seq.immutable, !seq.immutable) - %7 = seq.initial(%2#0) { - ^bb0(%arg0 : i32): - %ext = comb.extract %arg0 from 0 : (i32) -> i8 - seq.yield %ext: i8 - } : (!seq.immutable) -> (!seq.immutable) - - %3 = arc.state @counter_arc(%3) clock %4 initial (%0 : i8) latency 1 : (i8) -> i8 - %4 = arc.call @counter_arc_0(%clk) : (i1) -> !seq.clock - %5 = arc.state @counter_arc(%5) clock %4 initial (%1 : i8) latency 1 : (i8) -> i8 - hw.output %3, %5 : i8, i8 -} diff --git a/test/arcilator/arcilator.mlir b/test/arcilator/arcilator.mlir index af9fe66e3fec..7aaf2f29d3fe 100644 --- a/test/arcilator/arcilator.mlir +++ b/test/arcilator/arcilator.mlir @@ -1,33 +1,7 @@ -// RUN: arcilator %s --inline=0 --until-before=llvm-lowering | FileCheck %s // RUN: arcilator %s | FileCheck %s --check-prefix=LLVM // RUN: arcilator --print-debug-info %s | FileCheck %s --check-prefix=LLVM-DEBUG -// CHECK: func.func @[[XOR_ARC:.+]]( -// CHECK-NEXT: comb.xor -// CHECK-NEXT: return -// CHECK-NEXT: } - -// CHECK: func.func @[[ADD_ARC:.+]]( -// CHECK-NEXT: comb.add -// CHECK-NEXT: return -// CHECK-NEXT: } - -// CHECK: func.func @[[MUL_ARC:.+]]( -// CHECK-NEXT: comb.mul -// CHECK-NEXT: return -// CHECK-NEXT: } - -// CHECK: func.func @Top_passthrough -// CHECK: func.func @Top_clock - -// CHECK-NOT: hw.module @Top -// CHECK-LABEL: arc.model @Top io !hw.modty -// CHECK-NEXT: ^bb0(%arg0: !arc.storage<8>): hw.module @Top(in %clock : !seq.clock, in %i0 : i4, in %i1 : i4, out out : i4) { - // CHECK: func.call @Top_passthrough(%arg0) - // CHECK: scf.if {{%.+}} { - // CHECK: func.call @Top_clock(%arg0) - // CHECK: } %0 = comb.add %i0, %i1 : i4 %1 = comb.xor %0, %i0 : i4 %2 = comb.xor %0, %i1 : i4 @@ -37,22 +11,14 @@ hw.module @Top(in %clock : !seq.clock, in %i0 : i4, in %i1 : i4, out out : i4) { hw.output %3 : i4 } -// LLVM: define void @Top_passthrough(ptr %0) -// LLVM: mul i4 -// LLVM: define void @Top_clock(ptr %0) +// LLVM: define void @Top_eval(ptr %0) // LLVM: add i4 // LLVM: xor i4 // LLVM: xor i4 -// LLVM: define void @Top_eval(ptr %0) -// LLVM: call void @Top_passthrough(ptr %0) -// LLVM: call void @Top_clock(ptr %0) +// LLVM: mul i4 -// LLVM-DEBUG: define void @Top_passthrough(ptr %0){{.*}}!dbg -// LLVM-DEBUG: mul i4{{.*}}!dbg -// LLVM-DEBUG: define void @Top_clock(ptr %0){{.*}}!dbg +// LLVM-DEBUG: define void @Top_eval(ptr %0){{.*}}!dbg // LLVM-DEBUG: add i4{{.*}}!dbg // LLVM-DEBUG: xor i4{{.*}}!dbg // LLVM-DEBUG: xor i4{{.*}}!dbg -// LLVM-DEBUG: define void @Top_eval(ptr %0){{.*}}!dbg -// LLVM-DEBUG: call void @Top_passthrough(ptr %0){{.*}}!dbg -// LLVM-DEBUG: call void @Top_clock(ptr %0){{.*}}!dbg +// LLVM-DEBUG: mul i4{{.*}}!dbg diff --git a/tools/arcilator/arcilator.cpp b/tools/arcilator/arcilator.cpp index dc5aa8403106..9f092e800a86 100644 --- a/tools/arcilator/arcilator.cpp +++ b/tools/arcilator/arcilator.cpp @@ -321,8 +321,6 @@ static void populateHwModuleToArcPipeline(PassManager &pm) { if (untilReached(UntilStateLowering)) return; pm.addPass(arc::createLowerStatePass()); - pm.addPass(createCSEPass()); - pm.addPass(arc::createArcCanonicalizerPass()); // TODO: LowerClocksToFuncsPass might not properly consider scf.if operations // (or nested regions in general) and thus errors out when muxes are also @@ -330,15 +328,10 @@ static void populateHwModuleToArcPipeline(PassManager &pm) { // TODO: InlineArcs seems to not properly handle scf.if operations, thus the // following is commented out // pm.addPass(arc::createMuxToControlFlowPass()); - - if (shouldInline) { + if (shouldInline) pm.addPass(arc::createInlineArcsPass()); - pm.addPass(arc::createArcCanonicalizerPass()); - pm.addPass(createCSEPass()); - } pm.addPass(arc::createMergeIfsPass()); - pm.addPass(arc::createLegalizeStateUpdatePass()); pm.addPass(createCSEPass()); pm.addPass(arc::createArcCanonicalizerPass());