Skip to content

Commit

Permalink
[mlir][EmitC] Add an emitc.conditional operator (llvm#84883)
Browse files Browse the repository at this point in the history
This adds an `emitc.conditional` operation for the ternary conditional
operator. Furthermore, this adds a converion from `arith.select` to the
new op.
  • Loading branch information
marbre authored and mgehre-amd committed Mar 14, 2024
1 parent 8bc91d3 commit 2acc99e
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 7 deletions.
30 changes: 30 additions & 0 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
Original file line number Diff line number Diff line change
Expand Up @@ -908,6 +908,36 @@ def EmitC_SubOp : EmitC_BinaryOp<"sub", [CExpression]> {
let hasVerifier = 1;
}

def EmitC_ConditionalOp : EmitC_Op<"conditional",
[AllTypesMatch<["true_value", "false_value", "result"]>, CExpression]> {
let summary = "Conditional (ternary) operation";
let description = [{
With the `conditional` operation the ternary conditional operator can
be applied.

Example:

```mlir
%0 = emitc.cmp gt, %arg0, %arg1 : (i32, i32) -> i1

%c0 = "emitc.constant"() {value = 10 : i32} : () -> i32
%c1 = "emitc.constant"() {value = 11 : i32} : () -> i32

%1 = emitc.conditional %0, %c0, %c1 : i32
```
```c++
// Code emitted for the operations above.
bool v3 = v1 > v2;
int32_t v4 = 10;
int32_t v5 = 11;
int32_t v6 = v3 ? v4 : v5;
```
}];
let arguments = (ins I1:$condition, AnyType:$true_value, AnyType:$false_value);
let results = (outs AnyType:$result);
let assemblyFormat = "operands attr-dict `:` type($result)";
}

def EmitC_UnaryMinusOp : EmitC_UnaryOp<"unary_minus", [CExpression]> {
let summary = "Unary minus operation";
let description = [{
Expand Down
28 changes: 27 additions & 1 deletion mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,31 @@ class ArithOpConversion final : public OpConversionPattern<ArithOp> {
return success();
}
};

class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
public:
using OpConversionPattern<arith::SelectOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(arith::SelectOp selectOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

Type dstType = getTypeConverter()->convertType(selectOp.getType());
if (!dstType)
return rewriter.notifyMatchFailure(selectOp, "type conversion failed");

if (!adaptor.getCondition().getType().isInteger(1))
return rewriter.notifyMatchFailure(
selectOp,
"can only be converted if condition is a scalar of type i1");

rewriter.replaceOpWithNewOp<emitc::ConditionalOp>(selectOp, dstType,
adaptor.getOperands());

return success();
}
};

} // namespace

//===----------------------------------------------------------------------===//
Expand All @@ -70,7 +95,8 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
ArithOpConversion<arith::AddFOp, emitc::AddOp>,
ArithOpConversion<arith::DivFOp, emitc::DivOp>,
ArithOpConversion<arith::MulFOp, emitc::MulOp>,
ArithOpConversion<arith::SubFOp, emitc::SubOp>
ArithOpConversion<arith::SubFOp, emitc::SubOp>,
SelectOpConversion
>(typeConverter, ctx);
// clang-format on
}
37 changes: 31 additions & 6 deletions mlir/lib/Target/Cpp/TranslateToCpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ static FailureOr<int> getOperatorPrecedence(Operation *operation) {
}
return op->emitError("unsupported cmp predicate");
})
.Case<emitc::ConditionalOp>([&](auto op) { return 2; })
.Case<emitc::DivOp>([&](auto op) { return 13; })
.Case<emitc::LogicalAndOp>([&](auto op) { return 4; })
.Case<emitc::LogicalNotOp>([&](auto op) { return 15; })
Expand Down Expand Up @@ -455,6 +456,29 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::CmpOp cmpOp) {
return printBinaryOperation(emitter, operation, binaryOperator);
}

static LogicalResult printOperation(CppEmitter &emitter,
emitc::ConditionalOp conditionalOp) {
raw_ostream &os = emitter.ostream();

if (failed(emitter.emitAssignPrefix(*conditionalOp)))
return failure();

if (failed(emitter.emitOperand(conditionalOp.getCondition())))
return failure();

os << " ? ";

if (failed(emitter.emitOperand(conditionalOp.getTrueValue())))
return failure();

os << " : ";

if (failed(emitter.emitOperand(conditionalOp.getFalseValue())))
return failure();

return success();
}

static LogicalResult printOperation(CppEmitter &emitter,
emitc::VerbatimOp verbatimOp) {
raw_ostream &os = emitter.ostream();
Expand Down Expand Up @@ -1410,12 +1434,13 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
emitc::BitwiseNotOp, emitc::BitwiseOrOp,
emitc::BitwiseRightShiftOp, emitc::BitwiseXorOp, emitc::CallOp,
emitc::CallOpaqueOp, emitc::CastOp, emitc::CmpOp,
emitc::ConstantOp, emitc::DeclareFuncOp, emitc::DivOp,
emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp, emitc::IfOp,
emitc::IncludeOp, emitc::LogicalAndOp, emitc::LogicalNotOp,
emitc::LogicalOrOp, emitc::MulOp, emitc::RemOp, emitc::ReturnOp,
emitc::SubOp, emitc::SubscriptOp, emitc::UnaryMinusOp,
emitc::UnaryPlusOp, emitc::VariableOp, emitc::VerbatimOp>(
emitc::ConditionalOp, emitc::ConstantOp, emitc::DeclareFuncOp,
emitc::DivOp, emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp,
emitc::IfOp, emitc::IncludeOp, emitc::LogicalAndOp,
emitc::LogicalNotOp, emitc::LogicalOrOp, emitc::MulOp,
emitc::RemOp, emitc::ReturnOp, emitc::SubOp, emitc::SubscriptOp,
emitc::UnaryMinusOp, emitc::UnaryPlusOp, emitc::VariableOp,
emitc::VerbatimOp>(
[&](auto op) { return printOperation(*this, op); })
// Func ops.
.Case<func::CallOp, func::FuncOp, func::ReturnOp>(
Expand Down
8 changes: 8 additions & 0 deletions mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,11 @@ func.func @arith_ops(%arg0: f32, %arg1: f32) {

return
}

// -----

func.func @arith_select(%arg0: i1, %arg1: tensor<8xi32>, %arg2: tensor<8xi32>) -> () {
// CHECK: [[V0:[^ ]*]] = emitc.conditional %arg0, %arg1, %arg2 : tensor<8xi32>
%0 = arith.select %arg0, %arg1, %arg2 : i1, tensor<8xi32>
return
}
5 changes: 5 additions & 0 deletions mlir/test/Dialect/EmitC/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ func.func @bitwise(%arg0: i32, %arg1: i32) -> () {
return
}

func.func @cond(%cond: i1, %arg0: i32, %arg1: i32) -> () {
%0 = emitc.conditional %cond, %arg0, %arg1 : i32
return
}

func.func @div_int(%arg0: i32, %arg1: i32) {
%1 = "emitc.div" (%arg0, %arg1) : (i32, i32) -> i32
return
Expand Down
9 changes: 9 additions & 0 deletions mlir/test/Target/Cpp/conditional.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s

func.func @cond(%cond: i1, %arg0: i32, %arg1: i32) -> () {
%0 = emitc.conditional %cond, %arg0, %arg1 : i32
return
}

// CHECK-LABEL: void cond
// CHECK-NEXT: int32_t [[V3:[^ ]*]] = [[V0:[^ ]*]] ? [[V1:[^ ]*]] : [[V2:[^ ]*]];

0 comments on commit 2acc99e

Please sign in to comment.