From 420103999d0bd104a71a2a920206bb4104f11fa1 Mon Sep 17 00:00:00 2001 From: fzi-hielscher <47524191+fzi-hielscher@users.noreply.github.com> Date: Thu, 15 Aug 2024 17:52:11 +0200 Subject: [PATCH] [Arc] Add InitialOp and lowering support for FirReg preset values. (#7480) --- include/circt/Conversion/Passes.td | 2 +- include/circt/Dialect/Arc/ArcOps.td | 78 ++++++++++------ include/circt/Dialect/Arc/ModelInfo.h | 6 +- .../arcilator/JIT/initial-shift-reg.mlir | 69 ++++++++++++++ integration_test/arcilator/JIT/initial.mlir | 25 +++++ lib/Conversion/ArcToLLVM/LowerArcToLLVM.cpp | 14 ++- .../ConvertToArcs/ConvertToArcs.cpp | 42 ++++++++- lib/Dialect/Arc/ArcOps.cpp | 25 +++++ lib/Dialect/Arc/CMakeLists.txt | 1 + lib/Dialect/Arc/ModelInfo.cpp | 6 +- .../Arc/Transforms/LegalizeStateUpdate.cpp | 2 + .../Arc/Transforms/LowerClocksToFuncs.cpp | 91 +++++++++++++++---- lib/Dialect/Arc/Transforms/LowerState.cpp | 74 ++++++++++++--- lib/Dialect/Arc/Transforms/StripSV.cpp | 12 ++- .../ConvertToArcs/convert-to-arcs.mlir | 46 ++++++++++ test/Dialect/Arc/basic-errors.mlir | 39 ++++++++ .../Arc/lower-clocks-to-funcs-errors.mlir | 30 +++++- test/Dialect/Arc/lower-clocks-to-funcs.mlir | 14 ++- test/Dialect/Arc/lower-state-errors.mlir | 24 +++++ test/Dialect/Arc/lower-state.mlir | 38 ++++++++ tools/arcilator/arcilator-header-cpp.py | 10 +- 21 files changed, 584 insertions(+), 64 deletions(-) create mode 100644 integration_test/arcilator/JIT/initial-shift-reg.mlir create mode 100644 integration_test/arcilator/JIT/initial.mlir create mode 100644 test/Dialect/Arc/lower-state-errors.mlir diff --git a/include/circt/Conversion/Passes.td b/include/circt/Conversion/Passes.td index dddc507bd80f..6eb67b865093 100644 --- a/include/circt/Conversion/Passes.td +++ b/include/circt/Conversion/Passes.td @@ -654,7 +654,7 @@ def ConvertToArcs : Pass<"convert-to-arcs", "mlir::ModuleOp"> { latency. }]; let constructor = "circt::createConvertToArcsPass()"; - let dependentDialects = ["circt::arc::ArcDialect"]; + let dependentDialects = ["circt::arc::ArcDialect", "circt::hw::HWDialect"]; let options = [ Option<"tapRegisters", "tap-registers", "bool", "true", "Make registers observable">, diff --git a/include/circt/Dialect/Arc/ArcOps.td b/include/circt/Dialect/Arc/ArcOps.td index a4ebdae14dde..02781979d217 100644 --- a/include/circt/Dialect/Arc/ArcOps.td +++ b/include/circt/Dialect/Arc/ArcOps.td @@ -134,6 +134,9 @@ def StateOp : ArcOp<"state", [ DeclareOpInterfaceMethods, AttrSizedOperandSegments, DeclareOpInterfaceMethods, + PredOpTrait<"types of initial arguments match result types", + CPred<[{getInitials().empty() || + llvm::equal(getInitials().getType(), getResults().getType())}]>> ]> { let summary = "State transfer arc"; @@ -143,13 +146,15 @@ def StateOp : ArcOp<"state", [ Optional:$enable, Optional:$reset, I32Attr:$latency, - Variadic:$inputs); + Variadic:$inputs, + Variadic:$initials); let results = (outs Variadic:$outputs); let assemblyFormat = [{ $arc `(` $inputs `)` (`clock` $clock^)? (`enable` $enable^)? - (`reset` $reset^)? `latency` $latency attr-dict - `:` functional-type($inputs, results) + (`reset` $reset^)? + ( `initial` ` ` `(` $initials^ `:` type($initials) `)`)? + `latency` $latency attr-dict `:` functional-type($inputs, results) }]; let hasFolder = 1; @@ -157,21 +162,24 @@ def StateOp : ArcOp<"state", [ let builders = [ OpBuilder<(ins "DefineOp":$arc, "mlir::Value":$clock, "mlir::Value":$enable, - "unsigned":$latency, CArg<"mlir::ValueRange", "{}">:$inputs), [{ + "unsigned":$latency, CArg<"mlir::ValueRange", "{}">:$inputs, + CArg<"mlir::ValueRange", "{}">:$initials), [{ build($_builder, $_state, mlir::SymbolRefAttr::get(arc), arc.getFunctionType().getResults(), clock, enable, latency, - inputs); + inputs, initials); }]>, OpBuilder<(ins "mlir::SymbolRefAttr":$arc, "mlir::TypeRange":$results, "mlir::Value":$clock, "mlir::Value":$enable, "unsigned":$latency, - CArg<"mlir::ValueRange", "{}">:$inputs + CArg<"mlir::ValueRange", "{}">:$inputs, + CArg<"mlir::ValueRange", "{}">:$initials ), [{ build($_builder, $_state, arc, results, clock, enable, Value(), latency, - inputs); + inputs, initials); }]>, OpBuilder<(ins "mlir::SymbolRefAttr":$arc, "mlir::TypeRange":$results, "mlir::Value":$clock, "mlir::Value":$enable, "mlir::Value":$reset, - "unsigned":$latency, CArg<"mlir::ValueRange", "{}">:$inputs + "unsigned":$latency, CArg<"mlir::ValueRange", "{}">:$inputs, + CArg<"mlir::ValueRange", "{}">:$initials ), [{ if (clock) $_state.addOperands(clock); @@ -180,6 +188,7 @@ def StateOp : ArcOp<"state", [ if (reset) $_state.addOperands(reset); $_state.addOperands(inputs); + $_state.addOperands(initials); $_state.addAttribute("arc", arc); $_state.addAttribute("latency", $_builder.getI32IntegerAttr(latency)); $_state.addAttribute(getOperandSegmentSizeAttr(), @@ -187,23 +196,26 @@ def StateOp : ArcOp<"state", [ clock ? 1 : 0, enable ? 1 : 0, reset ? 1 : 0, - static_cast(inputs.size())})); + static_cast(inputs.size()), + static_cast(initials.size())})); $_state.addTypes(results); }]>, OpBuilder<(ins "mlir::StringAttr":$arc, "mlir::TypeRange":$results, "mlir::Value":$clock, "mlir::Value":$enable, "unsigned":$latency, - CArg<"mlir::ValueRange", "{}">:$inputs + CArg<"mlir::ValueRange", "{}">:$inputs, + CArg<"mlir::ValueRange", "{}">:$initials ), [{ build($_builder, $_state, mlir::SymbolRefAttr::get(arc), results, clock, - enable, latency, inputs); + enable, latency, inputs, initials); }]>, OpBuilder<(ins "mlir::StringRef":$arc, "mlir::TypeRange":$results, "mlir::Value":$clock, "mlir::Value":$enable, "unsigned":$latency, - CArg<"mlir::ValueRange", "{}">:$inputs + CArg<"mlir::ValueRange", "{}">:$inputs, + CArg<"mlir::ValueRange", "{}">:$initials ), [{ build($_builder, $_state, mlir::StringAttr::get($_builder.getContext(), arc), - results, clock, enable, latency, inputs); + results, clock, enable, latency, inputs, initials); }]> ]; let skipDefaultBuilders = 1; @@ -429,26 +441,37 @@ def ClockDomainOp : ArcOp<"clock_domain", [ let hasCanonicalizeMethod = 1; } -def ClockTreeOp : ArcOp<"clock_tree", [NoTerminator, NoRegionArguments]> { +//===----------------------------------------------------------------------===// +// (Pseudo) Clock Trees +//===----------------------------------------------------------------------===// + +class ClockTreeLikeOp traits = []>: + ArcOp +])> { + let regions = (region SizedRegion<1>:$body); +} + +def ClockTreeOp : ClockTreeLikeOp<"clock_tree"> { let summary = "A clock tree"; let arguments = (ins I1:$clock); - let regions = (region SizedRegion<1>:$body); let assemblyFormat = [{ $clock attr-dict-with-keyword $body }]; - let extraClassDeclaration = [{ - mlir::Block &getBodyBlock() { return getBody().front(); } - }]; } -def PassThroughOp : ArcOp<"passthrough", [NoTerminator, NoRegionArguments]> { +def PassThroughOp : ClockTreeLikeOp<"passthrough"> { let summary = "Clock-less logic that is on the pass-through path"; - let regions = (region SizedRegion<1>:$body); let assemblyFormat = [{ attr-dict-with-keyword $body }]; - let extraClassDeclaration = [{ - mlir::Block &getBodyBlock() { return getBody().front(); } +} + +def InitialOp : ClockTreeLikeOp<"initial"> { + let summary = "Clock-less logic called at the start of simulation"; + let assemblyFormat = [{ + attr-dict-with-keyword $body }]; } @@ -651,19 +674,22 @@ def TapOp : ArcOp<"tap"> { let assemblyFormat = [{ $value attr-dict `:` type($value) }]; } -def ModelOp : ArcOp<"model", [RegionKindInterface, IsolatedFromAbove, - NoTerminator, Symbol]> { +def ModelOp : ArcOp<"model", [ + RegionKindInterface, IsolatedFromAbove, NoTerminator, Symbol, + DeclareOpInterfaceMethods +]> { let summary = "A model with stratified clocks"; let description = [{ A model with stratified clocks. The `io` optional attribute specifies the I/O of the module associated to this model. }]; let arguments = (ins SymbolNameAttr:$sym_name, - TypeAttrOf:$io); + TypeAttrOf:$io, + OptionalAttr:$initialFn); let regions = (region SizedRegion<1>:$body); let assemblyFormat = [{ - $sym_name `io` $io attr-dict-with-keyword $body + $sym_name `io` $io (`initializer` $initialFn^)? attr-dict-with-keyword $body }]; let extraClassDeclaration = [{ diff --git a/include/circt/Dialect/Arc/ModelInfo.h b/include/circt/Dialect/Arc/ModelInfo.h index 322f38b774c5..ca3918a772f0 100644 --- a/include/circt/Dialect/Arc/ModelInfo.h +++ b/include/circt/Dialect/Arc/ModelInfo.h @@ -36,11 +36,13 @@ struct ModelInfo { std::string name; size_t numStateBytes; llvm::SmallVector states; + mlir::FlatSymbolRefAttr initialFnSym; ModelInfo(std::string name, size_t numStateBytes, - llvm::SmallVector states) + llvm::SmallVector states, + mlir::FlatSymbolRefAttr initialFnSym) : name(std::move(name)), numStateBytes(numStateBytes), - states(std::move(states)) {} + states(std::move(states)), initialFnSym(initialFnSym) {} }; /// Collects information about states within the provided Arc model storage diff --git a/integration_test/arcilator/JIT/initial-shift-reg.mlir b/integration_test/arcilator/JIT/initial-shift-reg.mlir new file mode 100644 index 000000000000..3724962d8a7f --- /dev/null +++ b/integration_test/arcilator/JIT/initial-shift-reg.mlir @@ -0,0 +1,69 @@ +// RUN: arcilator %s --run --jit-entry=main | FileCheck %s +// REQUIRES: arcilator-jit + +// CHECK-LABEL: output = ca +// CHECK-NEXT: output = ca +// CHECK-NEXT: output = 0 +// CHECK-NEXT: output = fe +// CHECK-NEXT: output = ff + +module { + + hw.module @shiftreg(in %clock : i1, in %reset : i1, in %en : i1, in %din : i8, out dout : i8) { + %seq_clk = seq.to_clock %clock + %srA = seq.firreg %0 clock %seq_clk preset 0xFE : i8 + %srB = seq.firreg %1 clock %seq_clk : i8 + %srC = seq.firreg %2 clock %seq_clk preset 0xCA : i8 + %0 = comb.mux bin %en, %din, %srA : i8 + %1 = comb.mux bin %en, %srA, %srB : i8 + %2 = comb.mux bin %en, %srB, %srC : i8 + hw.output %srC : i8 + } + + func.func @main() { + %ff = arith.constant 0xFF : i8 + %false = arith.constant 0 : i1 + %true = arith.constant 1 : i1 + + arc.sim.instantiate @shiftreg as %model { + arc.sim.set_input %model, "en" = %false : i1, !arc.sim.instance<@shiftreg> + arc.sim.set_input %model, "reset" = %false : i1, !arc.sim.instance<@shiftreg> + arc.sim.set_input %model, "din" = %ff : i8, !arc.sim.instance<@shiftreg> + + %res0 = arc.sim.get_port %model, "dout" : i8, !arc.sim.instance<@shiftreg> + arc.sim.emit "output", %res0 : i8 + + arc.sim.set_input %model, "clock" = %true : i1, !arc.sim.instance<@shiftreg> + arc.sim.step %model : !arc.sim.instance<@shiftreg> + arc.sim.set_input %model, "clock" = %false : i1, !arc.sim.instance<@shiftreg> + arc.sim.step %model : !arc.sim.instance<@shiftreg> + + %res1 = arc.sim.get_port %model, "dout" : i8, !arc.sim.instance<@shiftreg> + arc.sim.emit "output", %res1 : i8 + + arc.sim.set_input %model, "en" = %true : i1, !arc.sim.instance<@shiftreg> + + arc.sim.set_input %model, "clock" = %true : i1, !arc.sim.instance<@shiftreg> + arc.sim.step %model : !arc.sim.instance<@shiftreg> + arc.sim.set_input %model, "clock" = %false : i1, !arc.sim.instance<@shiftreg> + arc.sim.step %model : !arc.sim.instance<@shiftreg> + %res2 = arc.sim.get_port %model, "dout" : i8, !arc.sim.instance<@shiftreg> + arc.sim.emit "output", %res2 : i8 + + arc.sim.set_input %model, "clock" = %true : i1, !arc.sim.instance<@shiftreg> + arc.sim.step %model : !arc.sim.instance<@shiftreg> + arc.sim.set_input %model, "clock" = %false : i1, !arc.sim.instance<@shiftreg> + arc.sim.step %model : !arc.sim.instance<@shiftreg> + %res3 = arc.sim.get_port %model, "dout" : i8, !arc.sim.instance<@shiftreg> + arc.sim.emit "output", %res3 : i8 + + arc.sim.set_input %model, "clock" = %true : i1, !arc.sim.instance<@shiftreg> + arc.sim.step %model : !arc.sim.instance<@shiftreg> + arc.sim.set_input %model, "clock" = %false : i1, !arc.sim.instance<@shiftreg> + arc.sim.step %model : !arc.sim.instance<@shiftreg> + %res4 = arc.sim.get_port %model, "dout" : i8, !arc.sim.instance<@shiftreg> + arc.sim.emit "output", %res4 : i8 + } + return + } +} diff --git a/integration_test/arcilator/JIT/initial.mlir b/integration_test/arcilator/JIT/initial.mlir new file mode 100644 index 000000000000..7cde23b21072 --- /dev/null +++ b/integration_test/arcilator/JIT/initial.mlir @@ -0,0 +1,25 @@ +// RUN: arcilator %s --run --jit-entry=main 2>&1 >/dev/null | FileCheck %s +// REQUIRES: arcilator-jit + +// CHECK: - Init - + +module { + llvm.func @_arc_env_get_print_stream(i32) -> !llvm.ptr + llvm.func @_arc_libc_fputs(!llvm.ptr, !llvm.ptr) -> i32 + llvm.mlir.global internal constant @global_init_str(" - Init -\0A\00") {addr_space = 0 : i32} + + arc.model @initmodel io !hw.modty<> { + ^bb0(%arg0: !arc.storage): + arc.initial { + %cst0 = llvm.mlir.constant(0 : i32) : i32 + %stderr = llvm.call @_arc_env_get_print_stream(%cst0) : (i32) -> !llvm.ptr + %str = llvm.mlir.addressof @global_init_str : !llvm.ptr + %0 = llvm.call @_arc_libc_fputs(%str, %stderr) : (!llvm.ptr, !llvm.ptr) -> i32 + } + } + func.func @main() { + arc.sim.instantiate @initmodel as %arg0 { + } + return + } +} diff --git a/lib/Conversion/ArcToLLVM/LowerArcToLLVM.cpp b/lib/Conversion/ArcToLLVM/LowerArcToLLVM.cpp index 086038148c04..d6e4f5bcc130 100644 --- a/lib/Conversion/ArcToLLVM/LowerArcToLLVM.cpp +++ b/lib/Conversion/ArcToLLVM/LowerArcToLLVM.cpp @@ -318,6 +318,7 @@ namespace { struct ModelInfoMap { size_t numStateBytes; llvm::DenseMap states; + mlir::FlatSymbolRefAttr initialFnSymbol; }; template @@ -378,6 +379,16 @@ struct SimInstantiateOpLowering Value zero = rewriter.create(loc, rewriter.getI8Type(), 0); rewriter.create(loc, allocated, zero, numStateBytes, false); + + // Call the model's 'initial' function if present. + if (model.initialFnSymbol) { + auto initialFnType = LLVM::LLVMFunctionType::get( + LLVM::LLVMVoidType::get(op.getContext()), + {LLVM::LLVMPointerType::get(op.getContext())}); + rewriter.create(loc, initialFnType, model.initialFnSymbol, + ValueRange{allocated}); + } + rewriter.inlineBlockBefore(&adaptor.getBody().getBlocks().front(), op, {allocated}); rewriter.create(loc, freeFunc, ValueRange{allocated}); @@ -646,7 +657,8 @@ void LowerArcToLLVMPass::runOnOperation() { for (StateInfo &stateInfo : modelInfo.states) states.insert({stateInfo.name, stateInfo}); modelMap.insert({modelInfo.name, - ModelInfoMap{modelInfo.numStateBytes, std::move(states)}}); + ModelInfoMap{modelInfo.numStateBytes, std::move(states), + modelInfo.initialFnSym}}); } patterns.add(callOp.getOperation()); Value clock = stateOp ? stateOp.getClock() : Value{}; Value reset; + SmallVector initialValues; SmallVector absorbedRegs; SmallVector absorbedNames(callOp->getNumResults(), {}); if (auto names = callOp->getAttrOfType("names")) @@ -307,6 +308,8 @@ LogicalResult Converter::absorbRegs(HWModuleOp module) { } } + initialValues.push_back(regOp.getPowerOnValue()); + absorbedRegs.push_back(regOp); // If we absorb a register into the arc, the arc effectively produces that // register's value. So if the register had a name, ensure that we assign @@ -345,6 +348,28 @@ LogicalResult Converter::absorbRegs(HWModuleOp module) { "had a reset."); arc.getResetMutable().assign(reset); } + + bool onlyDefaultInitializers = + llvm::all_of(initialValues, [](auto val) -> bool { return !val; }); + + if (!onlyDefaultInitializers) { + if (!arc.getInitials().empty()) { + return arc.emitError( + "StateOp tried to infer initial values from CompReg, but already " + "had an initial value."); + } + // Create 0 constants for default initialization + for (unsigned i = 0; i < initialValues.size(); ++i) { + if (!initialValues[i]) { + OpBuilder zeroBuilder(arc); + initialValues[i] = zeroBuilder.createOrFold( + arc.getLoc(), + zeroBuilder.getIntegerAttr(arc.getResult(i).getType(), 0)); + } + } + arc.getInitialsMutable().assign(initialValues); + } + if (tapRegisters && llvm::any_of(absorbedNames, [](auto name) { return !cast(name).getValue().empty(); })) @@ -385,6 +410,7 @@ LogicalResult Converter::absorbRegs(HWModuleOp module) { SmallVector outputs; SmallVector names; SmallVector types; + SmallVector initialValues; SmallDenseMap mapping; SmallVector regToOutputMapping; for (auto regOp : regOps) { @@ -395,6 +421,7 @@ LogicalResult Converter::absorbRegs(HWModuleOp module) { types.push_back(regOp.getType()); outputs.push_back(block->addArgument(regOp.getType(), regOp.getLoc())); names.push_back(regOp->getAttrOfType("name")); + initialValues.push_back(regOp.getPowerOnValue()); } regToOutputMapping.push_back(it->second); } @@ -411,9 +438,22 @@ LogicalResult Converter::absorbRegs(HWModuleOp module) { defOp.getBody().push_back(block.release()); builder.setInsertionPoint(module.getBodyBlock()->getTerminator()); + + bool onlyDefaultInitializers = + llvm::all_of(initialValues, [](auto val) -> bool { return !val; }); + + if (onlyDefaultInitializers) + initialValues.clear(); + else + for (unsigned i = 0; i < initialValues.size(); ++i) { + if (!initialValues[i]) + initialValues[i] = builder.createOrFold( + loc, builder.getIntegerAttr(types[i], 0)); + } + auto arcOp = builder.create(loc, defOp, std::get<0>(clockAndResetAndOp), - /*enable=*/Value{}, 1, inputs); + /*enable=*/Value{}, 1, inputs, initialValues); auto reset = std::get<1>(clockAndResetAndOp); if (reset) arcOp.getResetMutable().assign(reset); diff --git a/lib/Dialect/Arc/ArcOps.cpp b/lib/Dialect/Arc/ArcOps.cpp index 87070ed8c398..76e457b8d77d 100644 --- a/lib/Dialect/Arc/ArcOps.cpp +++ b/lib/Dialect/Arc/ArcOps.cpp @@ -8,6 +8,7 @@ #include "circt/Dialect/Arc/ArcOps.h" #include "circt/Dialect/HW/HWOpInterfaces.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" @@ -307,6 +308,30 @@ LogicalResult ModelOp::verify() { return success(); } +LogicalResult ModelOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + if (!getInitialFn().has_value()) + return success(); + + auto referencedOp = + symbolTable.lookupNearestSymbolFrom(*this, getInitialFnAttr()); + if (!referencedOp) + return emitError("Cannot find declaration of initializer function '") + << *getInitialFn() << "'."; + auto funcOp = dyn_cast(referencedOp); + if (!funcOp) { + auto diag = emitError("Referenced initializer must be a 'func.func' op."); + diag.attachNote(referencedOp->getLoc()) << "Initializer declared here:"; + return diag; + } + if (!llvm::equal(funcOp.getArgumentTypes(), getBody().getArgumentTypes())) { + auto diag = emitError("Arguments of initializer function must match " + "arguments of model body."); + diag.attachNote(referencedOp->getLoc()) << "Initializer declared here:"; + return diag; + } + return success(); +} + //===----------------------------------------------------------------------===// // LutOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Arc/CMakeLists.txt b/lib/Dialect/Arc/CMakeLists.txt index e6ed7c9c4533..c1703987ed40 100644 --- a/lib/Dialect/Arc/CMakeLists.txt +++ b/lib/Dialect/Arc/CMakeLists.txt @@ -32,6 +32,7 @@ add_circt_dialect_library(CIRCTArc CIRCTSeq MLIRIR MLIRInferTypeOpInterface + MLIRFuncDialect MLIRSideEffectInterfaces MLIRFuncDialect ) diff --git a/lib/Dialect/Arc/ModelInfo.cpp b/lib/Dialect/Arc/ModelInfo.cpp index a16dc0b68e24..91e0449df396 100644 --- a/lib/Dialect/Arc/ModelInfo.cpp +++ b/lib/Dialect/Arc/ModelInfo.cpp @@ -105,6 +105,7 @@ LogicalResult circt::arc::collectStates(Value storage, unsigned offset, LogicalResult circt::arc::collectModels(mlir::ModuleOp module, SmallVector &models) { + for (auto modelOp : module.getOps()) { auto storageArg = modelOp.getBody().getArgument(0); auto storageType = cast(storageArg.getType()); @@ -115,7 +116,7 @@ LogicalResult circt::arc::collectModels(mlir::ModuleOp module, llvm::sort(states, [](auto &a, auto &b) { return a.offset < b.offset; }); models.emplace_back(std::string(modelOp.getName()), storageType.getSize(), - std::move(states)); + std::move(states), modelOp.getInitialFnAttr()); } return success(); @@ -130,6 +131,9 @@ void circt::arc::serializeModelInfoToJson(llvm::raw_ostream &outputStream, json.object([&] { json.attribute("name", model.name); json.attribute("numStateBytes", model.numStateBytes); + json.attribute("initialFnSym", !model.initialFnSym + ? "" + : model.initialFnSym.getValue()); json.attributeArray("states", [&] { for (const auto &state : model.states) { json.object([&] { diff --git a/lib/Dialect/Arc/Transforms/LegalizeStateUpdate.cpp b/lib/Dialect/Arc/Transforms/LegalizeStateUpdate.cpp index 486235d0bf5b..b68abef62513 100644 --- a/lib/Dialect/Arc/Transforms/LegalizeStateUpdate.cpp +++ b/lib/Dialect/Arc/Transforms/LegalizeStateUpdate.cpp @@ -30,6 +30,8 @@ using namespace arc; /// Check if an operation partakes in state accesses. static bool isOpInteresting(Operation *op) { + if (isa(op)) + return false; if (isa(op)) return true; if (op->getNumRegions() > 0) diff --git a/lib/Dialect/Arc/Transforms/LowerClocksToFuncs.cpp b/lib/Dialect/Arc/Transforms/LowerClocksToFuncs.cpp index 0ef473e428cf..c76467e80f38 100644 --- a/lib/Dialect/Arc/Transforms/LowerClocksToFuncs.cpp +++ b/lib/Dialect/Arc/Transforms/LowerClocksToFuncs.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Pass/Pass.h" +#include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE "arc-lower-clocks-to-funcs" @@ -50,6 +51,9 @@ struct LowerClocksToFuncsPass Statistic numOpsCopied{this, "ops-copied", "Ops copied into clock trees"}; Statistic numOpsMoved{this, "ops-moved", "Ops moved into clock trees"}; + +private: + bool hasPassthroughOp; }; } // namespace @@ -65,11 +69,38 @@ LogicalResult LowerClocksToFuncsPass::lowerModel(ModelOp modelOp) { << "`\n"); // Find the clocks to extract. + SmallVector initialOps; + SmallVector passthroughOps; SmallVector clocks; modelOp.walk([&](Operation *op) { - if (isa(op)) - clocks.push_back(op); + TypeSwitch(op) + .Case([&](auto) { clocks.push_back(op); }) + .Case([&](auto initOp) { + initialOps.push_back(initOp); + clocks.push_back(initOp); + }) + .Case([&](auto ptOp) { + passthroughOps.push_back(ptOp); + clocks.push_back(ptOp); + }); }); + hasPassthroughOp = !passthroughOps.empty(); + + // Sanity check + if (passthroughOps.size() > 1) { + auto diag = modelOp.emitOpError() + << "containing multiple PassThroughOps cannot be lowered."; + for (auto ptOp : passthroughOps) + diag.attachNote(ptOp.getLoc()) << "Conflicting PassThroughOp:"; + } + if (initialOps.size() > 1) { + auto diag = modelOp.emitOpError() + << "containing multiple InitialOps is currently unsupported."; + for (auto initOp : initialOps) + diag.attachNote(initOp.getLoc()) << "Conflicting InitialOp:"; + } + if (passthroughOps.size() > 1 || initialOps.size() > 1) + return failure(); // Perform the actual extraction. OpBuilder funcBuilder(modelOp); @@ -84,7 +115,7 @@ LogicalResult LowerClocksToFuncsPass::lowerClock(Operation *clockOp, Value modelStorageArg, OpBuilder &funcBuilder) { LLVM_DEBUG(llvm::dbgs() << "- Lowering clock " << clockOp->getName() << "\n"); - assert((isa(clockOp))); + assert((isa(clockOp))); // Add a `StorageType` block argument to the clock's body block which we are // going to use to pass the storage pointer to the clock once it has been @@ -103,8 +134,16 @@ LogicalResult LowerClocksToFuncsPass::lowerClock(Operation *clockOp, // Pick a name for the clock function. SmallString<32> funcName; - funcName.append(clockOp->getParentOfType().getName()); - funcName.append(isa(clockOp) ? "_passthrough" : "_clock"); + auto modelOp = clockOp->getParentOfType(); + funcName.append(modelOp.getName()); + + if (isa(clockOp)) + funcName.append("_passthrough"); + else if (isa(clockOp)) + funcName.append("_initial"); + else + funcName.append("_clock"); + auto funcOp = funcBuilder.create( clockOp->getLoc(), funcName, builder.getFunctionType({modelStorageArg.getType()}, {})); @@ -114,21 +153,41 @@ LogicalResult LowerClocksToFuncsPass::lowerClock(Operation *clockOp, // Create a call to the function within the model. builder.setInsertionPoint(clockOp); - if (auto treeOp = dyn_cast(clockOp)) { - auto ifOp = - builder.create(clockOp->getLoc(), treeOp.getClock(), false); - auto builder = ifOp.getThenBodyBuilder(); - builder.create(clockOp->getLoc(), funcOp, - ValueRange{modelStorageArg}); - } else { - builder.create(clockOp->getLoc(), funcOp, - ValueRange{modelStorageArg}); - } + TypeSwitch(clockOp) + .Case([&](auto treeOp) { + auto ifOp = builder.create(clockOp->getLoc(), + treeOp.getClock(), false); + auto builder = ifOp.getThenBodyBuilder(); + builder.template create(clockOp->getLoc(), funcOp, + ValueRange{modelStorageArg}); + }) + .Case([&](auto) { + builder.template create(clockOp->getLoc(), funcOp, + ValueRange{modelStorageArg}); + }) + .Case([&](auto) { + if (modelOp.getInitialFn().has_value()) + modelOp.emitWarning() << "Existing model initializer '" + << modelOp.getInitialFnAttr().getValue() + << "' will be overridden."; + modelOp.setInitialFnAttr( + FlatSymbolRefAttr::get(funcOp.getSymNameAttr())); + }); // Move the clock's body block to the function and remove the old clock op. funcOp.getBody().takeBody(clockRegion); - clockOp->erase(); + if (isa(clockOp) && hasPassthroughOp) { + // Call PassThroughOp after init + builder.setInsertionPoint(funcOp.getBlocks().front().getTerminator()); + funcName.clear(); + funcName.append(modelOp.getName()); + funcName.append("_passthrough"); + builder.create(clockOp->getLoc(), funcName, TypeRange{}, + ValueRange{funcOp.getBody().getArgument(0)}); + } + + clockOp->erase(); return success(); } diff --git a/lib/Dialect/Arc/Transforms/LowerState.cpp b/lib/Dialect/Arc/Transforms/LowerState.cpp index 7f5025b8e457..d659e8870397 100644 --- a/lib/Dialect/Arc/Transforms/LowerState.cpp +++ b/lib/Dialect/Arc/Transforms/LowerState.cpp @@ -8,6 +8,7 @@ #include "circt/Dialect/Arc/ArcOps.h" #include "circt/Dialect/Arc/ArcPasses.h" +#include "circt/Dialect/Comb/CombDialect.h" #include "circt/Dialect/Comb/CombOps.h" #include "circt/Dialect/HW/HWOps.h" #include "circt/Dialect/Seq/SeqOps.h" @@ -63,7 +64,7 @@ struct Statistics { struct ClockLowering { /// The root clock this lowering is for. Value clock; - /// A `ClockTreeOp` or `PassThroughOp`. + /// A `ClockTreeOp` or `PassThroughOp` or `InitialOp`. Operation *treeOp; /// Pass statistics. Statistics &stats; @@ -76,15 +77,21 @@ struct ClockLowering { /// A cache of OR gates created for aggregating enable conditions. DenseMap, Value> orCache; + // Prevent accidental construction and copying + ClockLowering() = delete; + ClockLowering(const ClockLowering &other) = delete; + ClockLowering(Value clock, Operation *treeOp, Statistics &stats) : clock(clock), treeOp(treeOp), stats(stats), builder(treeOp) { - assert((isa(treeOp))); + assert((isa(treeOp))); builder.setInsertionPointToStart(&treeOp->getRegion(0).front()); } Value materializeValue(Value value); Value getOrCreateAnd(Value lhs, Value rhs, Location loc); Value getOrCreateOr(Value lhs, Value rhs, Location loc); + + bool isInitialTree() const { return isa(treeOp); } }; struct GatedClockLowering { @@ -102,6 +109,7 @@ struct ModuleLowering { MLIRContext *context; DenseMap> clockLowerings; DenseMap gatedClockLowerings; + std::unique_ptr initialLowering; Value storageArg; OpBuilder clockBuilder; OpBuilder stateBuilder; @@ -112,6 +120,7 @@ struct ModuleLowering { GatedClockLowering getOrCreateClockLowering(Value clock); ClockLowering &getOrCreatePassThrough(); + ClockLowering &getInitial(); Value replaceValueWithStateRead(Value value, Value state); void addStorageArg(); @@ -121,7 +130,8 @@ struct ModuleLowering { template LogicalResult lowerStateLike(Operation *op, Value clock, Value enable, Value reset, ArrayRef inputs, - FlatSymbolRefAttr callee); + FlatSymbolRefAttr callee, + ArrayRef initialValues = {}); LogicalResult lowerState(StateOp stateOp); LogicalResult lowerState(sim::DPICallOp dpiCallOp); LogicalResult lowerState(MemoryOp memOp); @@ -159,6 +169,17 @@ static bool shouldMaterialize(Value value) { return shouldMaterialize(op); } +static bool canBeMaterializedInInitializer(Operation *op) { + if (!op) + return false; + if (op->hasTrait()) + return true; + if (isa(op->getDialect())) + return true; + // TODO: There are some other ops we probably want to allow + return false; +} + /// Materialize a value within this clock tree. This clones or moves all /// operations required to produce this value inside the clock tree. Value ClockLowering::materializeValue(Value value) { @@ -206,6 +227,10 @@ Value ClockLowering::materializeValue(Value value) { while (!worklist.empty()) { auto &workItem = worklist.back(); + if (isInitialTree() && !canBeMaterializedInInitializer(workItem.op)) { + workItem.op->emitError("Value cannot be used in initializer."); + return {}; + } if (!workItem.operands.empty()) { auto operand = workItem.operands.pop_back_val(); if (materializedValues.contains(operand) || !shouldMaterialize(operand)) @@ -317,6 +342,11 @@ ClockLowering &ModuleLowering::getOrCreatePassThrough() { return *slot; } +ClockLowering &ModuleLowering::getInitial() { + assert(!!initialLowering && "Initial tree op should have been constructed"); + return *initialLowering; +} + /// Replace all uses of a value with a `StateReadOp` on a state. Value ModuleLowering::replaceValueWithStateRead(Value value, Value state) { OpBuilder builder(state.getContext()); @@ -415,7 +445,8 @@ LogicalResult ModuleLowering::lowerStates() { template LogicalResult ModuleLowering::lowerStateLike( Operation *stateOp, Value stateClock, Value stateEnable, Value stateReset, - ArrayRef stateInputs, FlatSymbolRefAttr callee) { + ArrayRef stateInputs, FlatSymbolRefAttr callee, + ArrayRef initialValues) { // Grab all operands from the state op at the callsite and make it drop all // its references. This allows `materializeValue` to move an operation if this // state was the last user. @@ -470,10 +501,23 @@ LogicalResult ModuleLowering::lowerStateLike( thenBuilder.create(stateOp->getLoc(), alloc, constZero, Value()); } - nonResetBuilder = ifOp.getElseBodyBuilder(); } + if (!initialValues.empty()) { + assert(initialValues.size() == allocatedStates.size() && + "Unexpected number of initializers"); + auto &initialTree = getInitial(); + for (auto [alloc, init] : llvm::zip(allocatedStates, initialValues)) { + // TODO: Can we get away without materialization? + auto matierializedInit = initialTree.materializeValue(init); + if (!matierializedInit) + return failure(); + initialTree.builder.create(stateOp->getLoc(), alloc, + matierializedInit, Value()); + } + } + stateOp->dropAllReferences(); auto newStateOp = nonResetBuilder.create( @@ -501,10 +545,11 @@ LogicalResult ModuleLowering::lowerState(StateOp stateOp) { return stateOp.emitError("state with latency > 1 not supported"); auto stateInputs = SmallVector(stateOp.getInputs()); + auto stateInitializers = SmallVector(stateOp.getInitials()); - return lowerStateLike(stateOp, stateOp.getClock(), - stateOp.getEnable(), stateOp.getReset(), - stateInputs, stateOp.getArcAttr()); + return lowerStateLike( + stateOp, stateOp.getClock(), stateOp.getEnable(), stateOp.getReset(), + stateInputs, stateOp.getArcAttr(), stateInitializers); } LogicalResult ModuleLowering::lowerState(sim::DPICallOp callOp) { @@ -829,6 +874,13 @@ LogicalResult LowerStatePass::runOnModule(HWModuleOp moduleOp, Operation *clockSentinel = lowering.stateBuilder.create(moduleOp.getLoc()); + // Create the 'initial' pseudo clock tree. + auto initialTreeOp = + lowering.stateBuilder.create(moduleOp.getLoc()); + initialTreeOp.getBody().emplaceBlock(); + lowering.initialLowering = + std::make_unique(Value{}, initialTreeOp, stats); + lowering.stateBuilder.setInsertionPoint(stateSentinel); lowering.clockBuilder.setInsertionPoint(clockSentinel); @@ -856,9 +908,9 @@ LogicalResult LowerStatePass::runOnModule(HWModuleOp moduleOp, moduleOp.getBodyBlock()->eraseArguments( [&](auto arg) { return arg != lowering.storageArg; }); ImplicitLocOpBuilder builder(moduleOp.getLoc(), moduleOp); - auto modelOp = - builder.create(moduleOp.getLoc(), moduleOp.getModuleNameAttr(), - TypeAttr::get(moduleOp.getModuleType())); + auto modelOp = builder.create( + moduleOp.getLoc(), moduleOp.getModuleNameAttr(), + TypeAttr::get(moduleOp.getModuleType()), mlir::FlatSymbolRefAttr()); modelOp.getBody().takeBody(moduleOp.getBody()); moduleOp->erase(); sortTopologically(&modelOp.getBodyBlock()); diff --git a/lib/Dialect/Arc/Transforms/StripSV.cpp b/lib/Dialect/Arc/Transforms/StripSV.cpp index f18602261b40..dbe400e81bfb 100644 --- a/lib/Dialect/Arc/Transforms/StripSV.cpp +++ b/lib/Dialect/Arc/Transforms/StripSV.cpp @@ -151,9 +151,19 @@ void StripSVPass::runOnOperation() { else next = reg.getNext(); + Value presetValue; + // Materialize initial value, assume zero initialization as default. + if (reg.getPreset() && !reg.getPreset()->isZero()) { + assert(hw::type_isa(reg.getType()) && + "cannot lower non integer preset"); + presetValue = builder.createOrFold( + reg.getLoc(), IntegerAttr::get(reg.getType(), *reg.getPreset())); + } + Value compReg = builder.create( reg.getLoc(), next.getType(), next, reg.getClk(), reg.getNameAttr(), - Value{}, Value{}, Value{}, reg.getInnerSymAttr()); + Value{}, Value{}, /*powerOnValue*/ presetValue, + reg.getInnerSymAttr()); reg.replaceAllUsesWith(compReg); opsToDelete.push_back(reg); continue; diff --git a/test/Conversion/ConvertToArcs/convert-to-arcs.mlir b/test/Conversion/ConvertToArcs/convert-to-arcs.mlir index 21fac9ecc0b5..047a82f59eb3 100644 --- a/test/Conversion/ConvertToArcs/convert-to-arcs.mlir +++ b/test/Conversion/ConvertToArcs/convert-to-arcs.mlir @@ -110,6 +110,36 @@ hw.module @Reshuffling(in %clockA: !seq.clock, in %clockB: !seq.clock, out z0: i hw.module.extern private @Reshuffling2(out z0: i4, out z1: i4, out z2: i4, out z3: i4) +// CHECK-LABEL: arc.define @ReshufflingInit_arc(%arg0: i4, %arg1: i4) +// CHECK-NEXT: arc.output %arg0, %arg1 +// CHECK-NEXT: } + +// CHECK-LABEL: arc.define @ReshufflingInit_arc_0(%arg0: i4, %arg1: i4) +// CHECK-NEXT: arc.output %arg0, %arg1 +// CHECK-NEXT: } + +// CHECK-LABEL: hw.module @ReshufflingInit +hw.module @ReshufflingInit(in %clockA: !seq.clock, in %clockB: !seq.clock, out z0: i4, out z1: i4, out z2: i4, out z3: i4) { + // CHECK-NEXT: [[C1:%.+]] = hw.constant 1 : i4 + // CHECK-NEXT: [[C2:%.+]] = hw.constant 2 : i4 + // CHECK-NEXT: [[C3:%.+]] = hw.constant 3 : i4 + // CHECK-NEXT: hw.instance "x" @Reshuffling2() + // CHECK-NEXT: [[C0:%.+]] = hw.constant 0 : i4 + // CHECK-NEXT: arc.state @ReshufflingInit_arc(%x.z0, %x.z1) clock %clockA initial ([[C0]], [[C1]] : i4, i4) latency 1 + // CHECK-NEXT: arc.state @ReshufflingInit_arc_0(%x.z2, %x.z3) clock %clockB initial ([[C2]], [[C3]] : i4, i4) latency 1 + // CHECK-NEXT: hw.output + %cst1 = hw.constant 1 : i4 + %cst2 = hw.constant 2 : i4 + %cst3 = hw.constant 3 : i4 + %x.z0, %x.z1, %x.z2, %x.z3 = hw.instance "x" @Reshuffling2() -> (z0: i4, z1: i4, z2: i4, z3: i4) + %4 = seq.compreg %x.z0, %clockA : i4 + %5 = seq.compreg %x.z1, %clockA powerOn %cst1 : i4 + %6 = seq.compreg %x.z2, %clockB powerOn %cst2 : i4 + %7 = seq.compreg %x.z3, %clockB powerOn %cst3 : i4 + hw.output %4, %5, %6, %7 : i4, i4, i4, i4 +} +// CHECK-NEXT: } + // CHECK-LABEL: arc.define @FactorOutCommonOps_arc( // CHECK-NEXT: comb.xor @@ -196,6 +226,22 @@ hw.module @Trivial(in %clock: !seq.clock, in %i0: i4, in %reset: i1, out out: i4 } // CHECK-NEXT: } +// CHECK: arc.define @[[TRIVIALINIT_ARC:.+]]([[ARG0:%.+]]: i4) +// CHECK-NEXT: arc.output [[ARG0]] +// CHECK-NEXT: } + +// CHECK-LABEL: hw.module @TrivialWithInit( +hw.module @TrivialWithInit(in %clock: !seq.clock, in %i0: i4, in %reset: i1, out out: i4) { + // CHECK: [[CST2:%.+]] = hw.constant 2 : i4 + // CHECK: [[RES0:%.+]] = arc.state @[[TRIVIALINIT_ARC]](%i0) clock %clock reset %reset initial ([[CST2]] : i4) latency 1 {names = ["foo"] + // CHECK-NEXT: hw.output [[RES0:%.+]] + %0 = hw.constant 0 : i4 + %cst2 = hw.constant 2 : i4 + %foo = seq.compreg %i0, %clock reset %reset, %0 powerOn %cst2: i4 + hw.output %foo : i4 +} +// CHECK-NEXT: } + // CHECK-NEXT: arc.define @[[NONTRIVIAL_ARC_0:.+]]([[ARG0_1:%.+]]: i4) // CHECK-NEXT: arc.output [[ARG0_1]] // CHECK-NEXT: } diff --git a/test/Dialect/Arc/basic-errors.mlir b/test/Dialect/Arc/basic-errors.mlir index 6fe538beef05..723a29ebd5ee 100644 --- a/test/Dialect/Arc/basic-errors.mlir +++ b/test/Dialect/Arc/basic-errors.mlir @@ -524,3 +524,42 @@ hw.module @vectorize(in %in0: i4, in %in1: i4, out out0: i4) { // expected-error @below {{state type must have a known bit width}} func.func @InvalidStateType(%arg0: !arc.state) + +// ----- + +// expected-error @below {{Cannot find declaration of initializer function 'MissingInitilaizer_initial'.}} +arc.model @MissingInitilaizer io !hw.modty<> initializer @MissingInitilaizer_initial { + ^bb0(%arg0: !arc.storage<42>): +} + +// ----- + +// expected-note @below {{Initializer declared here:}} +hw.module @NonFuncInitilaizer_initial() { +} + +// expected-error @below {{Referenced initializer must be a 'func.func' op.}} +arc.model @NonFuncInitilaizer io !hw.modty<> initializer @NonFuncInitilaizer_initial { + ^bb0(%arg0: !arc.storage<42>): +} + +// ----- + +// expected-note @below {{Initializer declared here:}} +func.func @IncorrectArg_initial(!arc.storage<24>) { + ^bb0(%arg0: !arc.storage<24>): + return +} + +// expected-error @below {{Arguments of initializer function must match arguments of model body.}} +arc.model @IncorrectArg io !hw.modty<> initializer @IncorrectArg_initial { + ^bb0(%arg0: !arc.storage<42>): +} + +// ----- + +hw.module @InvalidInitType(in %clock: !seq.clock, in %input: i7) { + %cst = hw.constant 0 : i8 + // expected-error @below {{failed to verify that types of initial arguments match result types}} + %res = arc.state @Bar(%input) clock %clock initial (%cst: i8) latency 1 : (i7) -> i7 +} diff --git a/test/Dialect/Arc/lower-clocks-to-funcs-errors.mlir b/test/Dialect/Arc/lower-clocks-to-funcs-errors.mlir index a80487779825..a9100f6b9430 100644 --- a/test/Dialect/Arc/lower-clocks-to-funcs-errors.mlir +++ b/test/Dialect/Arc/lower-clocks-to-funcs-errors.mlir @@ -1,4 +1,4 @@ -// RUN: circt-opt %s --arc-lower-clocks-to-funcs --verify-diagnostics +// RUN: circt-opt %s --arc-lower-clocks-to-funcs --split-input-file --verify-diagnostics arc.model @NonConstExternalValue io !hw.modty<> { ^bb0(%arg0: !arc.storage<42>): @@ -12,3 +12,31 @@ arc.model @NonConstExternalValue io !hw.modty<> { %1 = comb.sub %0, %0 : i9001 } } + +// ----- + +func.func @VictimInit(%arg0: !arc.storage<42>) { + return +} + +// expected-warning @below {{Existing model initializer 'VictimInit' will be overridden.}} +arc.model @ExistingInit io !hw.modty<> initializer @VictimInit { +^bb0(%arg0: !arc.storage<42>): + arc.initial {} +} + +// ----- + +// expected-error @below {{op containing multiple PassThroughOps cannot be lowered.}} +// expected-error @below {{op containing multiple InitialOps is currently unsupported.}} +arc.model @MultiInitAndPassThrough io !hw.modty<> { +^bb0(%arg0: !arc.storage<1>): + // expected-note @below {{Conflicting PassThroughOp:}} + arc.passthrough {} + // expected-note @below {{Conflicting InitialOp:}} + arc.initial {} + // expected-note @below {{Conflicting PassThroughOp:}} + arc.passthrough {} + // expected-note @below {{Conflicting InitialOp:}} + arc.initial {} +} diff --git a/test/Dialect/Arc/lower-clocks-to-funcs.mlir b/test/Dialect/Arc/lower-clocks-to-funcs.mlir index 4226706910ee..9dd7d8899965 100644 --- a/test/Dialect/Arc/lower-clocks-to-funcs.mlir +++ b/test/Dialect/Arc/lower-clocks-to-funcs.mlir @@ -14,7 +14,15 @@ // CHECK-NEXT: return // CHECK-NEXT: } -// CHECK-LABEL: arc.model @Trivial io !hw.modty<> { +// CHECK-LABEL: func.func @Trivial_initial(%arg0: !arc.storage<42>) { +// CHECK-NEXT: %true = hw.constant true +// CHECK-NEXT: %c1_i9002 = hw.constant 1 : i9002 +// CHECK-NEXT: %0 = comb.mux %true, %c1_i9002, %c1_i9002 : i9002 +// CHECK-NEXT: call @Trivial_passthrough(%arg0) : (!arc.storage<42>) -> () +// CHECK-NEXT: return +// CHECK-NEXT: } + +// CHECK-LABEL: arc.model @Trivial io !hw.modty<> initializer @Trivial_initial { // CHECK-NEXT: ^bb0(%arg0: !arc.storage<42>): // CHECK-NEXT: %true = hw.constant true // CHECK-NEXT: %false = hw.constant false @@ -36,6 +44,10 @@ arc.model @Trivial io !hw.modty<> { %c1_i9001 = hw.constant 1 : i9001 %0 = comb.mux %true, %c1_i9001, %c1_i9001 : i9001 } + arc.initial { + %c1_i9002 = hw.constant 1 : i9002 + %0 = comb.mux %true, %c1_i9002, %c1_i9002 : i9002 + } } //===----------------------------------------------------------------------===// diff --git a/test/Dialect/Arc/lower-state-errors.mlir b/test/Dialect/Arc/lower-state-errors.mlir new file mode 100644 index 000000000000..cd6135a3e072 --- /dev/null +++ b/test/Dialect/Arc/lower-state-errors.mlir @@ -0,0 +1,24 @@ +// RUN: circt-opt %s --arc-lower-state --split-input-file --verify-diagnostics + +arc.define @DummyArc(%arg0: i42) -> i42 { + arc.output %arg0 : i42 +} + +// expected-error @+1 {{Value cannot be used in initializer.}} +hw.module @argInit(in %clk: !seq.clock, in %input: i42) { + %0 = arc.state @DummyArc(%0) clock %clk initial (%input : i42) latency 1 : (i42) -> i42 +} + + +// ----- + + +arc.define @DummyArc(%arg0: i42) -> i42 { + arc.output %arg0 : i42 +} + +hw.module @argInit(in %clk: !seq.clock, in %input: i42) { + // expected-error @+1 {{Value cannot be used in initializer.}} + %0 = arc.state @DummyArc(%0) clock %clk latency 1 : (i42) -> i42 + %1 = arc.state @DummyArc(%1) clock %clk initial (%0 : i42) latency 1 : (i42) -> i42 +} diff --git a/test/Dialect/Arc/lower-state.mlir b/test/Dialect/Arc/lower-state.mlir index b312bb115c8d..87f3e19cd6c6 100644 --- a/test/Dialect/Arc/lower-state.mlir +++ b/test/Dialect/Arc/lower-state.mlir @@ -366,3 +366,41 @@ hw.module @adder(in %clock : i1, in %a : i32, in %b : i32, out c : i32) { // CHECK-NEXT: %[[RESULT:.+]] = func.call @func(%6, %7) : (i32, i32) -> i32 hw.output %1 : i32 } + +// CHECK-LABEL: arc.model @InitializedStates +hw.module @InitializedStates(in %clk: !seq.clock, in %reset: i1, in %input: i42) { + +// CHECK: [[ST1:%.+]] = arc.alloc_state %arg0 : (!arc.storage) -> !arc.state +// CHECK-NEXT: [[ST2:%.+]] = arc.alloc_state %arg0 : (!arc.storage) -> !arc.state +// CHECK-NEXT: [[ST3:%.+]] = arc.alloc_state %arg0 : (!arc.storage) -> !arc.state +// CHECK-NEXT: [[ST4:%.+]] = arc.alloc_state %arg0 : (!arc.storage) -> !arc.state +// CHECK-NEXT: [[ST5:%.+]] = arc.alloc_state %arg0 : (!arc.storage) -> !arc.state + +// CHECK: arc.initial { + + %csta = hw.constant 1 : i42 + %cstb = hw.constant 10 : i42 + %cstc = hw.constant 100 : i42 + %cstd = hw.constant 1000 : i42 + %add = comb.add bin %cstb, %cstc, %csta : i42 + %mul = comb.mul bin %add, %csta : i42 + + // CHECK-NEXT: [[CSTD:%.+]] = hw.constant 1000 : i42 + // CHECK-NEXT: arc.state_write [[ST1]] = [[CSTD]] : + %0 = arc.state @DummyArc(%input) clock %clk initial (%cstd : i42) latency 1 : (i42) -> i42 + + // CHECK-DAG: [[CSTA:%.+]] = hw.constant 1 : i42 + // CHECK-DAG: [[CSTB:%.+]] = hw.constant 10 : i42 + // CHECK-DAG: [[CSTC:%.+]] = hw.constant 100 : i42 + // CHECK-DAG: [[ADD:%.+]] = comb.add bin [[CSTB]], [[CSTC]], [[CSTA]] : i42 + // CHECK-DAG: [[MUL:%.+]] = comb.mul bin [[ADD]], [[CSTA]] : i42 + + // CHECK: arc.state_write [[ST2]] = [[MUL]] : + %1 = arc.state @DummyArc(%0) clock %clk initial (%mul : i42) latency 1 : (i42) -> i42 + // CHECK-NEXT: arc.state_write [[ST3]] = [[CSTB]] : + %2 = arc.state @DummyArc(%1) clock %clk reset %reset initial (%cstb : i42) latency 1 : (i42) -> i42 + // CHECK-DAG: arc.state_write [[ST4]] = [[CSTB]] : + // CHECK-DAG: arc.state_write [[ST5]] = [[ADD]] : + %3, %4 = arc.state @DummyArc2(%2) clock %clk initial (%cstb, %add : i42, i42) latency 1 : (i42) -> (i42, i42) +// CHECK: } +} diff --git a/tools/arcilator/arcilator-header-cpp.py b/tools/arcilator/arcilator-header-cpp.py index 6a1545d950fc..113aa3094be2 100755 --- a/tools/arcilator/arcilator-header-cpp.py +++ b/tools/arcilator/arcilator-header-cpp.py @@ -63,12 +63,13 @@ class StateHierarchy: class ModelInfo: name: str numStateBytes: int + initialFnSym: str states: List[StateInfo] io: List[StateInfo] hierarchy: List[StateHierarchy] def decode(d: dict) -> "ModelInfo": - return ModelInfo(d["name"], d["numStateBytes"], + return ModelInfo(d["name"], d["numStateBytes"], d.get("initialFnSym", ""), [StateInfo.decode(d) for d in d["states"]], list(), list()) @@ -240,6 +241,8 @@ def indent(s: str, amount: int = 1): io.name = io.name + "_" print('extern "C" {') + if model.initialFnSym: + print(f"void {model.name}_initial(void* state);") print(f"void {model.name}_eval(void* state);") print('}') @@ -297,8 +300,11 @@ def indent(s: str, amount: int = 1): print(f" {model.name}View view;") print() print( - f" {model.name}() : storage({model.name}Layout::numStateBytes, 0), view(&storage[0]) {{}}" + f" {model.name}() : storage({model.name}Layout::numStateBytes, 0), view(&storage[0]) {{" ) + if model.initialFnSym: + print(f" {model.initialFnSym}(&storage[0]);") + print(" }") print(f" void eval() {{ {model.name}_eval(&storage[0]); }}") print( f" ValueChangeDump<{model.name}Layout> vcd(std::basic_ostream &os) {{"