Skip to content

Commit

Permalink
[mlir][emitc] Add 'emitc.switch' op to the dialect (#102331)
Browse files Browse the repository at this point in the history
This PR is continuation of the [previous
one](#101478). As a result, the
`emitc::SwitchOp` op was developed inspired by `scf::IndexSwitchOp`.

Main points of PR:

- Added the `emitc::SwitchOp` op  to the EmitC dialect + CppEmitter
- Corresponding tests were added
- Conversion from the SCF dialect to the EmitC dialect for the op
  • Loading branch information
EtoAndruwa authored Aug 16, 2024
1 parent 7afb51e commit 97f0ab7
Show file tree
Hide file tree
Showing 9 changed files with 1,391 additions and 20 deletions.
84 changes: 83 additions & 1 deletion mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
Original file line number Diff line number Diff line change
Expand Up @@ -1188,7 +1188,7 @@ def EmitC_AssignOp : EmitC_Op<"assign", []> {
}

def EmitC_YieldOp : EmitC_Op<"yield",
[Pure, Terminator, ParentOneOf<["ExpressionOp", "IfOp", "ForOp"]>]> {
[Pure, Terminator, ParentOneOf<["ExpressionOp", "IfOp", "ForOp", "SwitchOp"]>]> {
let summary = "Block termination operation";
let description = [{
The `emitc.yield` terminates its parent EmitC op's region, optionally yielding
Expand Down Expand Up @@ -1302,5 +1302,87 @@ def EmitC_SubscriptOp : EmitC_Op<"subscript", []> {
let assemblyFormat = "$value `[` $indices `]` attr-dict `:` functional-type(operands, results)";
}

def EmitC_SwitchOp : EmitC_Op<"switch", [RecursiveMemoryEffects,
SingleBlockImplicitTerminator<"emitc::YieldOp">,
DeclareOpInterfaceMethods<RegionBranchOpInterface,
["getRegionInvocationBounds",
"getEntrySuccessorRegions"]>]> {
let summary = "Switch operation";
let description = [{
The `emitc.switch` is a control-flow operation that branches to one of
the given regions based on the values of the argument and the cases.
The operand to a switch operation is a opaque, integral or pointer
wide types.

The operation always has a "default" region and any number of case regions
denoted by integer constants. Control-flow transfers to the case region
whose constant value equals the value of the argument. If the argument does
not equal any of the case values, control-flow transfer to the "default"
region.

The operation does not return any value. Moreover, case regions must be
explicitly terminated using the `emitc.yield` operation. Default region is
yielded implicitly.

Example:

```mlir
// Example:
emitc.switch %0 : i32
case 2 {
%1 = emitc.call_opaque "func_b" () : () -> i32
emitc.yield
}
case 5 {
%2 = emitc.call_opaque "func_a" () : () -> i32
emitc.yield
}
default {
%3 = "emitc.variable"(){value = 42.0 : f32} : () -> f32
emitc.call_opaque "func2" (%3) : (f32) -> ()
}
```
```c++
// Code emitted for the operations above.
switch (v1) {
case 2: {
int32_t v2 = func_b();
break;
}
case 5: {
int32_t v3 = func_a();
break;
}
default: {
float v4 = 4.200000000e+01f;
func2(v4);
break;
}
```
}];

let arguments = (ins IntegerIndexOrOpaqueType:$arg, DenseI64ArrayAttr:$cases);
let results = (outs);
let regions = (region SizedRegion<1>:$defaultRegion,
VariadicRegion<SizedRegion<1>>:$caseRegions);

let assemblyFormat = [{
$arg `:` type($arg) attr-dict custom<SwitchCases>($cases, $caseRegions) `\n`
`` `default` $defaultRegion
}];

let extraClassDeclaration = [{
/// Get the number of cases.
unsigned getNumCases();

/// Get the default region body.
Block &getDefaultBlock();

/// Get the body of a case region.
Block &getCaseBlock(unsigned idx);
}];

let hasVerifier = 1;
}

#endif // MLIR_DIALECT_EMITC_IR_EMITC
70 changes: 55 additions & 15 deletions mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,19 @@ static void lowerYield(SmallVector<Value> &resultVariables,
rewriter.eraseOp(yield);
}

// Lower the contents of an scf::if/scf::index_switch regions to an
// emitc::if/emitc::switch region. The contents of the lowering region is
// moved into the respective lowered region, but the scf::yield is replaced not
// only with an emitc::yield, but also with a sequence of emitc::assign ops that
// set the yielded values into the result variables.
static void lowerRegion(SmallVector<Value> &resultVariables,
PatternRewriter &rewriter, Region &region,
Region &loweredRegion) {
rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
Operation *terminator = loweredRegion.back().getTerminator();
lowerYield(resultVariables, rewriter, cast<scf::YieldOp>(terminator));
}

LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
PatternRewriter &rewriter) const {
Location loc = forOp.getLoc();
Expand Down Expand Up @@ -145,18 +158,6 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
SmallVector<Value> resultVariables =
createVariablesForResults(ifOp, rewriter);

// Utility function to lower the contents of an scf::if region to an emitc::if
// region. The contents of the scf::if regions is moved into the respective
// emitc::if regions, but the scf::yield is replaced not only with an
// emitc::yield, but also with a sequence of emitc::assign ops that set the
// yielded values into the result variables.
auto lowerRegion = [&resultVariables, &rewriter](Region &region,
Region &loweredRegion) {
rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
Operation *terminator = loweredRegion.back().getTerminator();
lowerYield(resultVariables, rewriter, cast<scf::YieldOp>(terminator));
};

Region &thenRegion = ifOp.getThenRegion();
Region &elseRegion = ifOp.getElseRegion();

Expand All @@ -166,20 +167,59 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
rewriter.create<emitc::IfOp>(loc, ifOp.getCondition(), false, false);

Region &loweredThenRegion = loweredIf.getThenRegion();
lowerRegion(thenRegion, loweredThenRegion);
lowerRegion(resultVariables, rewriter, thenRegion, loweredThenRegion);

if (hasElseBlock) {
Region &loweredElseRegion = loweredIf.getElseRegion();
lowerRegion(elseRegion, loweredElseRegion);
lowerRegion(resultVariables, rewriter, elseRegion, loweredElseRegion);
}

rewriter.replaceOp(ifOp, resultVariables);
return success();
}

// Lower scf::index_switch to emitc::switch, implementing result values as
// emitc::variable's updated within the case and default regions.
struct IndexSwitchOpLowering : public OpRewritePattern<IndexSwitchOp> {
using OpRewritePattern<IndexSwitchOp>::OpRewritePattern;

LogicalResult matchAndRewrite(IndexSwitchOp indexSwitchOp,
PatternRewriter &rewriter) const override;
};

LogicalResult
IndexSwitchOpLowering::matchAndRewrite(IndexSwitchOp indexSwitchOp,
PatternRewriter &rewriter) const {
Location loc = indexSwitchOp.getLoc();

// Create an emitc::variable op for each result. These variables will be
// assigned to by emitc::assign ops within the case and default regions.
SmallVector<Value> resultVariables =
createVariablesForResults(indexSwitchOp, rewriter);

auto loweredSwitch = rewriter.create<emitc::SwitchOp>(
loc, indexSwitchOp.getArg(), indexSwitchOp.getCases(),
indexSwitchOp.getNumCases());

// Lowering all case regions.
for (auto pair : llvm::zip(indexSwitchOp.getCaseRegions(),
loweredSwitch.getCaseRegions())) {
lowerRegion(resultVariables, rewriter, std::get<0>(pair),
std::get<1>(pair));
}

// Lowering default region.
lowerRegion(resultVariables, rewriter, indexSwitchOp.getDefaultRegion(),
loweredSwitch.getDefaultRegion());

rewriter.replaceOp(indexSwitchOp, resultVariables);
return success();
}

void mlir::populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns) {
patterns.add<ForLowering>(patterns.getContext());
patterns.add<IfLowering>(patterns.getContext());
patterns.add<IndexSwitchOpLowering>(patterns.getContext());
}

void SCFToEmitCPass::runOnOperation() {
Expand All @@ -188,7 +228,7 @@ void SCFToEmitCPass::runOnOperation() {

// Configure conversion to lower out SCF operations.
ConversionTarget target(getContext());
target.addIllegalOp<scf::ForOp, scf::IfOp>();
target.addIllegalOp<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>();
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
if (failed(
applyPartialConversion(getOperation(), target, std::move(patterns))))
Expand Down
132 changes: 132 additions & 0 deletions mlir/lib/Dialect/EmitC/IR/EmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1096,6 +1096,138 @@ GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return success();
}

//===----------------------------------------------------------------------===//
// SwitchOp
//===----------------------------------------------------------------------===//

/// Parse the case regions and values.
static ParseResult
parseSwitchCases(OpAsmParser &parser, DenseI64ArrayAttr &cases,
SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
SmallVector<int64_t> caseValues;
while (succeeded(parser.parseOptionalKeyword("case"))) {
int64_t value;
Region &region = *caseRegions.emplace_back(std::make_unique<Region>());
if (parser.parseInteger(value) ||
parser.parseRegion(region, /*arguments=*/{}))
return failure();
caseValues.push_back(value);
}
cases = parser.getBuilder().getDenseI64ArrayAttr(caseValues);
return success();
}

/// Print the case regions and values.
static void printSwitchCases(OpAsmPrinter &p, Operation *op,
DenseI64ArrayAttr cases, RegionRange caseRegions) {
for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
p.printNewline();
p << "case " << value << ' ';
p.printRegion(*region, /*printEntryBlockArgs=*/false);
}
}

static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region,
const Twine &name) {
auto yield = dyn_cast<emitc::YieldOp>(region.front().back());
if (!yield)
return op.emitOpError("expected region to end with emitc.yield, but got ")
<< region.front().back().getName();

if (yield.getNumOperands() != 0) {
return (op.emitOpError("expected each region to return ")
<< "0 values, but " << name << " returns "
<< yield.getNumOperands())
.attachNote(yield.getLoc())
<< "see yield operation here";
}

return success();
}

LogicalResult emitc::SwitchOp::verify() {
if (!isIntegerIndexOrOpaqueType(getArg().getType()))
return emitOpError("unsupported type ") << getArg().getType();

if (getCases().size() != getCaseRegions().size()) {
return emitOpError("has ")
<< getCaseRegions().size() << " case regions but "
<< getCases().size() << " case values";
}

DenseSet<int64_t> valueSet;
for (int64_t value : getCases())
if (!valueSet.insert(value).second)
return emitOpError("has duplicate case value: ") << value;

if (failed(verifyRegion(*this, getDefaultRegion(), "default region")))
return failure();

for (auto [idx, caseRegion] : llvm::enumerate(getCaseRegions()))
if (failed(verifyRegion(*this, caseRegion, "case region #" + Twine(idx))))
return failure();

return success();
}

unsigned emitc::SwitchOp::getNumCases() { return getCases().size(); }

Block &emitc::SwitchOp::getDefaultBlock() { return getDefaultRegion().front(); }

Block &emitc::SwitchOp::getCaseBlock(unsigned idx) {
assert(idx < getNumCases() && "case index out-of-bounds");
return getCaseRegions()[idx].front();
}

void SwitchOp::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) {
llvm::copy(getRegions(), std::back_inserter(successors));
}

void SwitchOp::getEntrySuccessorRegions(
ArrayRef<Attribute> operands,
SmallVectorImpl<RegionSuccessor> &successors) {
FoldAdaptor adaptor(operands, *this);

// If a constant was not provided, all regions are possible successors.
auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
if (!arg) {
llvm::copy(getRegions(), std::back_inserter(successors));
return;
}

// Otherwise, try to find a case with a matching value. If not, the
// default region is the only successor.
for (auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
if (caseValue == arg.getInt()) {
successors.emplace_back(&caseRegion);
return;
}
}
successors.emplace_back(&getDefaultRegion());
}

void SwitchOp::getRegionInvocationBounds(
ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
if (!operandValue) {
// All regions are invoked at most once.
bounds.append(getNumRegions(), InvocationBounds(/*lb=*/0, /*ub=*/1));
return;
}

unsigned liveIndex = getNumRegions() - 1;
const auto *iteratorToInt = llvm::find(getCases(), operandValue.getInt());

liveIndex = iteratorToInt != getCases().end()
? std::distance(getCases().begin(), iteratorToInt)
: liveIndex;

for (unsigned regIndex = 0, regNum = getNumRegions(); regIndex < regNum;
++regIndex)
bounds.emplace_back(/*lb=*/0, /*ub=*/regIndex == liveIndex);
}

//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
Expand Down
Loading

0 comments on commit 97f0ab7

Please sign in to comment.