Skip to content

Commit

Permalink
[mlir][emitc] Add 'emitc.switch' op to the dialect
Browse files Browse the repository at this point in the history
  • Loading branch information
EtoAndruwa committed Aug 7, 2024
1 parent ad8a2e4 commit 98b086a
Show file tree
Hide file tree
Showing 6 changed files with 600 additions and 4 deletions.
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ bool isSupportedFloatType(mlir::Type type);
/// Determines whether \p type is a emitc.size_t/ssize_t type.
bool isPointerWideType(mlir::Type type);

/// Determines whether \p type is a valid integer type for SwitchOp.
bool isSwitchOperandType(Type type);

} // namespace emitc
} // namespace mlir

Expand Down
89 changes: 88 additions & 1 deletion mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def CExpression : NativeOpTrait<"emitc::CExpression">;
def IntegerIndexOrOpaqueType : Type<CPred<"emitc::isIntegerIndexOrOpaqueType($_self)">,
"integer, index or opaque type supported by EmitC">;
def FloatIntegerIndexOrOpaqueType : AnyTypeOf<[EmitCFloatType, IntegerIndexOrOpaqueType]>;
def SwitchOperandType : Type<CPred<"emitc::isSwitchOperandType($_self)">,
"integer type for switch operation">;

def EmitC_AddOp : EmitC_BinaryOp<"add", [CExpression]> {
let summary = "Addition operation";
Expand Down Expand Up @@ -1188,7 +1190,7 @@ def EmitC_AssignOp : EmitC_Op<"assign", []> {
}

def EmitC_YieldOp : EmitC_Op<"yield",
[Pure, Terminator, ParentOneOf<["ExpressionOp", "IfOp", "ForOp"]>]> {
[Pure, Terminator, ParentOneOf<["ExpressionOp", "IfOp", "ForOp", "SwitchOp"]>]> {
let summary = "Block termination operation";
let description = [{
The `emitc.yield` terminates its parent EmitC op's region, optionally yielding
Expand Down Expand Up @@ -1302,5 +1304,90 @@ def EmitC_SubscriptOp : EmitC_Op<"subscript", []> {
let assemblyFormat = "$value `[` $indices `]` attr-dict `:` functional-type(operands, results)";
}

def EmitC_SwitchOp : EmitC_Op<"switch", [RecursiveMemoryEffects,
SingleBlockImplicitTerminator<"emitc::YieldOp">,
DeclareOpInterfaceMethods<RegionBranchOpInterface,
["getRegionInvocationBounds",
"getEntrySuccessorRegions"]>]> {
let summary = "Switch operation";
let description = [{
The `emitc.switch` is a control-flow operation that branches to one of
the given regions based on the values of the argument and the cases. The
argument is always of type integer (singed or unsigned), excluding i8 and i1.
If the type is not specified, then i32 will be used by default.

The operation always has a "default" region and any number of case regions
denoted by integer constants. Control-flow transfers to the case region
whose constant value equals the value of the argument. If the argument does
not equal any of the case values, control-flow transfer to the "default"
region.

The operation does not return any value. Moreover, case regions and
default region must be terminated using the `emitc.yield` operation.

Example:

```mlir
// Cases with i32 type.
%0 = "emitc.variable"(){value = 42 : i32} : () -> i32
emitc.switch %0 : i32
case 2: {
%1 = emitc.call_opaque "func_b" () : () -> i32
emitc.yield
}
case 5: {
%2 = emitc.call_opaque "func_a" () : () -> i32
emitc.yield
}
default : {
%3 = "emitc.variable"(){value = 42.0 : f32} : () -> f32
%4 = "emitc.variable"(){value = 42.0 : f32} : () -> f32

emitc.call_opaque "func2" (%3) : (f32) -> ()
emitc.call_opaque "func3" (%3, %4) { args = [1 : index, 0 : index] } : (f32, f32) -> ()
emitc.yield
}
...
// Cases with i16 type.
%0 = "emitc.variable"(){value = 42 : i16} : () -> i16
emitc.switch %0 : i16
case 2: {
%1 = emitc.call_opaque "func_b" () : () -> i32
emitc.yield
}
case 5: {
%2 = emitc.call_opaque "func_a" () : () -> i32
emitc.yield
}
default : {
%3 = "emitc.variable"(){value = 42.0 : f32} : () -> f32
%4 = "emitc.variable"(){value = 42.0 : f32} : () -> f32

emitc.call_opaque "func2" (%3) : (f32) -> ()
emitc.call_opaque "func3" (%3, %4) { args = [1 : index, 0 : index] } : (f32, f32) -> ()
emitc.yield
}
```
}];

let arguments = (ins SwitchOperandType:$arg, DenseI64ArrayAttr:$cases);
let results = (outs);
let regions = (region SizedRegion<1>:$defaultRegion,
VariadicRegion<SizedRegion<1>>:$caseRegions);

let extraClassDeclaration = [{
/// Get the number of cases.
unsigned getNumCases();

/// Get the default region body.
Block &getDefaultBlock();

/// Get the body of a case region.
Block &getCaseBlock(unsigned idx);
}];

let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

#endif // MLIR_DIALECT_EMITC_IR_EMITC
205 changes: 205 additions & 0 deletions mlir/lib/Dialect/EmitC/IR/EmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,12 @@ bool mlir::emitc::isPointerWideType(Type type) {
type);
}

bool mlir::emitc::isSwitchOperandType(Type type) {
auto intType = llvm::dyn_cast<IntegerType>(type);
return isSupportedIntegerType(type) && intType.getWidth() != 1 &&
intType.getWidth() != 8;
}

/// Check that the type of the initial value is compatible with the operations
/// result type.
static LogicalResult verifyInitializationAttribute(Operation *op,
Expand Down Expand Up @@ -1096,6 +1102,205 @@ GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return success();
}

//===----------------------------------------------------------------------===//
// SwitchOp
//===----------------------------------------------------------------------===//

/// Parse the case regions and values.
static ParseResult
parseSwitchCases(OpAsmParser &parser, DenseI64ArrayAttr &cases,
SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
SmallVector<int64_t> caseValues;
while (succeeded(parser.parseOptionalKeyword("case"))) {
int64_t value;
Region &region = *caseRegions.emplace_back(std::make_unique<Region>());

if (parser.parseInteger(value) || parser.parseColon() ||
parser.parseRegion(region, /*arguments=*/{}))
return failure();
caseValues.push_back(value);
}
cases = parser.getBuilder().getDenseI64ArrayAttr(caseValues);
return success();
}

/// Print the case regions and values.
static void printSwitchCases(OpAsmPrinter &parser, Operation *op,
DenseI64ArrayAttr cases, RegionRange caseRegions) {
for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
parser.printNewline();
parser << "case " << value << ": ";
parser.printRegion(*region, /*printEntryBlockArgs=*/false);
}
return;
}

ParseResult SwitchOp::parse(OpAsmParser &parser, OperationState &result) {
OpAsmParser::UnresolvedOperand arg;
DenseI64ArrayAttr casesAttr;
SmallVector<std::unique_ptr<Region>, 2> caseRegionsRegions;
std::unique_ptr<Region> defaultRegionRegion = std::make_unique<Region>();

if (parser.parseOperand(arg))
return failure();

Type argType = parser.getBuilder().getI32Type();
// Parse optional type, else assume i32.
if (!parser.parseOptionalColon() && parser.parseType(argType))
return failure();

auto loc = parser.getCurrentLocation();
if (parser.parseOptionalAttrDict(result.attributes))
return failure();

if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() {
return parser.emitError(loc)
<< "'" << result.name.getStringRef() << "' op ";
})))
return failure();

auto odsResult = parseSwitchCases(parser, casesAttr, caseRegionsRegions);
if (odsResult)
return failure();

result.getOrAddProperties<SwitchOp::Properties>().cases = casesAttr;

if (parser.parseKeyword("default") || parser.parseColon())
return failure();

if (parser.parseRegion(*defaultRegionRegion))
return failure();

result.addRegion(std::move(defaultRegionRegion));
result.addRegions(caseRegionsRegions);

if (parser.resolveOperand(arg, argType, result.operands))
return failure();

return success();
}

void SwitchOp::print(OpAsmPrinter &parser) {
parser << ' ';
parser << getArg();
SmallVector<StringRef, 2> elidedAttrs;
elidedAttrs.push_back("cases");
parser.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
parser << ' ';
printSwitchCases(parser, *this, getCasesAttr(), getCaseRegions());
parser.printNewline();
parser << "default";
parser << ' ';
parser.printRegion(getDefaultRegion(), /*printEntryBlockArgs=*/true,
/*printBlockTerminators=*/true);

return;
}

static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region,
const Twine &name) {
auto yield = dyn_cast<emitc::YieldOp>(region.front().back());
if (!yield)
return op.emitOpError("expected region to end with emitc.yield, but got ")
<< region.front().back().getName();

if (yield.getNumOperands() != 0) {
return (op.emitOpError("expected each region to return ")
<< "0 values, but " << name << " returns "
<< yield.getNumOperands())
.attachNote(yield.getLoc())
<< "see yield operation here";
}
return success();
}

LogicalResult emitc::SwitchOp::verify() {
if (!isSwitchOperandType(getArg().getType()))
return emitOpError("unsupported type ") << getArg().getType();

if (getCases().size() != getCaseRegions().size()) {
return emitOpError("has ")
<< getCaseRegions().size() << " case regions but "
<< getCases().size() << " case values";
}

DenseSet<int64_t> valueSet;
for (int64_t value : getCases())
if (!valueSet.insert(value).second)
return emitOpError("has duplicate case value: ") << value;

if (failed(verifyRegion(*this, getDefaultRegion(), "default region")))
return failure();

for (auto [idx, caseRegion] : llvm::enumerate(getCaseRegions()))
if (failed(verifyRegion(*this, caseRegion, "case region #" + Twine(idx))))
return failure();

return success();
}

unsigned emitc::SwitchOp::getNumCases() { return getCases().size(); }

Block &emitc::SwitchOp::getDefaultBlock() { return getDefaultRegion().front(); }

Block &emitc::SwitchOp::getCaseBlock(unsigned idx) {
assert(idx < getNumCases() && "case index out-of-bounds");
return getCaseRegions()[idx].front();
}

void SwitchOp::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) {
llvm::copy(getRegions(), std::back_inserter(successors));
return;
}

void SwitchOp::getEntrySuccessorRegions(
ArrayRef<Attribute> operands,
SmallVectorImpl<RegionSuccessor> &successors) {
FoldAdaptor adaptor(operands, *this);

// If a constant was not provided, all regions are possible successors.
auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
if (!arg) {
llvm::copy(getRegions(), std::back_inserter(successors));
return;
}

// Otherwise, try to find a case with a matching value. If not, the
// default region is the only successor.
for (auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
if (caseValue == arg.getInt()) {
successors.emplace_back(&caseRegion);
return;
}
}
successors.emplace_back(&getDefaultRegion());
return;
}

void SwitchOp::getRegionInvocationBounds(
ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
if (!operandValue) {
// All regions are invoked at most once.
bounds.append(getNumRegions(), InvocationBounds(/*lb=*/0, /*ub=*/1));
return;
}

unsigned liveIndex = getNumRegions() - 1;
const auto *iteratorToInt = llvm::find(getCases(), operandValue.getInt());

liveIndex = iteratorToInt != getCases().end()
? std::distance(getCases().begin(), iteratorToInt)
: liveIndex;

for (unsigned regIndex = 0, regNum = getNumRegions(); regIndex < regNum;
++regIndex)
bounds.emplace_back(/*lb=*/0, /*ub=*/regIndex == liveIndex);

return;
}

//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
Expand Down
Loading

0 comments on commit 98b086a

Please sign in to comment.