Skip to content

Commit

Permalink
[mlir][emitc] Switching to custom assembly in tablegen similar to the…
Browse files Browse the repository at this point in the history
… scf dialect
  • Loading branch information
EtoAndruwa committed Aug 13, 2024
1 parent f6ba7a9 commit ead4764
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 161 deletions.
12 changes: 8 additions & 4 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
Original file line number Diff line number Diff line change
Expand Up @@ -1320,8 +1320,9 @@ def EmitC_SwitchOp : EmitC_Op<"switch", [RecursiveMemoryEffects,
not equal any of the case values, control-flow transfer to the "default"
region.

The operation does not return any value. Moreover, case regions and
default region must be explicitly terminated using the `emitc.yield` operation.
The operation does not return any value. Moreover, case regions must be
explicitly terminated using the `emitc.yield` operation. Default region is
yielded implicitly.

Example:

Expand All @@ -1339,7 +1340,6 @@ def EmitC_SwitchOp : EmitC_Op<"switch", [RecursiveMemoryEffects,
default: {
%3 = "emitc.variable"(){value = 42.0 : f32} : () -> f32
emitc.call_opaque "func2" (%3) : (f32) -> ()
emitc.yield
}

// Output:
Expand All @@ -1365,6 +1365,11 @@ def EmitC_SwitchOp : EmitC_Op<"switch", [RecursiveMemoryEffects,
let regions = (region SizedRegion<1>:$defaultRegion,
VariadicRegion<SizedRegion<1>>:$caseRegions);

let assemblyFormat = [{
$arg `:` type($arg) attr-dict custom<SwitchCases>($cases, $caseRegions) `\n`
`` `default` $defaultRegion
}];

let extraClassDeclaration = [{
/// Get the number of cases.
unsigned getNumCases();
Expand All @@ -1376,7 +1381,6 @@ def EmitC_SwitchOp : EmitC_Op<"switch", [RecursiveMemoryEffects,
Block &getCaseBlock(unsigned idx);
}];

let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

Expand Down
60 changes: 1 addition & 59 deletions mlir/lib/Dialect/EmitC/IR/EmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1109,7 +1109,7 @@ parseSwitchCases(OpAsmParser &parser, DenseI64ArrayAttr &cases,
int64_t value;
Region &region = *caseRegions.emplace_back(std::make_unique<Region>());

if (parser.parseInteger(value) || parser.parseColon() ||
if (parser.parseInteger(value) ||
parser.parseRegion(region, /*arguments=*/{}))
return failure();
caseValues.push_back(value);
Expand All @@ -1128,64 +1128,6 @@ static void printSwitchCases(OpAsmPrinter &p, Operation *op,
}
}

ParseResult SwitchOp::parse(OpAsmParser &parser, OperationState &result) {
OpAsmParser::UnresolvedOperand arg;
DenseI64ArrayAttr casesAttr;
SmallVector<std::unique_ptr<Region>, 2> caseRegionsRegions;
std::unique_ptr<Region> defaultRegionRegion = std::make_unique<Region>();

if (parser.parseOperand(arg))
return failure();

Type argType;
// Parse the case's type.
if (parser.parseColon() || parser.parseType(argType))
return failure();

auto loc = parser.getCurrentLocation();
if (parser.parseOptionalAttrDict(result.attributes))
return failure();

if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() {
return parser.emitError(loc)
<< "'" << result.name.getStringRef() << "' op ";
})))
return failure();

auto odsResult = parseSwitchCases(parser, casesAttr, caseRegionsRegions);
if (odsResult)
return failure();

result.getOrAddProperties<SwitchOp::Properties>().cases = casesAttr;

if (parser.parseKeyword("default") || parser.parseColon())
return failure();

if (parser.parseRegion(*defaultRegionRegion))
return failure();

result.addRegion(std::move(defaultRegionRegion));
result.addRegions(caseRegionsRegions);

if (parser.resolveOperand(arg, argType, result.operands))
return failure();

return success();
}

void SwitchOp::print(OpAsmPrinter &p) {
p << ' ' << getArg();
SmallVector<StringRef, 2> elidedAttrs;
elidedAttrs.push_back("cases");
p.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
p << ' ';
printSwitchCases(p, *this, getCasesAttr(), getCaseRegions());
p.printNewline();
p << "default ";
p.printRegion(getDefaultRegion(), /*printEntryBlockArgs=*/true,
/*printBlockTerminators=*/true);
}

static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region,
const Twine &name) {
auto yield = dyn_cast<emitc::YieldOp>(region.front().back());
Expand Down
22 changes: 9 additions & 13 deletions mlir/test/Conversion/SCFToEmitC/switch.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,16 @@
// CHECK-LABEL: func.func @switch_no_result(
// CHECK-SAME: %[[VAL_0:.*]]: index) {
// CHECK: emitc.switch %[[VAL_0]]
// CHECK: case 2: {
// CHECK: case 2 {
// CHECK: %[[VAL_1:.*]] = arith.constant 10 : i32
// CHECK: emitc.yield
// CHECK: }
// CHECK: case 5: {
// CHECK: case 5 {
// CHECK: %[[VAL_2:.*]] = arith.constant 20 : i32
// CHECK: emitc.yield
// CHECK: }
// CHECK: default {
// CHECK: %[[VAL_3:.*]] = arith.constant 30 : i32
// CHECK: emitc.yield
// CHECK: }
// CHECK: return
// CHECK: }
Expand All @@ -29,29 +28,27 @@ func.func @switch_no_result(%arg0 : index) {
}
default {
%3 = arith.constant 30 : i32
scf.yield
}
return
}

// CHECK-LABEL: func.func @switch_one_result(
// CHECK-SAME: %[[VAL_0:.*]]: index) {
// CHECK: %[[VAL_1:.*]] = "emitc.variable"() <{value = #[[?]]<"">}> : () -> i32
// CHECK: %[[VAL_1:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> i32
// CHECK: emitc.switch %[[VAL_0]]
// CHECK: case 2: {
// CHECK: case 2 {
// CHECK: %[[VAL_2:.*]] = arith.constant 10 : i32
// CHECK: emitc.assign %[[VAL_2]] : i32 to %[[VAL_1]] : i32
// CHECK: emitc.yield
// CHECK: }
// CHECK: case 5: {
// CHECK: case 5 {
// CHECK: %[[VAL_3:.*]] = arith.constant 20 : i32
// CHECK: emitc.assign %[[VAL_3]] : i32 to %[[VAL_1]] : i32
// CHECK: emitc.yield
// CHECK: }
// CHECK: default {
// CHECK: %[[VAL_4:.*]] = arith.constant 30 : i32
// CHECK: emitc.assign %[[VAL_4]] : i32 to %[[VAL_1]] : i32
// CHECK: emitc.yield
// CHECK: }
// CHECK: return
// CHECK: }
Expand All @@ -74,17 +71,17 @@ func.func @switch_one_result(%arg0 : index) {

// CHECK-LABEL: func.func @switch_two_results(
// CHECK-SAME: %[[VAL_0:.*]]: index) {
// CHECK: %[[VAL_1:.*]] = "emitc.variable"() <{value = #[[?]]<"">}> : () -> i32
// CHECK: %[[VAL_2:.*]] = "emitc.variable"() <{value = #[[?]]<"">}> : () -> f32
// CHECK: %[[VAL_1:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> i32
// CHECK: %[[VAL_2:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
// CHECK: emitc.switch %[[VAL_0]]
// CHECK: case 2: {
// CHECK: case 2 {
// CHECK: %[[VAL_3:.*]] = arith.constant 10 : i32
// CHECK: %[[VAL_4:.*]] = arith.constant 1.200000e+00 : f32
// CHECK: emitc.assign %[[VAL_3]] : i32 to %[[VAL_1]] : i32
// CHECK: emitc.assign %[[VAL_4]] : f32 to %[[VAL_2]] : f32
// CHECK: emitc.yield
// CHECK: }
// CHECK: case 5: {
// CHECK: case 5 {
// CHECK: %[[VAL_5:.*]] = arith.constant 20 : i32
// CHECK: %[[VAL_6:.*]] = arith.constant 2.400000e+00 : f32
// CHECK: emitc.assign %[[VAL_5]] : i32 to %[[VAL_1]] : i32
Expand All @@ -96,7 +93,6 @@ func.func @switch_one_result(%arg0 : index) {
// CHECK: %[[VAL_8:.*]] = arith.constant 3.600000e+00 : f32
// CHECK: emitc.assign %[[VAL_7]] : i32 to %[[VAL_1]] : i32
// CHECK: emitc.assign %[[VAL_8]] : f32 to %[[VAL_2]] : f32
// CHECK: emitc.yield
// CHECK: }
// CHECK: return
// CHECK: }
Expand Down
60 changes: 19 additions & 41 deletions mlir/test/Dialect/EmitC/invalid_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -470,18 +470,18 @@ func.func @member_of_ptr(%arg0: i32) {
// -----

func.func @emitc_switch() {
%0 = "emitc.variable"(){value = 1 : ui16} : () -> ui16
%0 = "emitc.variable"(){value = 1 : i16} : () -> i16

// expected-error@+1 {{'emitc.switch' op expected region to end with emitc.yield, but got emitc.call_opaque}}
emitc.switch %0 : ui16
case 2: {
emitc.switch %0 : i16
case 2 {
%1 = emitc.call_opaque "func_b" () : () -> i32
}
case 5: {
case 5 {
%2 = emitc.call_opaque "func_a" () : () -> i32
emitc.yield
}
default: {
default {
%3 = "emitc.variable"(){value = 42.0 : f32} : () -> f32
emitc.call_opaque "func2" (%3) : (f32) -> ()
emitc.yield
Expand All @@ -492,41 +492,19 @@ func.func @emitc_switch() {
// -----

func.func @emitc_switch() {
%0 = "emitc.variable"(){value = 1 : ui16} : () -> ui16
%0 = "emitc.variable"(){value = 1 : i32} : () -> i32

// expected-error@+1 {{'emitc.switch' op expected region to end with emitc.yield, but got emitc.call_opaque}}
emitc.switch %0 : ui16
case 2: {
%1 = emitc.call_opaque "func_b" () : () -> i32
emitc.yield
}
case 5: {
%2 = emitc.call_opaque "func_a" () : () -> i32
emitc.yield
}
default: {
%3 = "emitc.variable"(){value = 42.0 : f32} : () -> f32
emitc.call_opaque "func2" (%3) : (f32) -> ()
}
return
}

// -----

func.func @emitc_switch() {
%0 = "emitc.variable"(){value = 1 : i16} : () -> i16

emitc.switch %0 : i16
case 2: {
emitc.switch %0 : i32
case 2 {
%1 = emitc.call_opaque "func_b" () : () -> i32
emitc.yield
}
// expected-error@+1 {{custom op 'emitc.switch' expected integer value}}
case : {
case {
%2 = emitc.call_opaque "func_a" () : () -> i32
emitc.yield
}
default: {
default {
%3 = "emitc.variable"(){value = 42.0 : f32} : () -> f32
emitc.call_opaque "func2" (%3) : (f32) -> ()
emitc.yield
Expand All @@ -537,14 +515,14 @@ func.func @emitc_switch() {
// -----

func.func @emitc_switch() {
%0 = "emitc.variable"(){value = 1 : i16} : () -> i16
%0 = "emitc.variable"(){value = 1 : i8} : () -> i8

emitc.switch %0 : i16
case 2: {
emitc.switch %0 : i8
case 2 {
%1 = emitc.call_opaque "func_b" () : () -> i32
emitc.yield
}
case 3: {
case 3 {
%2 = emitc.call_opaque "func_a" () : () -> i32
emitc.yield
}
Expand All @@ -555,19 +533,19 @@ func.func @emitc_switch() {
// -----

func.func @emitc_switch() {
%0 = "emitc.variable"(){value = 1 : i16} : () -> i16
%0 = "emitc.variable"(){value = 1 : i64} : () -> i64

// expected-error@+1 {{'emitc.switch' op has duplicate case value: 2}}
emitc.switch %0 : i16
case 2: {
emitc.switch %0 : i64
case 2 {
%1 = emitc.call_opaque "func_b" () : () -> i32
emitc.yield
}
case 2: {
case 2 {
%2 = emitc.call_opaque "func_a" () : () -> i32
emitc.yield
}
default: {
default {
%3 = "emitc.variable"(){value = 42.0 : f32} : () -> f32
emitc.call_opaque "func2" (%3) : (f32) -> ()
emitc.yield
Expand Down
20 changes: 20 additions & 0 deletions mlir/test/Dialect/EmitC/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -261,3 +261,23 @@ func.func @member_access(%arg0: !emitc.opaque<"mystruct">, %arg1: !emitc.opaque<
%2 = "emitc.member_of_ptr" (%arg2) {member = "a"} : (!emitc.ptr<!emitc.opaque<"mystruct">>) -> i32
return
}

func.func @switch() {
%0 = "emitc.variable"(){value = 1 : index} : () -> !emitc.ptrdiff_t

emitc.switch %0 : !emitc.ptrdiff_t
case 1 {
%1 = emitc.call_opaque "func_b" () : () -> i32
emitc.yield
}
case 2 {
%2 = emitc.call_opaque "func_a" () : () -> i32
emitc.yield
}
default {
%3 = "emitc.variable"(){value = 42.0 : f32} : () -> f32
emitc.call_opaque "func2" (%3) : (f32) -> ()
}

return
}
Loading

0 comments on commit ead4764

Please sign in to comment.