Skip to content

Commit

Permalink
Squashed WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
fzi-hielscher committed Aug 8, 2024
1 parent c08ac4b commit 3815ec1
Show file tree
Hide file tree
Showing 6 changed files with 251 additions and 35 deletions.
59 changes: 48 additions & 11 deletions include/circt/Dialect/Arc/ArcOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ include "circt/Dialect/HW/HWTypes.td"
include "circt/Dialect/Seq/SeqTypes.td"
include "mlir/IR/EnumAttr.td"
include "mlir/Interfaces/FunctionInterfaces.td"
include "mlir/IR/BuiltinAttributeInterfaces.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/RegionKindInterface.td"
include "mlir/IR/SymbolInterfaces.td"
Expand All @@ -27,6 +28,10 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
class ArcOp<string mnemonic, list<Trait> traits = []> :
Op<ArcDialect, mnemonic, traits>;

class StateAndValueTypesMatch<string state, string value> : TypesMatchWith<
"state and value types must match", state, value,
"llvm::cast<StateType>($_self).getType()">;

def DefineOp : ArcOp<"define", [
IsolatedFromAbove,
FunctionOpInterface,
Expand Down Expand Up @@ -458,11 +463,19 @@ def PassThroughOp : ArcOp<"passthrough", [NoTerminator, NoRegionArguments]> {

def AllocStateOp : ArcOp<"alloc_state", [MemoryEffects<[MemAlloc]>]> {
let summary = "Allocate internal state";
let arguments = (ins StorageType:$storage, UnitAttr:$tap);
let arguments = (ins StorageType:$storage, UnitAttr:$tap, OptionalAttr<TypedAttrInterface>:$initial);
let results = (outs StateType:$state);
let assemblyFormat = [{
$storage (`tap` $tap^)? attr-dict `:` functional-type($storage, $state)
$storage (`tap` $tap^)? (`init` `(` $initial^ `)`)? attr-dict `:` functional-type($storage, $state)
}];

let builders = [
OpBuilder<(ins "::mlir::Type":$type, "::mlir::Value":$value, "bool":$tap),
"build($_builder, $_state, type, value, tap, {});">
];

// OptionalTypesMatchWith doesn't work with optional attributes :(
let hasVerifier = true;
}

def AllocMemoryOp : ArcOp<"alloc_memory", [MemoryEffects<[MemAlloc]>]> {
Expand Down Expand Up @@ -526,10 +539,6 @@ def StorageGetOp : ArcOp<"storage.get", [Pure]> {
// State Read/Write
//===----------------------------------------------------------------------===//

class StateAndValueTypesMatch<string state, string value> : TypesMatchWith<
"state and value types must match", state, value,
"llvm::cast<StateType>($_self).getType()">;

def StateReadOp : ArcOp<"state_read", [
MemoryEffects<[MemRead]>,
StateAndValueTypesMatch<"state", "value">
Expand Down Expand Up @@ -651,31 +660,59 @@ def TapOp : ArcOp<"tap"> {
let assemblyFormat = [{ $value attr-dict `:` type($value) }];
}

def ModelOp : ArcOp<"model", [RegionKindInterface, IsolatedFromAbove,
NoTerminator, Symbol]> {
def ModelOp : ArcOp<"model", [
RegionKindInterface, IsolatedFromAbove, Symbol,
SingleBlockImplicitTerminator<"YieldStorageOp">,
DeclareOpInterfaceMethods<RegionBranchOpInterface,
["getSuccessorRegions", "getRegionInvocationBounds"]>
]> {
let summary = "A model with stratified clocks";
let description = [{
A model with stratified clocks. The `io` optional attribute
specifies the I/O of the module associated to this model.
}];
let arguments = (ins SymbolNameAttr:$sym_name,
TypeAttrOf<ModuleType>:$io);
let regions = (region SizedRegion<1>:$body);
let regions = (region SizedRegion<1>:$body, MaxSizedRegion<1>:$initialRegion);

let assemblyFormat = [{
$sym_name `io` $io attr-dict-with-keyword $body
$sym_name `io` $io attr-dict-with-keyword $body (`initial` $initialRegion^)?
}];

let extraClassDeclaration = [{
static mlir::RegionKind getRegionKind(unsigned index) {
return mlir::RegionKind::Graph;
return index == 0 ? mlir::RegionKind::Graph : mlir::RegionKind::SSACFG ;
}
mlir::Block &getBodyBlock() { return getBody().front(); }

bool hasTrivialInitialRegion() {
if (getInitialRegion().empty())
return true;
return &getInitialRegion().front().front() == &getInitialRegion().front().back();
}

}];

let hasVerifier = 1;
}

def YieldStorageOp : ArcOp<"yield_storage",
[Terminator, Pure, HasParent<"arc::ModelOp">,
DeclareOpInterfaceMethods<RegionBranchTerminatorOpInterface>,
]> {
let arguments = (ins Variadic<StorageType>:$storages);
let assemblyFormat = [{ attr-dict ($storages^ `:` type($storages))? }];
let builders = [
OpBuilder<(ins), "build($_builder, $_state, ::mlir::ValueRange{} );">
];
}

def ConstantInitializeOp : ArcOp<"const_init", [MemoryEffects<[MemWrite]>, SameTypeOperands]> {
let summary = "An ugly hack.";
let arguments = (ins AnyType:$state, TypedAttrInterface:$constant);
let assemblyFormat = "` ` `(` $constant `)` `->` $state attr-dict `:` type($state)";
}

def LutOp : ArcOp<"lut", [
IsolatedFromAbove,
SingleBlockImplicitTerminator<"arc::OutputOp">,
Expand Down
44 changes: 34 additions & 10 deletions lib/Conversion/ArcToLLVM/LowerArcToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,30 +52,53 @@ static llvm::Twine evalSymbolFromModelName(StringRef modelName) {
return modelName + "_eval";
}

static llvm::Twine initSymbolFromModelName(StringRef modelName) {
return modelName + "_init";
}

namespace {

struct ModelOpLowering : public OpConversionPattern<arc::ModelOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(arc::ModelOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
{
IRRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToEnd(&op.getBodyBlock());
rewriter.create<func::ReturnOp>(op.getLoc());
}
auto funcName =
auto bodyFuncName =
rewriter.getStringAttr(evalSymbolFromModelName(op.getName()));
auto funcType =
auto bodyFuncType =
rewriter.getFunctionType(op.getBody().getArgumentTypes(), {});
auto func =
rewriter.create<mlir::func::FuncOp>(op.getLoc(), funcName, funcType);
rewriter.inlineRegionBefore(op.getRegion(), func.getBody(), func.end());
auto bodyFunc = rewriter.create<mlir::func::FuncOp>(
op.getLoc(), bodyFuncName, bodyFuncType);
rewriter.inlineRegionBefore(op.getBody(), bodyFunc.getBody(),
bodyFunc.end());

if (!op.hasTrivialInitialRegion()) {
auto initFuncName =
rewriter.getStringAttr(initSymbolFromModelName(op.getName()));
auto initFuncType = rewriter.getFunctionType(
op.getInitialRegion().getArgumentTypes(), {});
auto initFunc = rewriter.create<mlir::func::FuncOp>(
op.getLoc(), initFuncName, initFuncType);
rewriter.inlineRegionBefore(op.getInitialRegion(), initFunc.getBody(),
initFunc.end());
}

rewriter.eraseOp(op);
return success();
}
};

struct YieldStorageOpLowering
: public OpConversionPattern<arc::YieldStorageOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(arc::YieldStorageOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(op);
return success();
}
};

struct AllocStorageOpLowering
: public OpConversionPattern<arc::AllocStorageOp> {
using OpConversionPattern::OpConversionPattern;
Expand Down Expand Up @@ -623,6 +646,7 @@ void LowerArcToLLVMPass::runOnOperation() {
MemoryReadOpLowering,
MemoryWriteOpLowering,
ModelOpLowering,
YieldStorageOpLowering,
ReplaceOpWithInputPattern<seq::ToClockOp>,
ReplaceOpWithInputPattern<seq::FromClockOp>,
SeqConstClockLowering,
Expand Down
51 changes: 46 additions & 5 deletions lib/Dialect/Arc/ArcOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,18 @@ LogicalResult StateOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// AllocStateOp
//===----------------------------------------------------------------------===//

LogicalResult AllocStateOp::verify() {
if (auto init = getInitial())
if (init->getType() != getType().getType())
return emitOpError(
"type of initial value must match inner type of state");
return success();
}

//===----------------------------------------------------------------------===//
// CallOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -296,17 +308,46 @@ void RootOutputOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
//===----------------------------------------------------------------------===//

LogicalResult ModelOp::verify() {
if (getBodyBlock().getArguments().size() != 1)
return emitOpError("must have exactly one argument");
if (auto type = getBodyBlock().getArgument(0).getType();
!isa<StorageType>(type))
return emitOpError("argument must be of storage type");
if (llvm::any_of(getBodyBlock().getArguments(), [](auto arg) -> bool {
return !isa<StorageType>(arg.getType());
}))
return emitOpError("arguments must be of storage type");

for (const hw::ModulePort &port : getIo().getPorts())
if (port.dir == hw::ModulePort::Direction::InOut)
return emitOpError("inout ports are not supported");
return success();
}

void ModelOp::getRegionInvocationBounds(
ArrayRef<Attribute> operands,
SmallVectorImpl<InvocationBounds> &invocationBounds) {
invocationBounds.assign(2, {getInitialRegion().empty() ? 0U : 1U, 1U});
}

void ModelOp::getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> &regions) {
if (point.isParent()) {
if (!getInitialRegion().empty())
regions.emplace_back(&getInitialRegion());
else
regions.emplace_back(&getBody());
return;
}

if (point.getRegionOrNull() == &getBody()) {
regions.emplace_back(RegionSuccessor());
return;
}

regions.emplace_back(&getBody(), getBody().getArguments());
}

MutableOperandRange
YieldStorageOp::getMutableSuccessorOperands(RegionBranchPoint point) {
return getStoragesMutable();
}

//===----------------------------------------------------------------------===//
// LutOp
//===----------------------------------------------------------------------===//
Expand Down
91 changes: 90 additions & 1 deletion lib/Dialect/Arc/Transforms/AllocateState.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "circt/Dialect/Arc/ArcOps.h"
#include "circt/Dialect/Arc/ArcPasses.h"
#include "circt/Dialect/HW/HWDialect.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/Pass/Pass.h"
#include "llvm/Support/Debug.h"
Expand Down Expand Up @@ -37,17 +38,40 @@ struct AllocateStatePass
void runOnOperation() override;
void allocateBlock(Block *block);
void allocateOps(Value storage, Block *block, ArrayRef<Operation *> ops);
void prepareInitialRegion();
void initializeStorage(OpBuilder &builder, AllocStateOp &allocOp,
IntegerAttr offset);

private:
hw::HWDialect *hwDialect;
};
} // namespace

void AllocateStatePass::runOnOperation() {

ModelOp modelOp = getOperation();
LLVM_DEBUG(llvm::dbgs() << "Allocating state in `" << modelOp.getName()
<< "`\n");

hwDialect = getContext().getLoadedDialect<hw::HWDialect>();

prepareInitialRegion();

// Walk the blocks from innermost to outermost and group all state allocations
// in that block in one larger allocation.
modelOp.walk([&](Block *block) { allocateBlock(block); });
modelOp.getBody().walk([&](Block *block) { allocateBlock(block); });

// Update initial storage types
auto bodyArgs = modelOp.getBody().getArguments();
auto initArgs = modelOp.getInitialRegion().getArguments();
assert(bodyArgs.size() == initArgs.size());

for (auto [arg, yield] : llvm::zip(bodyArgs, initArgs)) {
assert(isa<StorageType>(arg.getType()) &&
isa<StorageType>(yield.getType()));
// (chuckles) I'm in danger.
yield.setType(arg.getType());
}
}

void AllocateStatePass::allocateBlock(Block *block) {
Expand All @@ -68,6 +92,67 @@ void AllocateStatePass::allocateBlock(Block *block) {
allocateOps(storage, block, ops);
}

void AllocateStatePass::prepareInitialRegion() {
auto &initRegion = getOperation().getInitialRegion();

ImplicitLocOpBuilder initBuilder(getOperation().getLoc(), &getContext());

Block *initBlock;
YieldStorageOp yield;
if (initRegion.empty()) {
initBlock = &initRegion.emplaceBlock();
initBuilder.setInsertionPointToEnd(initBlock);
yield = initBuilder.create<YieldStorageOp>(ValueRange{});
} else {
initBlock = &initRegion.front();
assert(initBlock->getArguments().empty());
yield = cast<YieldStorageOp>(initBlock->getTerminator());
}

initRegion.addArguments(getOperation().getBody().getArgumentTypes(),
getOperation().getLoc());
yield->setOperands(initRegion.getArguments());
}

void AllocateStatePass::initializeStorage(OpBuilder &builder,
AllocStateOp &allocOp,
IntegerAttr offset) {
if (!allocOp.getInitial())
return;

assert(isa<ModelOp>(allocOp->getParentOp()) &&
"Unsupported nested allocation");
auto storageArg = dyn_cast<BlockArgument>(allocOp.getStorage());
assert(!!storageArg && "Unknown storage value");
auto argIdx = storageArg.getArgNumber();

OpBuilder initBuilder(&getContext());
Block *initBlock = &getOperation().getInitialRegion().front();
Operation *initCstOp;
initBuilder.setInsertionPoint(initBlock->getTerminator());

if (auto intAttr = dyn_cast<IntegerAttr>(*allocOp.getInitial())) {
initCstOp = hwDialect->materializeConstant(
initBuilder, intAttr, intAttr.getType(), allocOp.getLoc());
} else {
auto initial = *allocOp.getInitial();
auto *opDialect = &initial.getType().getDialect();
initCstOp = opDialect->materializeConstant(
initBuilder, initial, initial.getType(), allocOp.getLoc());
}
assert(!!initCstOp && initCstOp->getNumResults() == 1 &&
"Failed to materialize single constatnt value");

auto getOp = initBuilder.create<StorageGetOp>(
allocOp.getLoc(), StateType::get(initCstOp->getResult(0).getType()),
initBlock->getArgument(argIdx), offset);

initBuilder.create<StateWriteOp>(allocOp.getLoc(), getOp.getResult(),
initCstOp->getResult(0), Value());

allocOp.removeInitialAttr();
}

void AllocateStatePass::allocateOps(Value storage, Block *block,
ArrayRef<Operation *> ops) {
SmallVector<std::tuple<Value, Value, IntegerAttr>> gettersToCreate;
Expand All @@ -85,6 +170,8 @@ void AllocateStatePass::allocateOps(Value storage, Block *block,

// Allocate storage for the operations.
OpBuilder builder(block->getParentOp());
OpBuilder initBuilder(&getContext());

for (auto *op : ops) {
if (isa<AllocStateOp, RootInputOp, RootOutputOp>(op)) {
auto result = op->getResult(0);
Expand All @@ -93,6 +180,8 @@ void AllocateStatePass::allocateOps(Value storage, Block *block,
auto offset = builder.getI32IntegerAttr(allocBytes(numBytes));
op->setAttr("offset", offset);
gettersToCreate.emplace_back(result, storage, offset);
if (auto allocOp = dyn_cast<AllocStateOp>(op))
initializeStorage(initBuilder, allocOp, offset);
continue;
}

Expand Down
Loading

0 comments on commit 3815ec1

Please sign in to comment.