Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][emitc] Add 'emitc.switch' op to the dialect #102331

Merged
merged 25 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
11a57c3
[mlir][emitc] Add 'emitc.switch' op to the dialect
EtoAndruwa Aug 7, 2024
82b8c25
[mlir][emitc] Add support of opaque type and all integral types
EtoAndruwa Aug 8, 2024
89b6204
[mlir][emict] Fix description of the switch op
EtoAndruwa Aug 8, 2024
e99904d
[mlir][emitc] NFC
EtoAndruwa Aug 9, 2024
5c44e22
[mlir][emitc] Add test cases for '-declare-variables-at-top'
EtoAndruwa Aug 9, 2024
246e9f7
[mlir][emitc] Fix 'isSwitchOperandType' function
EtoAndruwa Aug 9, 2024
6f53683
[mlir][emitc] NFC
EtoAndruwa Aug 9, 2024
9990307
[mlir][emitc] Add example of emitted code to the op's description
EtoAndruwa Aug 9, 2024
b9793a6
[mlir][emitc] NFC
EtoAndruwa Aug 9, 2024
07f8db0
[mlir][emitc] NFC. Fix 'test_misplaced_yield' test case
EtoAndruwa Aug 9, 2024
e7d1c50
[mlir][emitc] Add conversion from 'scf::index_switch' to 'emitc::switch'
EtoAndruwa Aug 9, 2024
df8e92b
[mlir][emitc] Add test case for the index type
EtoAndruwa Aug 9, 2024
dfacff1
[mlir][emitc] NFC. Fix op's example in the description
EtoAndruwa Aug 9, 2024
1188105
[emitc][mlir] Refactor 'lowerRegion' lambda function
EtoAndruwa Aug 12, 2024
125466a
[mlir][emitc] Add support of pointer wide types
EtoAndruwa Aug 12, 2024
67c45e5
[mlir][emitc] NFC
EtoAndruwa Aug 12, 2024
3704607
[mlir][emitc] Delete unused function 'isSwitchOperandType'
EtoAndruwa Aug 12, 2024
64e58a8
[mlir][emitc] Fix missing arguments for the 'lowerRegion' call
EtoAndruwa Aug 12, 2024
c9a9744
[mlir][emitc] NFC. Delete parentheses aroound a case value
EtoAndruwa Aug 12, 2024
626ed80
[mlir][emitc] NFC. Tests adjusted for updates
EtoAndruwa Aug 12, 2024
453d127
[mlir][emitc] Fix op's description
EtoAndruwa Aug 12, 2024
f6ba7a9
[mlir][emitc] Moved some test from 'invalid' to 'invalid_ops'
EtoAndruwa Aug 12, 2024
3c0971a
[mlir][emitc] Switching to custom assembly in tablegen similar to the…
EtoAndruwa Aug 13, 2024
33a8667
[mlir][emitc] NFC
EtoAndruwa Aug 14, 2024
408f667
[mlir][emitc] NFC
EtoAndruwa Aug 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 82 additions & 1 deletion mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
Original file line number Diff line number Diff line change
Expand Up @@ -1188,7 +1188,7 @@ def EmitC_AssignOp : EmitC_Op<"assign", []> {
}

def EmitC_YieldOp : EmitC_Op<"yield",
[Pure, Terminator, ParentOneOf<["ExpressionOp", "IfOp", "ForOp"]>]> {
[Pure, Terminator, ParentOneOf<["ExpressionOp", "IfOp", "ForOp", "SwitchOp"]>]> {
EtoAndruwa marked this conversation as resolved.
Show resolved Hide resolved
let summary = "Block termination operation";
let description = [{
The `emitc.yield` terminates its parent EmitC op's region, optionally yielding
Expand Down Expand Up @@ -1302,5 +1302,86 @@ def EmitC_SubscriptOp : EmitC_Op<"subscript", []> {
let assemblyFormat = "$value `[` $indices `]` attr-dict `:` functional-type(operands, results)";
}

def EmitC_SwitchOp : EmitC_Op<"switch", [RecursiveMemoryEffects,
SingleBlockImplicitTerminator<"emitc::YieldOp">,
DeclareOpInterfaceMethods<RegionBranchOpInterface,
["getRegionInvocationBounds",
"getEntrySuccessorRegions"]>]> {
let summary = "Switch operation";
let description = [{
The `emitc.switch` is a control-flow operation that branches to one of
the given regions based on the values of the argument and the cases.
The operand to a switch operation is a opaque, integral or pointer
wide types.

The operation always has a "default" region and any number of case regions
denoted by integer constants. Control-flow transfers to the case region
whose constant value equals the value of the argument. If the argument does
not equal any of the case values, control-flow transfer to the "default"
region.

The operation does not return any value. Moreover, case regions must be
explicitly terminated using the `emitc.yield` operation. Default region is
yielded implicitly.

Example:
EtoAndruwa marked this conversation as resolved.
Show resolved Hide resolved

```mlir
// Example:
emitc.switch %0 : i32
case 2: {
%1 = emitc.call_opaque "func_b" () : () -> i32
emitc.yield
}
case 5: {
%2 = emitc.call_opaque "func_a" () : () -> i32
emitc.yield
}
default: {
EtoAndruwa marked this conversation as resolved.
Show resolved Hide resolved
%3 = "emitc.variable"(){value = 42.0 : f32} : () -> f32
emitc.call_opaque "func2" (%3) : (f32) -> ()
}

// Output:
EtoAndruwa marked this conversation as resolved.
Show resolved Hide resolved
switch (v1) {
case (2): {
int32_t v2 = func_b();
break;
}
case (5): {
EtoAndruwa marked this conversation as resolved.
Show resolved Hide resolved
int32_t v3 = func_a();
break;
}
default: {
float v4 = 4.200000000e+01f;
func2(v4);
break;
}
```
}];

let arguments = (ins IntegerIndexOrOpaqueType:$arg, DenseI64ArrayAttr:$cases);
let results = (outs);
let regions = (region SizedRegion<1>:$defaultRegion,
VariadicRegion<SizedRegion<1>>:$caseRegions);

let assemblyFormat = [{
$arg `:` type($arg) attr-dict custom<SwitchCases>($cases, $caseRegions) `\n`
`` `default` $defaultRegion
}];

let extraClassDeclaration = [{
/// Get the number of cases.
unsigned getNumCases();

/// Get the default region body.
Block &getDefaultBlock();

/// Get the body of a case region.
Block &getCaseBlock(unsigned idx);
}];

let hasVerifier = 1;
}

#endif // MLIR_DIALECT_EMITC_IR_EMITC
70 changes: 55 additions & 15 deletions mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,19 @@ static void lowerYield(SmallVector<Value> &resultVariables,
rewriter.eraseOp(yield);
}

// Lower the contents of an scf::if/scf::index_switch regions to an
// emitc::if/emitc::switch regions. The contents of the lowering region is
EtoAndruwa marked this conversation as resolved.
Show resolved Hide resolved
// moved into the respective lowered region, but the scf::yield is replaced not
// only with an emitc::yield, but also with a sequence of emitc::assign ops that
// set the yielded values into the result variables.
static void lowerRegion(SmallVector<Value> &resultVariables,
PatternRewriter &rewriter, Region &region,
Region &loweredRegion) {
rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
Operation *terminator = loweredRegion.back().getTerminator();
lowerYield(resultVariables, rewriter, cast<scf::YieldOp>(terminator));
}

LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
PatternRewriter &rewriter) const {
Location loc = forOp.getLoc();
Expand Down Expand Up @@ -145,18 +158,6 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
SmallVector<Value> resultVariables =
createVariablesForResults(ifOp, rewriter);

// Utility function to lower the contents of an scf::if region to an emitc::if
// region. The contents of the scf::if regions is moved into the respective
// emitc::if regions, but the scf::yield is replaced not only with an
// emitc::yield, but also with a sequence of emitc::assign ops that set the
// yielded values into the result variables.
auto lowerRegion = [&resultVariables, &rewriter](Region &region,
Region &loweredRegion) {
rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
Operation *terminator = loweredRegion.back().getTerminator();
lowerYield(resultVariables, rewriter, cast<scf::YieldOp>(terminator));
};

Region &thenRegion = ifOp.getThenRegion();
Region &elseRegion = ifOp.getElseRegion();

Expand All @@ -166,20 +167,59 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
rewriter.create<emitc::IfOp>(loc, ifOp.getCondition(), false, false);

Region &loweredThenRegion = loweredIf.getThenRegion();
lowerRegion(thenRegion, loweredThenRegion);
lowerRegion(resultVariables, rewriter, thenRegion, loweredThenRegion);

if (hasElseBlock) {
Region &loweredElseRegion = loweredIf.getElseRegion();
lowerRegion(elseRegion, loweredElseRegion);
lowerRegion(resultVariables, rewriter, elseRegion, loweredElseRegion);
}

rewriter.replaceOp(ifOp, resultVariables);
return success();
}

// Lower scf::index_switch to emitc::switch, implementing result values as
// emitc::variable's updated within the case and default regions.
struct IndexSwitchOpLowering : public OpRewritePattern<IndexSwitchOp> {
using OpRewritePattern<IndexSwitchOp>::OpRewritePattern;

LogicalResult matchAndRewrite(IndexSwitchOp indexSwitchOp,
PatternRewriter &rewriter) const override;
};

LogicalResult
IndexSwitchOpLowering::matchAndRewrite(IndexSwitchOp indexSwitchOp,
PatternRewriter &rewriter) const {
Location loc = indexSwitchOp.getLoc();

// Create an emitc::variable op for each result. These variables will be
// assigned to by emitc::assign ops within the case and default regions.
SmallVector<Value> resultVariables =
createVariablesForResults(indexSwitchOp, rewriter);

auto loweredSwitch = rewriter.create<emitc::SwitchOp>(
loc, indexSwitchOp.getArg(), indexSwitchOp.getCases(),
indexSwitchOp.getNumCases());

// Lowering all case regions.
for (auto pair : llvm::zip(indexSwitchOp.getCaseRegions(),
loweredSwitch.getCaseRegions())) {
lowerRegion(resultVariables, rewriter, std::get<0>(pair),
std::get<1>(pair));
}

// Lowering default region.
lowerRegion(resultVariables, rewriter, indexSwitchOp.getDefaultRegion(),
loweredSwitch.getDefaultRegion());

rewriter.replaceOp(indexSwitchOp, resultVariables);
return success();
}

void mlir::populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns) {
patterns.add<ForLowering>(patterns.getContext());
patterns.add<IfLowering>(patterns.getContext());
patterns.add<IndexSwitchOpLowering>(patterns.getContext());
}

void SCFToEmitCPass::runOnOperation() {
Expand All @@ -188,7 +228,7 @@ void SCFToEmitCPass::runOnOperation() {

// Configure conversion to lower out SCF operations.
ConversionTarget target(getContext());
target.addIllegalOp<scf::ForOp, scf::IfOp>();
target.addIllegalOp<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>();
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
if (failed(
applyPartialConversion(getOperation(), target, std::move(patterns))))
Expand Down
132 changes: 132 additions & 0 deletions mlir/lib/Dialect/EmitC/IR/EmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1096,6 +1096,138 @@ GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return success();
}

//===----------------------------------------------------------------------===//
// SwitchOp
//===----------------------------------------------------------------------===//

/// Parse the case regions and values.
static ParseResult
parseSwitchCases(OpAsmParser &parser, DenseI64ArrayAttr &cases,
SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
SmallVector<int64_t> caseValues;
while (succeeded(parser.parseOptionalKeyword("case"))) {
int64_t value;
Region &region = *caseRegions.emplace_back(std::make_unique<Region>());
if (parser.parseInteger(value) ||
parser.parseRegion(region, /*arguments=*/{}))
return failure();
caseValues.push_back(value);
}
cases = parser.getBuilder().getDenseI64ArrayAttr(caseValues);
return success();
}

/// Print the case regions and values.
static void printSwitchCases(OpAsmPrinter &p, Operation *op,
DenseI64ArrayAttr cases, RegionRange caseRegions) {
for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
p.printNewline();
p << "case " << value << ' ';
p.printRegion(*region, /*printEntryBlockArgs=*/false);
}
}

static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region,
const Twine &name) {
auto yield = dyn_cast<emitc::YieldOp>(region.front().back());
if (!yield)
return op.emitOpError("expected region to end with emitc.yield, but got ")
<< region.front().back().getName();

if (yield.getNumOperands() != 0) {
return (op.emitOpError("expected each region to return ")
<< "0 values, but " << name << " returns "
<< yield.getNumOperands())
.attachNote(yield.getLoc())
<< "see yield operation here";
}

return success();
EtoAndruwa marked this conversation as resolved.
Show resolved Hide resolved
}

LogicalResult emitc::SwitchOp::verify() {
if (!isIntegerIndexOrOpaqueType(getArg().getType()))
return emitOpError("unsupported type ") << getArg().getType();

if (getCases().size() != getCaseRegions().size()) {
return emitOpError("has ")
<< getCaseRegions().size() << " case regions but "
<< getCases().size() << " case values";
}

DenseSet<int64_t> valueSet;
for (int64_t value : getCases())
if (!valueSet.insert(value).second)
return emitOpError("has duplicate case value: ") << value;

if (failed(verifyRegion(*this, getDefaultRegion(), "default region")))
return failure();

for (auto [idx, caseRegion] : llvm::enumerate(getCaseRegions()))
if (failed(verifyRegion(*this, caseRegion, "case region #" + Twine(idx))))
return failure();

return success();
}

unsigned emitc::SwitchOp::getNumCases() { return getCases().size(); }

Block &emitc::SwitchOp::getDefaultBlock() { return getDefaultRegion().front(); }

Block &emitc::SwitchOp::getCaseBlock(unsigned idx) {
assert(idx < getNumCases() && "case index out-of-bounds");
return getCaseRegions()[idx].front();
}

void SwitchOp::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) {
llvm::copy(getRegions(), std::back_inserter(successors));
}

void SwitchOp::getEntrySuccessorRegions(
ArrayRef<Attribute> operands,
SmallVectorImpl<RegionSuccessor> &successors) {
FoldAdaptor adaptor(operands, *this);

// If a constant was not provided, all regions are possible successors.
auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
if (!arg) {
llvm::copy(getRegions(), std::back_inserter(successors));
return;
}

// Otherwise, try to find a case with a matching value. If not, the
// default region is the only successor.
for (auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
if (caseValue == arg.getInt()) {
successors.emplace_back(&caseRegion);
return;
}
}
successors.emplace_back(&getDefaultRegion());
}

void SwitchOp::getRegionInvocationBounds(
ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
if (!operandValue) {
// All regions are invoked at most once.
bounds.append(getNumRegions(), InvocationBounds(/*lb=*/0, /*ub=*/1));
return;
}

unsigned liveIndex = getNumRegions() - 1;
const auto *iteratorToInt = llvm::find(getCases(), operandValue.getInt());

liveIndex = iteratorToInt != getCases().end()
? std::distance(getCases().begin(), iteratorToInt)
: liveIndex;

for (unsigned regIndex = 0, regNum = getNumRegions(); regIndex < regNum;
++regIndex)
bounds.emplace_back(/*lb=*/0, /*ub=*/regIndex == liveIndex);
}

//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
Expand Down
Loading