Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][EmitC] Add an emitc.conditional operator #84883

Merged
merged 3 commits into from
Mar 12, 2024

Conversation

marbre
Copy link
Member

@marbre marbre commented Mar 12, 2024

This adds an emitc.conditional operation for the ternary conditional operator. Furthermore, this adds a converion from arith.select to the new op.

This adds an `emitc.conditional` operation for the ternary conditional
operator. Furthermore, this adds a converion from `arith.select` to the
new op.
@llvmbot
Copy link
Member

llvmbot commented Mar 12, 2024

@llvm/pr-subscribers-mlir-emitc

@llvm/pr-subscribers-mlir

Author: Marius Brehler (marbre)

Changes

This adds an emitc.conditional operation for the ternary conditional operator. Furthermore, this adds a converion from arith.select to the new op.


Full diff: https://github.com/llvm/llvm-project/pull/84883.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/EmitC/IR/EmitC.td (+30)
  • (modified) mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp (+22-1)
  • (modified) mlir/lib/Target/Cpp/TranslateToCpp.cpp (+31-6)
  • (modified) mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir (+8)
  • (modified) mlir/test/Dialect/EmitC/ops.mlir (+5)
  • (added) mlir/test/Target/Cpp/conditional.mlir (+9)
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index ac1e38a5506da0..ec842f76628c08 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -908,6 +908,36 @@ def EmitC_SubOp : EmitC_BinaryOp<"sub", [CExpression]> {
   let hasVerifier = 1;
 }
 
+def EmitC_ConditionalOp : EmitC_Op<"conditional",
+    [AllTypesMatch<["true_value", "false_value", "result"]>, CExpression]> {
+  let summary = "Conditional (ternary) operation";
+  let description = [{
+    With the `conditional` operation the ternary conditional operator can
+    be applied.
+
+    Example:
+
+    ```mlir
+    %0 = emitc.cmp gt, %arg0, %arg1 : (i32, i32) -> i1
+
+    %c0 = "emitc.constant"() {value = 10 : i32} : () -> i32
+    %c1 = "emitc.constant"() {value = 11 : i32} : () -> i32
+
+    %1 = emitc.conditional %0, %c0, %c1 : i32
+    ```
+    ```c++
+    // Code emitted for the operations above.
+    bool v3 = v1 > v2;
+    int32_t v4 = 10;
+    int32_t v5 = 11;
+    int32_t v6 = v3 ? v4 : v5;
+    ```
+  }];
+  let arguments = (ins I1:$condition, AnyType:$true_value, AnyType:$false_value);
+  let results = (outs AnyType:$result);
+  let assemblyFormat = "operands attr-dict `:` type($result)";
+}
+
 def EmitC_UnaryMinusOp : EmitC_UnaryOp<"unary_minus", [CExpression]> {
   let summary = "Unary minus operation";
   let description = [{
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index 40dce001a3b224..afef4ac1bfc5f9 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -54,6 +54,26 @@ class ArithOpConversion final : public OpConversionPattern<ArithOp> {
     return success();
   }
 };
+
+class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
+public:
+  using OpConversionPattern<arith::SelectOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(arith::SelectOp selectOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    if (!selectOp.getCondition().getType().isInteger())
+      return rewriter.notifyMatchFailure(
+          selectOp, "can only converted if condition is a scalar of type i1");
+
+    rewriter.replaceOpWithNewOp<emitc::ConditionalOp>(
+        selectOp, selectOp.getType(), adaptor.getOperands());
+
+    return success();
+  }
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -70,7 +90,8 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
     ArithOpConversion<arith::AddFOp, emitc::AddOp>,
     ArithOpConversion<arith::DivFOp, emitc::DivOp>,
     ArithOpConversion<arith::MulFOp, emitc::MulOp>,
-    ArithOpConversion<arith::SubFOp, emitc::SubOp>
+    ArithOpConversion<arith::SubFOp, emitc::SubOp>,
+    SelectOpConversion
   >(typeConverter, ctx);
   // clang-format on
 }
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 3cf137c1d07c0e..7cbb1e9265e174 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -96,6 +96,7 @@ static FailureOr<int> getOperatorPrecedence(Operation *operation) {
         }
         return op->emitError("unsupported cmp predicate");
       })
+      .Case<emitc::ConditionalOp>([&](auto op) { return 2; })
       .Case<emitc::DivOp>([&](auto op) { return 13; })
       .Case<emitc::LogicalAndOp>([&](auto op) { return 4; })
       .Case<emitc::LogicalNotOp>([&](auto op) { return 15; })
@@ -446,6 +447,29 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::CmpOp cmpOp) {
   return printBinaryOperation(emitter, operation, binaryOperator);
 }
 
+static LogicalResult printOperation(CppEmitter &emitter,
+                                    emitc::ConditionalOp conditionalOp) {
+  raw_ostream &os = emitter.ostream();
+
+  if (failed(emitter.emitAssignPrefix(*conditionalOp)))
+    return failure();
+
+  if (failed(emitter.emitOperand(conditionalOp.getCondition())))
+    return failure();
+
+  os << " ? ";
+
+  if (failed(emitter.emitOperand(conditionalOp.getTrueValue())))
+    return failure();
+
+  os << " : ";
+
+  if (failed(emitter.emitOperand(conditionalOp.getFalseValue())))
+    return failure();
+
+  return success();
+}
+
 static LogicalResult printOperation(CppEmitter &emitter,
                                     emitc::VerbatimOp verbatimOp) {
   raw_ostream &os = emitter.ostream();
@@ -1383,12 +1407,13 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
                 emitc::BitwiseNotOp, emitc::BitwiseOrOp,
                 emitc::BitwiseRightShiftOp, emitc::BitwiseXorOp, emitc::CallOp,
                 emitc::CallOpaqueOp, emitc::CastOp, emitc::CmpOp,
-                emitc::ConstantOp, emitc::DeclareFuncOp, emitc::DivOp,
-                emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp, emitc::IfOp,
-                emitc::IncludeOp, emitc::LogicalAndOp, emitc::LogicalNotOp,
-                emitc::LogicalOrOp, emitc::MulOp, emitc::RemOp, emitc::ReturnOp,
-                emitc::SubOp, emitc::UnaryMinusOp, emitc::UnaryPlusOp,
-                emitc::VariableOp, emitc::VerbatimOp>(
+                emitc::ConditionalOp, emitc::ConstantOp, emitc::DeclareFuncOp,
+                emitc::DivOp, emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp,
+                emitc::IfOp, emitc::IncludeOp, emitc::LogicalAndOp,
+                emitc::LogicalNotOp, emitc::LogicalOrOp, emitc::MulOp,
+                emitc::RemOp, emitc::ReturnOp, emitc::SubOp,
+                emitc::UnaryMinusOp, emitc::UnaryPlusOp, emitc::VariableOp,
+                emitc::VerbatimOp>(
               [&](auto op) { return printOperation(*this, op); })
           // Func ops.
           .Case<func::CallOp, func::FuncOp, func::ReturnOp>(
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index 2886810c01e917..022530ef4db84b 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -34,3 +34,11 @@ func.func @arith_ops(%arg0: f32, %arg1: f32) {
 
   return
 }
+
+// -----
+
+func.func @arith_select(%arg0: i1, %arg1: tensor<8xi32>, %arg2: tensor<8xi32>) -> () {
+  // CHECK: [[V0:[^ ]*]] = emitc.conditional %arg0, %arg1, %arg2 : tensor<8xi32>
+  %0 = arith.select %arg0, %arg1, %arg2 : i1, tensor<8xi32>
+  return
+}
diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir
index 122b1d9ef1059f..5f00a295ed740e 100644
--- a/mlir/test/Dialect/EmitC/ops.mlir
+++ b/mlir/test/Dialect/EmitC/ops.mlir
@@ -71,6 +71,11 @@ func.func @bitwise(%arg0: i32, %arg1: i32) -> () {
   return
 }
 
+func.func @cond(%cond: i1, %arg0: i32, %arg1: i32) -> () {
+  %0 = emitc.conditional %cond, %arg0, %arg1 : i32
+  return
+}
+
 func.func @div_int(%arg0: i32, %arg1: i32) {
   %1 = "emitc.div" (%arg0, %arg1) : (i32, i32) -> i32
   return
diff --git a/mlir/test/Target/Cpp/conditional.mlir b/mlir/test/Target/Cpp/conditional.mlir
new file mode 100644
index 00000000000000..2470fbeb33adae
--- /dev/null
+++ b/mlir/test/Target/Cpp/conditional.mlir
@@ -0,0 +1,9 @@
+// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s
+
+func.func @cond(%cond: i1, %arg0: i32, %arg1: i32) -> () {
+  %0 = emitc.conditional %cond, %arg0, %arg1 : i32
+  return
+}
+
+// CHECK-LABEL: void cond
+// CHECK-NEXT:  int32_t [[V3:[^ ]*]] = [[V0:[^ ]*]] ? [[V1:[^ ]*]] : [[V2:[^ ]*]];

Copy link

@simon-camp simon-camp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd replace the auto with Type, otherwise looks good.

@@ -70,7 +95,8 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
ArithOpConversion<arith::AddFOp, emitc::AddOp>,
ArithOpConversion<arith::DivFOp, emitc::DivOp>,
ArithOpConversion<arith::MulFOp, emitc::MulOp>,
ArithOpConversion<arith::SubFOp, emitc::SubOp>
ArithOpConversion<arith::SubFOp, emitc::SubOp>,
SelectOpConversion
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
SelectOpConversion
ArithSelectOpConversion

Maybe rename it such that is more consistent to the rest?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The intend behind the naming ArithOp is that this is a conversion for a "genric" arith op. The ArithToEmitC already indicates that this file only contains conversions from the Arith to the EmitC dialect.

@marbre marbre merged commit 19266ca into llvm:main Mar 12, 2024
3 of 4 checks passed
@marbre marbre deleted the emitc.conditional branch March 12, 2024 10:27
mgehre-amd pushed a commit to Xilinx/llvm-project that referenced this pull request Mar 14, 2024
This adds an `emitc.conditional` operation for the ternary conditional
operator. Furthermore, this adds a converion from `arith.select` to the
new op.
mgehre-amd pushed a commit to Xilinx/llvm-project that referenced this pull request Mar 14, 2024
This adds an `emitc.conditional` operation for the ternary conditional
operator. Furthermore, this adds a converion from `arith.select` to the
new op.
mgehre-amd added a commit to Xilinx/llvm-project that referenced this pull request Mar 15, 2024
[mlir][EmitC] Add an `emitc.conditional` operator (llvm#84883)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants