Skip to content

Commit

Permalink
Support scf.if Op Lowering to Calyx (#6256)
Browse files Browse the repository at this point in the history
* support lowering scf if op and add a corresponding test
  • Loading branch information
jiahanxie353 authored Jul 31, 2024
1 parent 05136f0 commit 3c12682
Show file tree
Hide file tree
Showing 2 changed files with 284 additions and 19 deletions.
234 changes: 215 additions & 19 deletions lib/Conversion/SCFToCalyx/SCFToCalyx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ class ScfForOp : public calyx::RepeatOpInterface<scf::ForOp> {
// Lowering state classes
//===----------------------------------------------------------------------===//

struct IfScheduleable {
scf::IfOp ifOp;
};

struct WhileScheduleable {
/// While operation to schedule.
ScfWhileOp whileOp;
Expand All @@ -115,8 +119,63 @@ struct CallScheduleable {
};

/// A variant of types representing scheduleable operations.
using Scheduleable = std::variant<calyx::GroupOp, WhileScheduleable,
ForScheduleable, CallScheduleable>;
using Scheduleable =
std::variant<calyx::GroupOp, WhileScheduleable, ForScheduleable,
IfScheduleable, CallScheduleable>;

class IfLoweringStateInterface {
public:
void setThenGroup(scf::IfOp op, calyx::GroupOp group) {
Operation *operation = op.getOperation();
assert(thenGroup.count(operation) == 0 &&
"A then group was already set for this scf::IfOp!\n");
thenGroup[operation] = group;
}

calyx::GroupOp getThenGroup(scf::IfOp op) {
auto it = thenGroup.find(op.getOperation());
assert(it != thenGroup.end() &&
"No then group was set for this scf::IfOp!\n");
return it->second;
}

void setElseGroup(scf::IfOp op, calyx::GroupOp group) {
Operation *operation = op.getOperation();
assert(elseGroup.count(operation) == 0 &&
"An else group was already set for this scf::IfOp!\n");
elseGroup[operation] = group;
}

calyx::GroupOp getElseGroup(scf::IfOp op) {
auto it = elseGroup.find(op.getOperation());
assert(it != elseGroup.end() &&
"No else group was set for this scf::IfOp!\n");
return it->second;
}

void setResultRegs(scf::IfOp op, calyx::RegisterOp reg, unsigned idx) {
assert(resultRegs[op.getOperation()].count(idx) == 0 &&
"A register was already registered for the given yield result.\n");
assert(idx < op->getNumOperands());
resultRegs[op.getOperation()][idx] = reg;
}

const DenseMap<unsigned, calyx::RegisterOp> &getResultRegs(scf::IfOp op) {
return resultRegs[op.getOperation()];
}

calyx::RegisterOp getResultRegs(scf::IfOp op, unsigned idx) {
auto regs = getResultRegs(op);
auto it = regs.find(idx);
assert(it != regs.end() && "resultReg not found");
return it->second;
}

private:
DenseMap<Operation *, calyx::GroupOp> thenGroup;
DenseMap<Operation *, calyx::GroupOp> elseGroup;
DenseMap<Operation *, DenseMap<unsigned, calyx::RegisterOp>> resultRegs;
};

class WhileLoopLoweringStateInterface
: calyx::LoopLoweringStateInterface<ScfWhileOp> {
Expand Down Expand Up @@ -187,6 +246,7 @@ class ForLoopLoweringStateInterface
class ComponentLoweringState : public calyx::ComponentLoweringStateInterface,
public WhileLoopLoweringStateInterface,
public ForLoopLoweringStateInterface,
public IfLoweringStateInterface,
public calyx::SchedulerInterface<Scheduleable> {
public:
ComponentLoweringState(calyx::ComponentOp component)
Expand All @@ -213,7 +273,7 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {
TypeSwitch<mlir::Operation *, bool>(_op)
.template Case<arith::ConstantOp, ReturnOp, BranchOpInterface,
/// SCF
scf::YieldOp, scf::WhileOp, scf::ForOp,
scf::YieldOp, scf::WhileOp, scf::ForOp, scf::IfOp,
/// memref
memref::AllocOp, memref::AllocaOp, memref::LoadOp,
memref::StoreOp,
Expand Down Expand Up @@ -272,6 +332,7 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {
LogicalResult buildOp(PatternRewriter &rewriter, memref::StoreOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter, scf::WhileOp whileOp) const;
LogicalResult buildOp(PatternRewriter &rewriter, scf::ForOp forOp) const;
LogicalResult buildOp(PatternRewriter &rewriter, scf::IfOp ifOp) const;
LogicalResult buildOp(PatternRewriter &rewriter, CallOp callOp) const;

/// buildLibraryOp will build a TCalyxLibOp inside a TGroupOp based on the
Expand Down Expand Up @@ -720,22 +781,53 @@ LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
"loops. Run --scf-for-to-while before running --scf-to-calyx.";
}

auto whileOp = dyn_cast<scf::WhileOp>(yieldOp->getParentOp());
if (!whileOp) {
return yieldOp.getOperation()->emitError()
<< "Currently only support yield operations inside for and while "
"loops.";
}
ScfWhileOp whileOpInterface(whileOp);

auto assignGroup =
getState<ComponentLoweringState>().buildWhileLoopIterArgAssignments(
rewriter, whileOpInterface,
getState<ComponentLoweringState>().getComponentOp(),
getState<ComponentLoweringState>().getUniqueName(whileOp) + "_latch",
yieldOp->getOpOperands());
getState<ComponentLoweringState>().setWhileLoopLatchGroup(whileOpInterface,
assignGroup);
if (auto whileOp = dyn_cast<scf::WhileOp>(yieldOp->getParentOp())) {
ScfWhileOp whileOpInterface(whileOp);

auto assignGroup =
getState<ComponentLoweringState>().buildWhileLoopIterArgAssignments(
rewriter, whileOpInterface,
getState<ComponentLoweringState>().getComponentOp(),
getState<ComponentLoweringState>().getUniqueName(whileOp) +
"_latch",
yieldOp->getOpOperands());
getState<ComponentLoweringState>().setWhileLoopLatchGroup(whileOpInterface,
assignGroup);
return success();
}

if (auto ifOp = dyn_cast<scf::IfOp>(yieldOp->getParentOp())) {
auto resultRegs = getState<ComponentLoweringState>().getResultRegs(ifOp);

if (yieldOp->getParentRegion() == &ifOp.getThenRegion()) {
auto thenGroup = getState<ComponentLoweringState>().getThenGroup(ifOp);
for (auto op : enumerate(yieldOp.getOperands())) {
auto resultReg =
getState<ComponentLoweringState>().getResultRegs(ifOp, op.index());
buildAssignmentsForRegisterWrite(
rewriter, thenGroup,
getState<ComponentLoweringState>().getComponentOp(), resultReg,
op.value());
getState<ComponentLoweringState>().registerEvaluatingGroup(
ifOp.getResult(op.index()), thenGroup);
}
}

if (!ifOp.getElseRegion().empty() &&
(yieldOp->getParentRegion() == &ifOp.getElseRegion())) {
auto elseGroup = getState<ComponentLoweringState>().getElseGroup(ifOp);
for (auto op : enumerate(yieldOp.getOperands())) {
auto resultReg =
getState<ComponentLoweringState>().getResultRegs(ifOp, op.index());
buildAssignmentsForRegisterWrite(
rewriter, elseGroup,
getState<ComponentLoweringState>().getComponentOp(), resultReg,
op.value());
getState<ComponentLoweringState>().registerEvaluatingGroup(
ifOp.getResult(op.index()), elseGroup);
}
}
}
return success();
}

Expand Down Expand Up @@ -945,6 +1037,13 @@ LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
return success();
}

LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
scf::IfOp ifOp) const {
getState<ComponentLoweringState>().addBlockScheduleable(
ifOp.getOperation()->getBlock(), IfScheduleable{ifOp});
return success();
}

LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
CallOp callOp) const {
std::string instanceName = calyx::getInstanceName(callOp);
Expand Down Expand Up @@ -1291,6 +1390,51 @@ class BuildForGroups : public calyx::FuncOpPartialLoweringPattern {
}
};

class BuildIfGroups : public calyx::FuncOpPartialLoweringPattern {
using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;

LogicalResult
partiallyLowerFuncToComp(FuncOp funcOp,
PatternRewriter &rewriter) const override {
LogicalResult res = success();
funcOp.walk([&](Operation *op) {
if (!isa<scf::IfOp>(op))
return WalkResult::advance();

auto scfIfOp = cast<scf::IfOp>(op);

calyx::ComponentOp componentOp =
getState<ComponentLoweringState>().getComponentOp();

std::string thenGroupName =
getState<ComponentLoweringState>().getUniqueName("then_br");
auto thenGroupOp = calyx::createGroup<calyx::GroupOp>(
rewriter, componentOp, scfIfOp.getLoc(), thenGroupName);
getState<ComponentLoweringState>().setThenGroup(scfIfOp, thenGroupOp);

if (!scfIfOp.getElseRegion().empty()) {
std::string elseGroupName =
getState<ComponentLoweringState>().getUniqueName("else_br");
auto elseGroupOp = calyx::createGroup<calyx::GroupOp>(
rewriter, componentOp, scfIfOp.getLoc(), elseGroupName);
getState<ComponentLoweringState>().setElseGroup(scfIfOp, elseGroupOp);
}

for (auto ifOpRes : scfIfOp.getResults()) {
auto reg = createRegister(
scfIfOp.getLoc(), rewriter, getComponent(),
ifOpRes.getType().getIntOrFloatBitWidth(),
getState<ComponentLoweringState>().getUniqueName("if_res"));
getState<ComponentLoweringState>().setResultRegs(
scfIfOp, reg, ifOpRes.getResultNumber());
}

return WalkResult::advance();
});
return res;
}
};

/// Builds a control schedule by traversing the CFG of the function and
/// associating this with the previously created groups.
/// For simplicity, the generated control flow is expanded for all possible
Expand Down Expand Up @@ -1384,6 +1528,50 @@ class BuildControl : public calyx::FuncOpPartialLoweringPattern {
forLatchGroup.getName());
if (res.failed())
return res;
} else if (auto *ifSchedPtr = std::get_if<IfScheduleable>(&group);
ifSchedPtr) {
auto ifOp = ifSchedPtr->ifOp;

Location loc = ifOp->getLoc();

auto cond = ifOp.getCondition();
auto condGroup = getState<ComponentLoweringState>()
.getEvaluatingGroup<calyx::CombGroupOp>(cond);

auto symbolAttr = FlatSymbolRefAttr::get(
StringAttr::get(getContext(), condGroup.getSymName()));

bool initElse = !ifOp.getElseRegion().empty();
auto ifCtrlOp = rewriter.create<calyx::IfOp>(
loc, cond, symbolAttr, /*initializeElseBody=*/initElse);

rewriter.setInsertionPointToEnd(ifCtrlOp.getBodyBlock());

auto thenSeqOp =
rewriter.create<calyx::SeqOp>(ifOp.getThenRegion().getLoc());
auto *thenSeqOpBlock = thenSeqOp.getBodyBlock();

rewriter.setInsertionPointToEnd(thenSeqOpBlock);

calyx::GroupOp thenGroup =
getState<ComponentLoweringState>().getThenGroup(ifOp);
rewriter.create<calyx::EnableOp>(thenGroup.getLoc(),
thenGroup.getName());

if (!ifOp.getElseRegion().empty()) {
rewriter.setInsertionPointToEnd(ifCtrlOp.getElseBody());

auto elseSeqOp =
rewriter.create<calyx::SeqOp>(ifOp.getElseRegion().getLoc());
auto *elseSeqOpBlock = elseSeqOp.getBodyBlock();

rewriter.setInsertionPointToEnd(elseSeqOpBlock);

calyx::GroupOp elseGroup =
getState<ComponentLoweringState>().getElseGroup(ifOp);
rewriter.create<calyx::EnableOp>(elseGroup.getLoc(),
elseGroup.getName());
}
} else if (auto *callSchedPtr = std::get_if<CallScheduleable>(&group)) {
auto instanceOp = callSchedPtr->instanceOp;
OpBuilder::InsertionGuard g(rewriter);
Expand Down Expand Up @@ -1540,6 +1728,12 @@ class LateSSAReplacement : public calyx::FuncOpPartialLoweringPattern {

LogicalResult partiallyLowerFuncToComp(FuncOp funcOp,
PatternRewriter &) const override {
funcOp.walk([&](scf::IfOp op) {
for (auto res : getState<ComponentLoweringState>().getResultRegs(op))
op.getOperation()->getResults()[res.first].replaceAllUsesWith(
res.second.getOut());
});

funcOp.walk([&](scf::WhileOp op) {
/// The yielded values returned from the while op will be present in the
/// iterargs registers post execution of the loop.
Expand Down Expand Up @@ -1790,6 +1984,8 @@ void SCFToCalyxPass::runOnOperation() {
addOncePattern<BuildForGroups>(loweringPatterns, patternState, funcMap,
*loweringState);

addOncePattern<BuildIfGroups>(loweringPatterns, patternState, funcMap,
*loweringState);
/// This pattern converts operations within basic blocks to Calyx library
/// operators. Combinational operations are assigned inside a
/// calyx::CombGroupOp, and sequential inside calyx::GroupOps.
Expand Down
69 changes: 69 additions & 0 deletions test/Conversion/SCFToCalyx/convert_controlflow.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -572,3 +572,72 @@ module {
return
}
}

// -----

// Test if op with else branch.

module {
// CHECK-LABEL: calyx.component @main(
// CHECK-SAME: %[[VAL_0:in0]]: i32,
// CHECK-SAME: %[[VAL_1:in1]]: i32,
// CHECK-SAME: %[[VAL_2:.*]]: i1 {clk},
// CHECK-SAME: %[[VAL_3:.*]]: i1 {reset},
// CHECK-SAME: %[[VAL_4:.*]]: i1 {go}) -> (
// CHECK-SAME: %[[VAL_5:out0]]: i32,
// CHECK-SAME: %[[VAL_6:.*]]: i1 {done}) {
// CHECK: %[[VAL_7:.*]] = hw.constant true
// CHECK: %[[VAL_8:.*]], %[[VAL_9:.*]], %[[VAL_10:.*]] = calyx.std_add @std_add_0 : i32, i32, i32
// CHECK: %[[VAL_11:.*]], %[[VAL_12:.*]], %[[VAL_13:.*]] = calyx.std_slt @std_slt_0 : i32, i32, i1
// CHECK: %[[VAL_14:.*]], %[[VAL_15:.*]], %[[VAL_16:.*]], %[[VAL_17:.*]], %[[VAL_18:.*]], %[[VAL_19:.*]] = calyx.register @if_res_0_reg : i32, i1, i1, i1, i32, i1
// CHECK: %[[VAL_20:.*]], %[[VAL_21:.*]], %[[VAL_22:.*]], %[[VAL_23:.*]], %[[VAL_24:.*]], %[[VAL_25:.*]] = calyx.register @ret_arg0_reg : i32, i1, i1, i1, i32, i1
// CHECK: calyx.wires {
// CHECK: calyx.assign %[[VAL_5]] = %[[VAL_24]] : i32
// CHECK: calyx.group @then_br_0 {
// CHECK: calyx.assign %[[VAL_14]] = %[[VAL_10]] : i32
// CHECK: calyx.assign %[[VAL_15]] = %[[VAL_7]] : i1
// CHECK: calyx.assign %[[VAL_8]] = %[[VAL_0]] : i32
// CHECK: calyx.assign %[[VAL_9]] = %[[VAL_1]] : i32
// CHECK: calyx.group_done %[[VAL_19]] : i1
// CHECK: }
// CHECK: calyx.group @else_br_0 {
// CHECK: calyx.assign %[[VAL_14]] = %[[VAL_1]] : i32
// CHECK: calyx.assign %[[VAL_15]] = %[[VAL_7]] : i1
// CHECK: calyx.group_done %[[VAL_19]] : i1
// CHECK: }
// CHECK: calyx.comb_group @bb0_0 {
// CHECK: calyx.assign %[[VAL_11]] = %[[VAL_0]] : i32
// CHECK: calyx.assign %[[VAL_12]] = %[[VAL_1]] : i32
// CHECK: }
// CHECK: calyx.group @ret_assign_0 {
// CHECK: calyx.assign %[[VAL_20]] = %[[VAL_18]] : i32
// CHECK: calyx.assign %[[VAL_21]] = %[[VAL_7]] : i1
// CHECK: calyx.group_done %[[VAL_25]] : i1
// CHECK: }
// CHECK: }
// CHECK: calyx.control {
// CHECK: calyx.seq {
// CHECK: calyx.if %[[VAL_13]] with @bb0_0 {
// CHECK: calyx.seq {
// CHECK: calyx.enable @then_br_0
// CHECK: }
// CHECK: } else {
// CHECK: calyx.seq {
// CHECK: calyx.enable @else_br_0
// CHECK: }
// CHECK: }
// CHECK: calyx.enable @ret_assign_0
// CHECK: }
// CHECK: }
// CHECK: } {toplevel}
func.func @main(%arg0 : i32, %arg1 : i32) -> i32 {
%0 = arith.cmpi slt, %arg0, %arg1 : i32
%1 = scf.if %0 -> i32 {
%3 = arith.addi %arg0, %arg1 : i32
scf.yield %3 : i32
} else {
scf.yield %arg1 : i32
}
return %1 : i32
}
}

0 comments on commit 3c12682

Please sign in to comment.