Skip to content

Commit

Permalink
[Arc] Improve LowerState to never produce read-after-write conflicts
Browse files Browse the repository at this point in the history
This is a complete rewrite of the `LowerState` pass that makes the
`LegalizeStateUpdate` pass obsolete.

The old implementation of `LowerState` produces `arc.model`s that still
contain read-after-write conflicts. This primarily happens because the
pass simply emits `arc.state_write` operations that write updated values
to simulation memory for each `arc.state`, and any user of `arc.state`
would use an `arc.state_read` operation to retrieve the original value
of the state before any writes occurred. Memories are similar. The Arc
dialect considers `arc.state_write` and `arc.memory_write` operations to
be _deferred_ writes until the `LegalizeStateUpdate` pass runs, at which
point they become _immediate_ since the legalization inserts the
necessary temporaries to resolve the read-after-write conflicts.

The previous implementation would also not handle state-to-output and
state-to-side-effecting-op propagation paths correctly. When a model's
eval function is called, registers are updated to their new value, and
any outputs that combinatorially depend on those new values should also
immediately update. Similarly, operations such as asserts or debug
trackers should observe new values for states immediately after they
have been written. However, since all writes are considered deferred,
there is no way for `LowerState` to produce a mixture of operations that
depend on a register's _old_ state (because they are used to compute a
register's new state), and on a _new_ state because they are
combinatorially derived values.

This new implementation of `LowerState` completely avoids
read-after-write conflicts. It does this by changing the way modules are
lowered in two ways:

**Phases:** The pass tracks in which _phase_ of the simulation lifecycle
a value is needed and allows for operations to have different lowerings
in different phases. An `arc.state` operation for example requires its
inputs, enable, and reset to be computed based on the _old_ value they
had, i.e. the value the end of the previous call to the model's eval
function. The clock however has to be computed based on the _new_ value
it has in the current call to eval. Therefore, the ops defining the
inputs, enable, and reset are lowered in the _old_ phase, while the ops
defining the clock are lowered in the _new_ phase. The `arc.state` op
lowering will then write its _new_ value to simulation storage.

This phase tracking allows registers to be used as the clock for other
registers: since the clocks require _new_ values, registers serving as
clock to others are lowered first, such that the dependent registers can
immediately react to the updated clock. It also allows for module
outputs and side-effecting ops based on `arc.state`s to be scheduled
after the states have been updated, since they depend on the state's
_new_ value.

The pass also covers the situation where an operation depends on a
module input and a state, and feeds into a module output as well as
another state. In this situation that operation has to be lowered twice:
once for the _old_ phase to serve as input to the subsequent state, and
once for the _new_ phase to compute the new module output.

In addition to the _old_ and _new_ phases representing the previous and
current call to eval, the pass also models an _initial_ and _final_
phase. These are used for `seq.initial` and `llhd.final` ops, and in
order to compute the initial values for states. If an `arc.state` op has
an initial value operand it is lowered in the _initial_ phase. Similarly
for the ops in `llhd.final`. The pass places all ops lowered in the
initial and final phases into corresponding `arc.initial` and
`arc.final` ops. At a later point we may want to generate the
`*_initial`, `*_eval`, and `*_final` functions directly.

**No more clock trees:** The new implementation also no longer generates
`arc.clock_tree` and `arc.passthrough` operations. These were a holdover
from the early days of the Arc dialect, where no eval function would be
generated. Instead, the user was required to directly call clock
functions. This was never able to model clocks changing at the exact
same moment, or having clocks derived from registers and other
combinatorial operations. Since Arc has since switched to generating an
eval function that can accurately interleave the effects of different
clocks, grouping ops by clock tree is no longer needed. In fact,
removing the clock tree ops allows for the pass to more efficiently
interleave the operations from different clock domains.

The Rocket core in the circt/arc-tests repository still works with this
new implementation of LowerState. In combination with the MergeIfs pass
the performance stays the same.

I have renamed the implementation and test files to make the git diffs
easier to read. The names will be changed back in a follow-up commit.
  • Loading branch information
fabianschuiki committed Oct 15, 2024
1 parent 2085d0d commit ca2369f
Show file tree
Hide file tree
Showing 17 changed files with 1,809 additions and 2,316 deletions.
2 changes: 0 additions & 2 deletions include/circt/Dialect/Arc/ArcPasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,9 @@ createInferMemoriesPass(const InferMemoriesOptions &options = {});
std::unique_ptr<mlir::Pass> createInlineArcsPass();
std::unique_ptr<mlir::Pass> createIsolateClocksPass();
std::unique_ptr<mlir::Pass> createLatencyRetimingPass();
std::unique_ptr<mlir::Pass> createLegalizeStateUpdatePass();
std::unique_ptr<mlir::Pass> createLowerArcsToFuncsPass();
std::unique_ptr<mlir::Pass> createLowerClocksToFuncsPass();
std::unique_ptr<mlir::Pass> createLowerLUTPass();
std::unique_ptr<mlir::Pass> createLowerStatePass();
std::unique_ptr<mlir::Pass> createLowerVectorizationsPass(
LowerVectorizationsModeEnum mode = LowerVectorizationsModeEnum::Full);
std::unique_ptr<mlir::Pass> createMakeTablesPass();
Expand Down
17 changes: 7 additions & 10 deletions include/circt/Dialect/Arc/ArcPasses.td
Original file line number Diff line number Diff line change
Expand Up @@ -163,12 +163,6 @@ def LatencyRetiming : Pass<"arc-latency-retiming", "mlir::ModuleOp"> {
];
}

def LegalizeStateUpdate : Pass<"arc-legalize-state-update", "mlir::ModuleOp"> {
let summary = "Insert temporaries such that state reads don't see writes";
let constructor = "circt::arc::createLegalizeStateUpdatePass()";
let dependentDialects = ["arc::ArcDialect"];
}

def LowerArcsToFuncs : Pass<"arc-lower-arcs-to-funcs", "mlir::ModuleOp"> {
let summary = "Lower arc definitions into functions";
let constructor = "circt::arc::createLowerArcsToFuncsPass()";
Expand All @@ -187,12 +181,15 @@ def LowerLUT : Pass<"arc-lower-lut", "arc::DefineOp"> {
let dependentDialects = ["hw::HWDialect", "comb::CombDialect"];
}

def LowerState : Pass<"arc-lower-state", "mlir::ModuleOp"> {
def LowerStatePass : Pass<"arc-lower-state", "mlir::ModuleOp"> {
let summary = "Split state into read and write ops grouped by clock tree";
let constructor = "circt::arc::createLowerStatePass()";
let dependentDialects = [
"arc::ArcDialect", "mlir::scf::SCFDialect", "mlir::func::FuncDialect",
"mlir::LLVM::LLVMDialect", "comb::CombDialect", "seq::SeqDialect"
"arc::ArcDialect",
"comb::CombDialect",
"mlir::LLVM::LLVMDialect",
"mlir::func::FuncDialect",
"mlir::scf::SCFDialect",
"seq::SeqDialect",
];
}

Expand Down
9 changes: 4 additions & 5 deletions integration_test/arcilator/JIT/dpi.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@ func.func @add_mlir_impl(%arg0: i32, %arg1: i32, %arg2: !llvm.ptr) {
llvm.store %0, %arg2 : i32, !llvm.ptr
return
}

hw.module @arith(in %clock : i1, in %a : i32, in %b : i32, out c : i32, out d : i32) {
%seq_clk = seq.to_clock %clock

%0 = sim.func.dpi.call @add_mlir(%a, %b) clock %seq_clk : (i32, i32) -> i32
%1 = sim.func.dpi.call @mul_shared(%a, %b) clock %seq_clk : (i32, i32) -> i32
hw.output %0, %1 : i32, i32
}

func.func @main() {
%c2_i32 = arith.constant 2 : i32
%c3_i32 = arith.constant 3 : i32
Expand All @@ -34,18 +35,16 @@ func.func @main() {
arc.sim.instantiate @arith as %arg0 {
arc.sim.set_input %arg0, "a" = %c2_i32 : i32, !arc.sim.instance<@arith>
arc.sim.set_input %arg0, "b" = %c3_i32 : i32, !arc.sim.instance<@arith>
arc.sim.set_input %arg0, "clock" = %one : i1, !arc.sim.instance<@arith>

arc.sim.step %arg0 : !arc.sim.instance<@arith>
arc.sim.set_input %arg0, "clock" = %zero : i1, !arc.sim.instance<@arith>
arc.sim.step %arg0 : !arc.sim.instance<@arith>
%0 = arc.sim.get_port %arg0, "c" : i32, !arc.sim.instance<@arith>
%1 = arc.sim.get_port %arg0, "d" : i32, !arc.sim.instance<@arith>

arc.sim.emit "c", %0 : i32
arc.sim.emit "d", %1 : i32

arc.sim.step %arg0 : !arc.sim.instance<@arith>
arc.sim.set_input %arg0, "clock" = %one : i1, !arc.sim.instance<@arith>
arc.sim.step %arg0 : !arc.sim.instance<@arith>
%2 = arc.sim.get_port %arg0, "c" : i32, !arc.sim.instance<@arith>
%3 = arc.sim.get_port %arg0, "d" : i32, !arc.sim.instance<@arith>
arc.sim.emit "c", %2 : i32
Expand Down
1 change: 1 addition & 0 deletions integration_test/arcilator/JIT/initial-shift-reg.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ module {
%true = arith.constant 1 : i1

arc.sim.instantiate @shiftreg as %model {
arc.sim.step %model : !arc.sim.instance<@shiftreg>
arc.sim.set_input %model, "en" = %false : i1, !arc.sim.instance<@shiftreg>
arc.sim.set_input %model, "reset" = %false : i1, !arc.sim.instance<@shiftreg>
arc.sim.set_input %model, "din" = %ff : i8, !arc.sim.instance<@shiftreg>
Expand Down
3 changes: 2 additions & 1 deletion integration_test/arcilator/JIT/reg.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// RUN: arcilator %s --run --jit-entry=main | FileCheck %s
// REQUIRES: arcilator-jit

// CHECK: o1 = 2
// CHECK: o1 = 2
// CHECK-NEXT: o2 = 5
// CHECK-NEXT: o1 = 3
// CHECK-NEXT: o2 = 6
Expand Down Expand Up @@ -41,6 +41,7 @@ func.func @main() {
%step = arith.constant 1 : index

arc.sim.instantiate @counter as %model {
arc.sim.step %model : !arc.sim.instance<@counter>
%init_val1 = arc.sim.get_port %model, "o1" : i8, !arc.sim.instance<@counter>
%init_val2 = arc.sim.get_port %model, "o2" : i8, !arc.sim.instance<@counter>

Expand Down
5 changes: 3 additions & 2 deletions lib/Conversion/ConvertToArcs/ConvertToArcs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ using llvm::MapVector;

static bool isArcBreakingOp(Operation *op) {
return op->hasTrait<OpTrait::ConstantLike>() ||
isa<hw::InstanceOp, seq::CompRegOp, MemoryOp, ClockedOpInterface,
seq::InitialOp, seq::ClockGateOp, sim::DPICallOp>(op) ||
isa<hw::InstanceOp, seq::CompRegOp, MemoryOp, MemoryReadPortOp,
ClockedOpInterface, seq::InitialOp, seq::ClockGateOp,
sim::DPICallOp>(op) ||
op->getNumResults() > 1;
}

Expand Down
8 changes: 7 additions & 1 deletion lib/Dialect/Arc/ArcTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,17 @@ using namespace mlir;
#define GET_TYPEDEF_CLASSES
#include "circt/Dialect/Arc/ArcTypes.cpp.inc"

unsigned StateType::getBitWidth() { return hw::getBitWidth(getType()); }
unsigned StateType::getBitWidth() {
if (llvm::isa<seq::ClockType>(getType()))
return 1;
return hw::getBitWidth(getType());
}

LogicalResult
StateType::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
Type innerType) {
if (llvm::isa<seq::ClockType>(innerType))
return success();
if (hw::getBitWidth(innerType) < 0)
return emitError() << "state type must have a known bit width; got "
<< innerType;
Expand Down
4 changes: 2 additions & 2 deletions lib/Dialect/Arc/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@ add_circt_dialect_library(CIRCTArcTransforms
InlineArcs.cpp
IsolateClocks.cpp
LatencyRetiming.cpp
LegalizeStateUpdate.cpp
LowerArcsToFuncs.cpp
LowerClocksToFuncs.cpp
LowerLUT.cpp
LowerState.cpp
LowerStateRewrite.cpp
LowerVectorizations.cpp
MakeTables.cpp
MergeIfs.cpp
Expand All @@ -33,6 +32,7 @@ add_circt_dialect_library(CIRCTArcTransforms
CIRCTComb
CIRCTEmit
CIRCTHW
CIRCTLLHD
CIRCTOM
CIRCTSV
CIRCTSeq
Expand Down
Loading

0 comments on commit ca2369f

Please sign in to comment.