diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index bcdd001528c46d6..fc00fe0c6dfe1ec 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -908,6 +908,36 @@ def EmitC_SubOp : EmitC_BinaryOp<"sub", [CExpression]> { let hasVerifier = 1; } +def EmitC_ConditionalOp : EmitC_Op<"conditional", + [AllTypesMatch<["true_value", "false_value", "result"]>, CExpression]> { + let summary = "Conditional (ternary) operation"; + let description = [{ + With the `conditional` operation the ternary conditional operator can + be applied. + + Example: + + ```mlir + %0 = emitc.cmp gt, %arg0, %arg1 : (i32, i32) -> i1 + + %c0 = "emitc.constant"() {value = 10 : i32} : () -> i32 + %c1 = "emitc.constant"() {value = 11 : i32} : () -> i32 + + %1 = emitc.conditional %0, %c0, %c1 : i32 + ``` + ```c++ + // Code emitted for the operations above. + bool v3 = v1 > v2; + int32_t v4 = 10; + int32_t v5 = 11; + int32_t v6 = v3 ? v4 : v5; + ``` + }]; + let arguments = (ins I1:$condition, AnyType:$true_value, AnyType:$false_value); + let results = (outs AnyType:$result); + let assemblyFormat = "operands attr-dict `:` type($result)"; +} + def EmitC_UnaryMinusOp : EmitC_UnaryOp<"unary_minus", [CExpression]> { let summary = "Unary minus operation"; let description = [{ diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp index 40dce001a3b2242..3532785c31b9396 100644 --- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp +++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp @@ -54,6 +54,31 @@ class ArithOpConversion final : public OpConversionPattern { return success(); } }; + +class SelectOpConversion : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::SelectOp selectOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + Type dstType = getTypeConverter()->convertType(selectOp.getType()); + if (!dstType) + return rewriter.notifyMatchFailure(selectOp, "type conversion failed"); + + if (!adaptor.getCondition().getType().isInteger(1)) + return rewriter.notifyMatchFailure( + selectOp, + "can only be converted if condition is a scalar of type i1"); + + rewriter.replaceOpWithNewOp(selectOp, dstType, + adaptor.getOperands()); + + return success(); + } +}; + } // namespace //===----------------------------------------------------------------------===// @@ -70,7 +95,8 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter, ArithOpConversion, ArithOpConversion, ArithOpConversion, - ArithOpConversion + ArithOpConversion, + SelectOpConversion >(typeConverter, ctx); // clang-format on } diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index 6e477a34fc4ba9c..f19ab031096e2f3 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -96,6 +96,7 @@ static FailureOr getOperatorPrecedence(Operation *operation) { } return op->emitError("unsupported cmp predicate"); }) + .Case([&](auto op) { return 2; }) .Case([&](auto op) { return 13; }) .Case([&](auto op) { return 4; }) .Case([&](auto op) { return 15; }) @@ -455,6 +456,29 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::CmpOp cmpOp) { return printBinaryOperation(emitter, operation, binaryOperator); } +static LogicalResult printOperation(CppEmitter &emitter, + emitc::ConditionalOp conditionalOp) { + raw_ostream &os = emitter.ostream(); + + if (failed(emitter.emitAssignPrefix(*conditionalOp))) + return failure(); + + if (failed(emitter.emitOperand(conditionalOp.getCondition()))) + return failure(); + + os << " ? "; + + if (failed(emitter.emitOperand(conditionalOp.getTrueValue()))) + return failure(); + + os << " : "; + + if (failed(emitter.emitOperand(conditionalOp.getFalseValue()))) + return failure(); + + return success(); +} + static LogicalResult printOperation(CppEmitter &emitter, emitc::VerbatimOp verbatimOp) { raw_ostream &os = emitter.ostream(); @@ -1410,12 +1434,13 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { emitc::BitwiseNotOp, emitc::BitwiseOrOp, emitc::BitwiseRightShiftOp, emitc::BitwiseXorOp, emitc::CallOp, emitc::CallOpaqueOp, emitc::CastOp, emitc::CmpOp, - emitc::ConstantOp, emitc::DeclareFuncOp, emitc::DivOp, - emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp, emitc::IfOp, - emitc::IncludeOp, emitc::LogicalAndOp, emitc::LogicalNotOp, - emitc::LogicalOrOp, emitc::MulOp, emitc::RemOp, emitc::ReturnOp, - emitc::SubOp, emitc::SubscriptOp, emitc::UnaryMinusOp, - emitc::UnaryPlusOp, emitc::VariableOp, emitc::VerbatimOp>( + emitc::ConditionalOp, emitc::ConstantOp, emitc::DeclareFuncOp, + emitc::DivOp, emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp, + emitc::IfOp, emitc::IncludeOp, emitc::LogicalAndOp, + emitc::LogicalNotOp, emitc::LogicalOrOp, emitc::MulOp, + emitc::RemOp, emitc::ReturnOp, emitc::SubOp, emitc::SubscriptOp, + emitc::UnaryMinusOp, emitc::UnaryPlusOp, emitc::VariableOp, + emitc::VerbatimOp>( [&](auto op) { return printOperation(*this, op); }) // Func ops. .Case( diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir index 2886810c01e917d..022530ef4db84b3 100644 --- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir +++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir @@ -34,3 +34,11 @@ func.func @arith_ops(%arg0: f32, %arg1: f32) { return } + +// ----- + +func.func @arith_select(%arg0: i1, %arg1: tensor<8xi32>, %arg2: tensor<8xi32>) -> () { + // CHECK: [[V0:[^ ]*]] = emitc.conditional %arg0, %arg1, %arg2 : tensor<8xi32> + %0 = arith.select %arg0, %arg1, %arg2 : i1, tensor<8xi32> + return +} diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir index 02294d13cef7638..74ac826eace7c71 100644 --- a/mlir/test/Dialect/EmitC/ops.mlir +++ b/mlir/test/Dialect/EmitC/ops.mlir @@ -71,6 +71,11 @@ func.func @bitwise(%arg0: i32, %arg1: i32) -> () { return } +func.func @cond(%cond: i1, %arg0: i32, %arg1: i32) -> () { + %0 = emitc.conditional %cond, %arg0, %arg1 : i32 + return +} + func.func @div_int(%arg0: i32, %arg1: i32) { %1 = "emitc.div" (%arg0, %arg1) : (i32, i32) -> i32 return diff --git a/mlir/test/Target/Cpp/conditional.mlir b/mlir/test/Target/Cpp/conditional.mlir new file mode 100644 index 000000000000000..2470fbeb33adaec --- /dev/null +++ b/mlir/test/Target/Cpp/conditional.mlir @@ -0,0 +1,9 @@ +// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s + +func.func @cond(%cond: i1, %arg0: i32, %arg1: i32) -> () { + %0 = emitc.conditional %cond, %arg0, %arg1 : i32 + return +} + +// CHECK-LABEL: void cond +// CHECK-NEXT: int32_t [[V3:[^ ]*]] = [[V0:[^ ]*]] ? [[V1:[^ ]*]] : [[V2:[^ ]*]];