diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp index f453366c53c907f..ba279aee16fb7bf 100644 --- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp +++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp @@ -96,6 +96,41 @@ class CmpFOpConversion : public OpConversionPattern { bool unordered = false; emitc::CmpPredicate predicate; switch (op.getPredicate()) { + case arith::CmpFPredicate::AlwaysFalse: { + auto constant = rewriter.create( + op.getLoc(), rewriter.getI1Type(), + rewriter.getBoolAttr(/*value=*/false)); + rewriter.replaceOp(op, constant); + return success(); + } + case arith::CmpFPredicate::OEQ: + unordered = false; + predicate = emitc::CmpPredicate::eq; + break; + case arith::CmpFPredicate::OGT: + // ordered and greater than + unordered = false; + predicate = emitc::CmpPredicate::gt; + break; + case arith::CmpFPredicate::OGE: + unordered = false; + predicate = emitc::CmpPredicate::ge; + break; + case arith::CmpFPredicate::OLT: + unordered = false; + predicate = emitc::CmpPredicate::lt; + break; + case arith::CmpFPredicate::ONE: + unordered = false; + predicate = emitc::CmpPredicate::ne; + break; + case arith::CmpFPredicate::ORD: { + // ordered, i.e. none of the operands is NaN + auto cmp = createCheckIsOrdered(rewriter, op.getLoc(), adaptor.getLhs(), + adaptor.getRhs()); + rewriter.replaceOp(op, cmp); + return success(); + } case arith::CmpFPredicate::UEQ: // unordered or equal unordered = true; @@ -121,6 +156,10 @@ class CmpFOpConversion : public OpConversionPattern { unordered = true; predicate = emitc::CmpPredicate::le; break; + case arith::CmpFPredicate::UNE: + unordered = true; + predicate = emitc::CmpPredicate::ne; + break; case arith::CmpFPredicate::UNO: { // unordered, i.e. either operand is nan auto cmp = createCheckIsUnordered(rewriter, op.getLoc(), adaptor.getLhs(), @@ -128,11 +167,13 @@ class CmpFOpConversion : public OpConversionPattern { rewriter.replaceOp(op, cmp); return success(); } - case arith::CmpFPredicate::OGT: - // ordered and greater than - unordered = false; - predicate = emitc::CmpPredicate::gt; - break; + case arith::CmpFPredicate::AlwaysTrue: { + auto constant = rewriter.create( + op.getLoc(), rewriter.getI1Type(), + rewriter.getBoolAttr(/*value=*/true)); + rewriter.replaceOp(op, constant); + return success(); + } default: return rewriter.notifyMatchFailure(op.getLoc(), "cannot match predicate "); diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir index 123d3cf0fa77d6e..37427ef70cacc2e 100644 --- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir +++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir @@ -16,14 +16,6 @@ func.func @arith_cmpf_vector(%arg0: vector<5xf32>, %arg1: vector<5xf32>) -> vect // ----- -func.func @arith_cmpf_ordered(%arg0: f32, %arg1: f32) -> i1 { - // expected-error @+1 {{failed to legalize operation 'arith.cmpf'}} - %oge = arith.cmpf oge, %arg0, %arg1 : f32 - return %oge: i1 -} - -// ----- - func.func @arith_cast_f32(%arg0: f32) -> i32 { // expected-error @+1 {{failed to legalize operation 'arith.fptosi'}} %t = arith.fptosi %arg0 : f32 to i32 diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir index 75ca89407fe5fd0..9fdbeeaa21c5c09 100644 --- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir +++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir @@ -58,6 +58,105 @@ func.func @arith_select(%arg0: i1, %arg1: tensor<8xi32>, %arg2: tensor<8xi32>) - // ----- +func.func @arith_cmpf_false(%arg0: f32, %arg1: f32) -> i1 { + // CHECK-LABEL: arith_cmpf_false + // CHECK-SAME: ([[Arg0:[^ ]*]]: f32, [[Arg1:[^ ]*]]: f32) + // CHECK-DAG: [[False:[^ ]*]] = "emitc.constant"() <{value = false}> : () -> i1 + %ueq = arith.cmpf false, %arg0, %arg1 : f32 + // CHECK: return [[False]] + return %ueq: i1 +} + +// ----- + +func.func @arith_cmpf_oeq(%arg0: f32, %arg1: f32) -> i1 { + // CHECK-LABEL: arith_cmpf_oeq + // CHECK-SAME: ([[Arg0:[^ ]*]]: f32, [[Arg1:[^ ]*]]: f32) + // CHECK-DAG: [[EQ:[^ ]*]] = emitc.cmp eq, [[Arg0]], [[Arg1]] : (f32, f32) -> i1 + // CHECK-DAG: [[NaNArg0:[^ ]*]] = emitc.cmp eq, [[Arg0]], [[Arg0]] : (f32, f32) -> i1 + // CHECK-DAG: [[NaNArg1:[^ ]*]] = emitc.cmp eq, [[Arg1]], [[Arg1]] : (f32, f32) -> i1 + // CHECK-DAG: [[Ordered:[^ ]*]] = emitc.logical_and [[NaNArg0]], [[NaNArg1]] : i1, i1 + // CHECK-DAG: [[OEQ:[^ ]*]] = emitc.logical_and [[Ordered]], [[EQ]] : i1, i1 + %ueq = arith.cmpf oeq, %arg0, %arg1 : f32 + // CHECK: return [[OEQ]] + return %ueq: i1 +} + +// ----- + +func.func @arith_cmpf_ogt(%arg0: f32, %arg1: f32) -> i1 { + // CHECK-LABEL: arith_cmpf_ogt + // CHECK-SAME: ([[Arg0:[^ ]*]]: f32, [[Arg1:[^ ]*]]: f32) + // CHECK-DAG: [[GT:[^ ]*]] = emitc.cmp gt, [[Arg0]], [[Arg1]] : (f32, f32) -> i1 + // CHECK-DAG: [[OrderedArg0:[^ ]*]] = emitc.cmp eq, [[Arg0]], [[Arg0]] : (f32, f32) -> i1 + // CHECK-DAG: [[OrderedArg1:[^ ]*]] = emitc.cmp eq, [[Arg1]], [[Arg1]] : (f32, f32) -> i1 + // CHECK-DAG: [[Ordered:[^ ]*]] = emitc.logical_and [[OrderedArg0]], [[OrderedArg1]] : i1, i1 + // CHECK-DAG: [[OGT:[^ ]*]] = emitc.logical_and [[Ordered]], [[GT]] : i1, i1 + %ogt = arith.cmpf ogt, %arg0, %arg1 : f32 + // CHECK: return [[OGT]] + return %ogt: i1 +} + +// ----- + +func.func @arith_cmpf_oge(%arg0: f32, %arg1: f32) -> i1 { + // CHECK-LABEL: arith_cmpf_oge + // CHECK-SAME: ([[Arg0:[^ ]*]]: f32, [[Arg1:[^ ]*]]: f32) + // CHECK-DAG: [[GE:[^ ]*]] = emitc.cmp ge, [[Arg0]], [[Arg1]] : (f32, f32) -> i1 + // CHECK-DAG: [[NaNArg0:[^ ]*]] = emitc.cmp eq, [[Arg0]], [[Arg0]] : (f32, f32) -> i1 + // CHECK-DAG: [[NaNArg1:[^ ]*]] = emitc.cmp eq, [[Arg1]], [[Arg1]] : (f32, f32) -> i1 + // CHECK-DAG: [[Ordered:[^ ]*]] = emitc.logical_and [[NaNArg0]], [[NaNArg1]] : i1, i1 + // CHECK-DAG: [[OGE:[^ ]*]] = emitc.logical_and [[Ordered]], [[GE]] : i1, i1 + %ueq = arith.cmpf oge, %arg0, %arg1 : f32 + // CHECK: return [[OGE]] + return %ueq: i1 +} + +// ----- + +func.func @arith_cmpf_olt(%arg0: f32, %arg1: f32) -> i1 { + // CHECK-LABEL: arith_cmpf_olt + // CHECK-SAME: ([[Arg0:[^ ]*]]: f32, [[Arg1:[^ ]*]]: f32) + // CHECK-DAG: [[LT:[^ ]*]] = emitc.cmp lt, [[Arg0]], [[Arg1]] : (f32, f32) -> i1 + // CHECK-DAG: [[NaNArg0:[^ ]*]] = emitc.cmp eq, [[Arg0]], [[Arg0]] : (f32, f32) -> i1 + // CHECK-DAG: [[NaNArg1:[^ ]*]] = emitc.cmp eq, [[Arg1]], [[Arg1]] : (f32, f32) -> i1 + // CHECK-DAG: [[Ordered:[^ ]*]] = emitc.logical_and [[NaNArg0]], [[NaNArg1]] : i1, i1 + // CHECK-DAG: [[UEQ:[^ ]*]] = emitc.logical_and [[Ordered]], [[LT]] : i1, i1 + %ueq = arith.cmpf olt, %arg0, %arg1 : f32 + // CHECK: return [[UEQ]] + return %ueq: i1 +} + +// ----- + +func.func @arith_cmpf_one(%arg0: f32, %arg1: f32) -> i1 { + // CHECK-LABEL: arith_cmpf_one + // CHECK-SAME: ([[Arg0:[^ ]*]]: f32, [[Arg1:[^ ]*]]: f32) + // CHECK-DAG: [[NEQ:[^ ]*]] = emitc.cmp ne, [[Arg0]], [[Arg1]] : (f32, f32) -> i1 + // CHECK-DAG: [[NaNArg0:[^ ]*]] = emitc.cmp eq, [[Arg0]], [[Arg0]] : (f32, f32) -> i1 + // CHECK-DAG: [[NaNArg1:[^ ]*]] = emitc.cmp eq, [[Arg1]], [[Arg1]] : (f32, f32) -> i1 + // CHECK-DAG: [[Ordered:[^ ]*]] = emitc.logical_and [[NaNArg0]], [[NaNArg1]] : i1, i1 + // CHECK-DAG: [[ONE:[^ ]*]] = emitc.logical_and [[Ordered]], [[NEQ]] : i1, i1 + %ueq = arith.cmpf one, %arg0, %arg1 : f32 + // CHECK: return [[ONE]] + return %ueq: i1 +} + +// ----- + +func.func @arith_cmpf_ord(%arg0: f32, %arg1: f32) -> i1 { + // CHECK-LABEL: arith_cmpf_ord + // CHECK-SAME: ([[Arg0:[^ ]*]]: f32, [[Arg1:[^ ]*]]: f32) + // CHECK-DAG: [[NaNArg0:[^ ]*]] = emitc.cmp eq, [[Arg0]], [[Arg0]] : (f32, f32) -> i1 + // CHECK-DAG: [[NaNArg1:[^ ]*]] = emitc.cmp eq, [[Arg1]], [[Arg1]] : (f32, f32) -> i1 + // CHECK-DAG: [[Ordered:[^ ]*]] = emitc.logical_and [[NaNArg0]], [[NaNArg1]] : i1, i1 + %ueq = arith.cmpf ord, %arg0, %arg1 : f32 + // CHECK: return [[Ordered]] + return %ueq: i1 +} + +// ----- + func.func @arith_cmpf_ueq(%arg0: f32, %arg1: f32) -> i1 { // CHECK-LABEL: arith_cmpf_ueq // CHECK-SAME: ([[Arg0:[^ ]*]]: f32, [[Arg1:[^ ]*]]: f32) @@ -133,30 +232,41 @@ func.func @arith_cmpf_ule(%arg0: f32, %arg1: f32) -> i1 { // ----- +func.func @arith_cmpf_une(%arg0: f32, %arg1: f32) -> i1 { + // CHECK-LABEL: arith_cmpf_une + // CHECK-SAME: ([[Arg0:[^ ]*]]: f32, [[Arg1:[^ ]*]]: f32) + // CHECK-DAG: [[NEQ:[^ ]*]] = emitc.cmp ne, [[Arg0]], [[Arg1]] : (f32, f32) -> i1 + // CHECK-DAG: [[NaNArg0:[^ ]*]] = emitc.cmp ne, [[Arg0]], [[Arg0]] : (f32, f32) -> i1 + // CHECK-DAG: [[NaNArg1:[^ ]*]] = emitc.cmp ne, [[Arg1]], [[Arg1]] : (f32, f32) -> i1 + // CHECK-DAG: [[Unordered:[^ ]*]] = emitc.logical_or [[NaNArg0]], [[NaNArg1]] : i1, i1 + // CHECK-DAG: [[UNE:[^ ]*]] = emitc.logical_or [[Unordered]], [[NEQ]] : i1, i1 + %une = arith.cmpf une, %arg0, %arg1 : f32 + // CHECK: return [[UNE]] + return %une: i1 +} + +// ----- + func.func @arith_cmpf_uno(%arg0: f32, %arg1: f32) -> i1 { // CHECK-LABEL: arith_cmpf_uno // CHECK-SAME: ([[Arg0:[^ ]*]]: f32, [[Arg1:[^ ]*]]: f32) // CHECK-DAG: [[NaNArg0:[^ ]*]] = emitc.cmp ne, [[Arg0]], [[Arg0]] : (f32, f32) -> i1 // CHECK-DAG: [[NaNArg1:[^ ]*]] = emitc.cmp ne, [[Arg1]], [[Arg1]] : (f32, f32) -> i1 // CHECK-DAG: [[Unordered:[^ ]*]] = emitc.logical_or [[NaNArg0]], [[NaNArg1]] : i1, i1 - %2 = arith.cmpf uno, %arg0, %arg1 : f32 + %uno = arith.cmpf uno, %arg0, %arg1 : f32 // CHECK: return [[Unordered]] - return %2: i1 + return %uno: i1 } // ----- -func.func @arith_cmpf_ogt(%arg0: f32, %arg1: f32) -> i1 { - // CHECK-LABEL: arith_cmpf_ogt +func.func @arith_cmpf_true(%arg0: f32, %arg1: f32) -> i1 { + // CHECK-LABEL: arith_cmpf_true // CHECK-SAME: ([[Arg0:[^ ]*]]: f32, [[Arg1:[^ ]*]]: f32) - // CHECK-DAG: [[GT:[^ ]*]] = emitc.cmp gt, [[Arg0]], [[Arg1]] : (f32, f32) -> i1 - // CHECK-DAG: [[OrderedArg0:[^ ]*]] = emitc.cmp eq, [[Arg0]], [[Arg0]] : (f32, f32) -> i1 - // CHECK-DAG: [[OrderedArg1:[^ ]*]] = emitc.cmp eq, [[Arg1]], [[Arg1]] : (f32, f32) -> i1 - // CHECK-DAG: [[Ordered:[^ ]*]] = emitc.logical_and [[OrderedArg0]], [[OrderedArg1]] : i1, i1 - // CHECK-DAG: [[OGT:[^ ]*]] = emitc.logical_and [[Ordered]], [[GT]] : i1, i1 - %ule = arith.cmpf ogt, %arg0, %arg1 : f32 - // CHECK: return [[OGT]] - return %ule: i1 + // CHECK-DAG: [[True:[^ ]*]] = "emitc.constant"() <{value = true}> : () -> i1 + %ueq = arith.cmpf true, %arg0, %arg1 : f32 + // CHECK: return [[True]] + return %ueq: i1 } // -----