-
Notifications
You must be signed in to change notification settings - Fork 12.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][EmitC] Add an emitc.conditional
operator
#84883
Conversation
This adds an `emitc.conditional` operation for the ternary conditional operator. Furthermore, this adds a converion from `arith.select` to the new op.
@llvm/pr-subscribers-mlir-emitc @llvm/pr-subscribers-mlir Author: Marius Brehler (marbre) ChangesThis adds an Full diff: https://github.com/llvm/llvm-project/pull/84883.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index ac1e38a5506da0..ec842f76628c08 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -908,6 +908,36 @@ def EmitC_SubOp : EmitC_BinaryOp<"sub", [CExpression]> {
let hasVerifier = 1;
}
+def EmitC_ConditionalOp : EmitC_Op<"conditional",
+ [AllTypesMatch<["true_value", "false_value", "result"]>, CExpression]> {
+ let summary = "Conditional (ternary) operation";
+ let description = [{
+ With the `conditional` operation the ternary conditional operator can
+ be applied.
+
+ Example:
+
+ ```mlir
+ %0 = emitc.cmp gt, %arg0, %arg1 : (i32, i32) -> i1
+
+ %c0 = "emitc.constant"() {value = 10 : i32} : () -> i32
+ %c1 = "emitc.constant"() {value = 11 : i32} : () -> i32
+
+ %1 = emitc.conditional %0, %c0, %c1 : i32
+ ```
+ ```c++
+ // Code emitted for the operations above.
+ bool v3 = v1 > v2;
+ int32_t v4 = 10;
+ int32_t v5 = 11;
+ int32_t v6 = v3 ? v4 : v5;
+ ```
+ }];
+ let arguments = (ins I1:$condition, AnyType:$true_value, AnyType:$false_value);
+ let results = (outs AnyType:$result);
+ let assemblyFormat = "operands attr-dict `:` type($result)";
+}
+
def EmitC_UnaryMinusOp : EmitC_UnaryOp<"unary_minus", [CExpression]> {
let summary = "Unary minus operation";
let description = [{
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index 40dce001a3b224..afef4ac1bfc5f9 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -54,6 +54,26 @@ class ArithOpConversion final : public OpConversionPattern<ArithOp> {
return success();
}
};
+
+class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
+public:
+ using OpConversionPattern<arith::SelectOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::SelectOp selectOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ if (!selectOp.getCondition().getType().isInteger())
+ return rewriter.notifyMatchFailure(
+ selectOp, "can only converted if condition is a scalar of type i1");
+
+ rewriter.replaceOpWithNewOp<emitc::ConditionalOp>(
+ selectOp, selectOp.getType(), adaptor.getOperands());
+
+ return success();
+ }
+};
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -70,7 +90,8 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
ArithOpConversion<arith::AddFOp, emitc::AddOp>,
ArithOpConversion<arith::DivFOp, emitc::DivOp>,
ArithOpConversion<arith::MulFOp, emitc::MulOp>,
- ArithOpConversion<arith::SubFOp, emitc::SubOp>
+ ArithOpConversion<arith::SubFOp, emitc::SubOp>,
+ SelectOpConversion
>(typeConverter, ctx);
// clang-format on
}
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 3cf137c1d07c0e..7cbb1e9265e174 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -96,6 +96,7 @@ static FailureOr<int> getOperatorPrecedence(Operation *operation) {
}
return op->emitError("unsupported cmp predicate");
})
+ .Case<emitc::ConditionalOp>([&](auto op) { return 2; })
.Case<emitc::DivOp>([&](auto op) { return 13; })
.Case<emitc::LogicalAndOp>([&](auto op) { return 4; })
.Case<emitc::LogicalNotOp>([&](auto op) { return 15; })
@@ -446,6 +447,29 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::CmpOp cmpOp) {
return printBinaryOperation(emitter, operation, binaryOperator);
}
+static LogicalResult printOperation(CppEmitter &emitter,
+ emitc::ConditionalOp conditionalOp) {
+ raw_ostream &os = emitter.ostream();
+
+ if (failed(emitter.emitAssignPrefix(*conditionalOp)))
+ return failure();
+
+ if (failed(emitter.emitOperand(conditionalOp.getCondition())))
+ return failure();
+
+ os << " ? ";
+
+ if (failed(emitter.emitOperand(conditionalOp.getTrueValue())))
+ return failure();
+
+ os << " : ";
+
+ if (failed(emitter.emitOperand(conditionalOp.getFalseValue())))
+ return failure();
+
+ return success();
+}
+
static LogicalResult printOperation(CppEmitter &emitter,
emitc::VerbatimOp verbatimOp) {
raw_ostream &os = emitter.ostream();
@@ -1383,12 +1407,13 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
emitc::BitwiseNotOp, emitc::BitwiseOrOp,
emitc::BitwiseRightShiftOp, emitc::BitwiseXorOp, emitc::CallOp,
emitc::CallOpaqueOp, emitc::CastOp, emitc::CmpOp,
- emitc::ConstantOp, emitc::DeclareFuncOp, emitc::DivOp,
- emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp, emitc::IfOp,
- emitc::IncludeOp, emitc::LogicalAndOp, emitc::LogicalNotOp,
- emitc::LogicalOrOp, emitc::MulOp, emitc::RemOp, emitc::ReturnOp,
- emitc::SubOp, emitc::UnaryMinusOp, emitc::UnaryPlusOp,
- emitc::VariableOp, emitc::VerbatimOp>(
+ emitc::ConditionalOp, emitc::ConstantOp, emitc::DeclareFuncOp,
+ emitc::DivOp, emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp,
+ emitc::IfOp, emitc::IncludeOp, emitc::LogicalAndOp,
+ emitc::LogicalNotOp, emitc::LogicalOrOp, emitc::MulOp,
+ emitc::RemOp, emitc::ReturnOp, emitc::SubOp,
+ emitc::UnaryMinusOp, emitc::UnaryPlusOp, emitc::VariableOp,
+ emitc::VerbatimOp>(
[&](auto op) { return printOperation(*this, op); })
// Func ops.
.Case<func::CallOp, func::FuncOp, func::ReturnOp>(
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index 2886810c01e917..022530ef4db84b 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -34,3 +34,11 @@ func.func @arith_ops(%arg0: f32, %arg1: f32) {
return
}
+
+// -----
+
+func.func @arith_select(%arg0: i1, %arg1: tensor<8xi32>, %arg2: tensor<8xi32>) -> () {
+ // CHECK: [[V0:[^ ]*]] = emitc.conditional %arg0, %arg1, %arg2 : tensor<8xi32>
+ %0 = arith.select %arg0, %arg1, %arg2 : i1, tensor<8xi32>
+ return
+}
diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir
index 122b1d9ef1059f..5f00a295ed740e 100644
--- a/mlir/test/Dialect/EmitC/ops.mlir
+++ b/mlir/test/Dialect/EmitC/ops.mlir
@@ -71,6 +71,11 @@ func.func @bitwise(%arg0: i32, %arg1: i32) -> () {
return
}
+func.func @cond(%cond: i1, %arg0: i32, %arg1: i32) -> () {
+ %0 = emitc.conditional %cond, %arg0, %arg1 : i32
+ return
+}
+
func.func @div_int(%arg0: i32, %arg1: i32) {
%1 = "emitc.div" (%arg0, %arg1) : (i32, i32) -> i32
return
diff --git a/mlir/test/Target/Cpp/conditional.mlir b/mlir/test/Target/Cpp/conditional.mlir
new file mode 100644
index 00000000000000..2470fbeb33adae
--- /dev/null
+++ b/mlir/test/Target/Cpp/conditional.mlir
@@ -0,0 +1,9 @@
+// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s
+
+func.func @cond(%cond: i1, %arg0: i32, %arg1: i32) -> () {
+ %0 = emitc.conditional %cond, %arg0, %arg1 : i32
+ return
+}
+
+// CHECK-LABEL: void cond
+// CHECK-NEXT: int32_t [[V3:[^ ]*]] = [[V0:[^ ]*]] ? [[V1:[^ ]*]] : [[V2:[^ ]*]];
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd replace the auto
with Type
, otherwise looks good.
@@ -70,7 +95,8 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter, | |||
ArithOpConversion<arith::AddFOp, emitc::AddOp>, | |||
ArithOpConversion<arith::DivFOp, emitc::DivOp>, | |||
ArithOpConversion<arith::MulFOp, emitc::MulOp>, | |||
ArithOpConversion<arith::SubFOp, emitc::SubOp> | |||
ArithOpConversion<arith::SubFOp, emitc::SubOp>, | |||
SelectOpConversion |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SelectOpConversion | |
ArithSelectOpConversion |
Maybe rename it such that is more consistent to the rest?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The intend behind the naming ArithOp
is that this is a conversion for a "genric" arith op. The ArithToEmitC
already indicates that this file only contains conversions from the Arith to the EmitC dialect.
This adds an `emitc.conditional` operation for the ternary conditional operator. Furthermore, this adds a converion from `arith.select` to the new op.
This adds an `emitc.conditional` operation for the ternary conditional operator. Furthermore, this adds a converion from `arith.select` to the new op.
[mlir][EmitC] Add an `emitc.conditional` operator (llvm#84883)
This adds an
emitc.conditional
operation for the ternary conditional operator. Furthermore, this adds a converion fromarith.select
to the new op.