From 7a17cb73c383106d4f66fb6787d33cb335ad7096 Mon Sep 17 00:00:00 2001 From: Jose Lopes Date: Thu, 11 Apr 2024 14:34:22 +0100 Subject: [PATCH 1/2] Review --- .../Conversion/ArithToEmitC/ArithToEmitC.cpp | 55 ++++++- .../arith-to-emitc-unsupported.mlir | 8 -- .../ArithToEmitC/arith-to-emitc.mlir | 134 ++++++++++++++++-- 3 files changed, 172 insertions(+), 25 deletions(-) diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp index f453366c53c907..a58a9645401a35 100644 --- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp +++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp @@ -19,6 +19,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/APInt.h" using namespace mlir; @@ -79,6 +80,13 @@ Value createCheckIsOrdered(ConversionPatternRewriter &rewriter, Location loc, firstIsNaN, secondIsNaN); } +emitc::ConstantOp getConstant(OpBuilder rewriter, Location loc, + llvm::APInt val) { + auto type = rewriter.getIntegerType(val.getBitWidth()); + return rewriter.create(loc, type, + rewriter.getIntegerAttr(type, val)); +} + class CmpFOpConversion : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -96,6 +104,39 @@ class CmpFOpConversion : public OpConversionPattern { bool unordered = false; emitc::CmpPredicate predicate; switch (op.getPredicate()) { + case arith::CmpFPredicate::AlwaysFalse: { + auto constant = getConstant(rewriter, op->getLoc(), llvm::APInt(1, 0)); + rewriter.replaceOp(op, constant); + return success(); + } + case arith::CmpFPredicate::OEQ: + unordered = false; + predicate = emitc::CmpPredicate::eq; + break; + case arith::CmpFPredicate::OGT: + // ordered and greater than + unordered = false; + predicate = emitc::CmpPredicate::gt; + break; + case arith::CmpFPredicate::OGE: + unordered = false; + predicate = emitc::CmpPredicate::ge; + break; + case arith::CmpFPredicate::OLT: + unordered = false; + predicate = emitc::CmpPredicate::lt; + break; + case arith::CmpFPredicate::ONE: + unordered = false; + predicate = emitc::CmpPredicate::ne; + break; + case arith::CmpFPredicate::ORD: { + // ordered, i.e. none of the operands is NaN + auto cmp = createCheckIsOrdered(rewriter, op.getLoc(), adaptor.getLhs(), + adaptor.getRhs()); + rewriter.replaceOp(op, cmp); + return success(); + } case arith::CmpFPredicate::UEQ: // unordered or equal unordered = true; @@ -121,6 +162,10 @@ class CmpFOpConversion : public OpConversionPattern { unordered = true; predicate = emitc::CmpPredicate::le; break; + case arith::CmpFPredicate::UNE: + unordered = true; + predicate = emitc::CmpPredicate::ne; + break; case arith::CmpFPredicate::UNO: { // unordered, i.e. either operand is nan auto cmp = createCheckIsUnordered(rewriter, op.getLoc(), adaptor.getLhs(), @@ -128,11 +173,11 @@ class CmpFOpConversion : public OpConversionPattern { rewriter.replaceOp(op, cmp); return success(); } - case arith::CmpFPredicate::OGT: - // ordered and greater than - unordered = false; - predicate = emitc::CmpPredicate::gt; - break; + case arith::CmpFPredicate::AlwaysTrue: { + auto constant = getConstant(rewriter, op->getLoc(), llvm::APInt(1, 1)); + rewriter.replaceOp(op, constant); + return success(); + } default: return rewriter.notifyMatchFailure(op.getLoc(), "cannot match predicate "); diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir index 123d3cf0fa77d6..37427ef70cacc2 100644 --- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir +++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir @@ -16,14 +16,6 @@ func.func @arith_cmpf_vector(%arg0: vector<5xf32>, %arg1: vector<5xf32>) -> vect // ----- -func.func @arith_cmpf_ordered(%arg0: f32, %arg1: f32) -> i1 { - // expected-error @+1 {{failed to legalize operation 'arith.cmpf'}} - %oge = arith.cmpf oge, %arg0, %arg1 : f32 - return %oge: i1 -} - -// ----- - func.func @arith_cast_f32(%arg0: f32) -> i32 { // expected-error @+1 {{failed to legalize operation 'arith.fptosi'}} %t = arith.fptosi %arg0 : f32 to i32 diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir index 75ca89407fe5fd..9fdbeeaa21c5c0 100644 --- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir +++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir @@ -58,6 +58,105 @@ func.func @arith_select(%arg0: i1, %arg1: tensor<8xi32>, %arg2: tensor<8xi32>) - // ----- +func.func @arith_cmpf_false(%arg0: f32, %arg1: f32) -> i1 { + // CHECK-LABEL: arith_cmpf_false + // CHECK-SAME: ([[Arg0:[^ ]*]]: f32, [[Arg1:[^ ]*]]: f32) + // CHECK-DAG: [[False:[^ ]*]] = "emitc.constant"() <{value = false}> : () -> i1 + %ueq = arith.cmpf false, %arg0, %arg1 : f32 + // CHECK: return [[False]] + return %ueq: i1 +} + +// ----- + +func.func @arith_cmpf_oeq(%arg0: f32, %arg1: f32) -> i1 { + // CHECK-LABEL: arith_cmpf_oeq + // CHECK-SAME: ([[Arg0:[^ ]*]]: f32, [[Arg1:[^ ]*]]: f32) + // CHECK-DAG: [[EQ:[^ ]*]] = emitc.cmp eq, [[Arg0]], [[Arg1]] : (f32, f32) -> i1 + // CHECK-DAG: [[NaNArg0:[^ ]*]] = emitc.cmp eq, [[Arg0]], [[Arg0]] : (f32, f32) -> i1 + // CHECK-DAG: [[NaNArg1:[^ ]*]] = emitc.cmp eq, [[Arg1]], [[Arg1]] : (f32, f32) -> i1 + // CHECK-DAG: [[Ordered:[^ ]*]] = emitc.logical_and [[NaNArg0]], [[NaNArg1]] : i1, i1 + // CHECK-DAG: [[OEQ:[^ ]*]] = emitc.logical_and [[Ordered]], [[EQ]] : i1, i1 + %ueq = arith.cmpf oeq, %arg0, %arg1 : f32 + // CHECK: return [[OEQ]] + return %ueq: i1 +} + +// ----- + +func.func @arith_cmpf_ogt(%arg0: f32, %arg1: f32) -> i1 { + // CHECK-LABEL: arith_cmpf_ogt + // CHECK-SAME: ([[Arg0:[^ ]*]]: f32, [[Arg1:[^ ]*]]: f32) + // CHECK-DAG: [[GT:[^ ]*]] = emitc.cmp gt, [[Arg0]], [[Arg1]] : (f32, f32) -> i1 + // CHECK-DAG: [[OrderedArg0:[^ ]*]] = emitc.cmp eq, [[Arg0]], [[Arg0]] : (f32, f32) -> i1 + // CHECK-DAG: [[OrderedArg1:[^ ]*]] = emitc.cmp eq, [[Arg1]], [[Arg1]] : (f32, f32) -> i1 + // CHECK-DAG: [[Ordered:[^ ]*]] = emitc.logical_and [[OrderedArg0]], [[OrderedArg1]] : i1, i1 + // CHECK-DAG: [[OGT:[^ ]*]] = emitc.logical_and [[Ordered]], [[GT]] : i1, i1 + %ogt = arith.cmpf ogt, %arg0, %arg1 : f32 + // CHECK: return [[OGT]] + return %ogt: i1 +} + +// ----- + +func.func @arith_cmpf_oge(%arg0: f32, %arg1: f32) -> i1 { + // CHECK-LABEL: arith_cmpf_oge + // CHECK-SAME: ([[Arg0:[^ ]*]]: f32, [[Arg1:[^ ]*]]: f32) + // CHECK-DAG: [[GE:[^ ]*]] = emitc.cmp ge, [[Arg0]], [[Arg1]] : (f32, f32) -> i1 + // CHECK-DAG: [[NaNArg0:[^ ]*]] = emitc.cmp eq, [[Arg0]], [[Arg0]] : (f32, f32) -> i1 + // CHECK-DAG: [[NaNArg1:[^ ]*]] = emitc.cmp eq, [[Arg1]], [[Arg1]] : (f32, f32) -> i1 + // CHECK-DAG: [[Ordered:[^ ]*]] = emitc.logical_and [[NaNArg0]], [[NaNArg1]] : i1, i1 + // CHECK-DAG: [[OGE:[^ ]*]] = emitc.logical_and [[Ordered]], [[GE]] : i1, i1 + %ueq = arith.cmpf oge, %arg0, %arg1 : f32 + // CHECK: return [[OGE]] + return %ueq: i1 +} + +// ----- + +func.func @arith_cmpf_olt(%arg0: f32, %arg1: f32) -> i1 { + // CHECK-LABEL: arith_cmpf_olt + // CHECK-SAME: ([[Arg0:[^ ]*]]: f32, [[Arg1:[^ ]*]]: f32) + // CHECK-DAG: [[LT:[^ ]*]] = emitc.cmp lt, [[Arg0]], [[Arg1]] : (f32, f32) -> i1 + // CHECK-DAG: [[NaNArg0:[^ ]*]] = emitc.cmp eq, [[Arg0]], [[Arg0]] : (f32, f32) -> i1 + // CHECK-DAG: [[NaNArg1:[^ ]*]] = emitc.cmp eq, [[Arg1]], [[Arg1]] : (f32, f32) -> i1 + // CHECK-DAG: [[Ordered:[^ ]*]] = emitc.logical_and [[NaNArg0]], [[NaNArg1]] : i1, i1 + // CHECK-DAG: [[UEQ:[^ ]*]] = emitc.logical_and [[Ordered]], [[LT]] : i1, i1 + %ueq = arith.cmpf olt, %arg0, %arg1 : f32 + // CHECK: return [[UEQ]] + return %ueq: i1 +} + +// ----- + +func.func @arith_cmpf_one(%arg0: f32, %arg1: f32) -> i1 { + // CHECK-LABEL: arith_cmpf_one + // CHECK-SAME: ([[Arg0:[^ ]*]]: f32, [[Arg1:[^ ]*]]: f32) + // CHECK-DAG: [[NEQ:[^ ]*]] = emitc.cmp ne, [[Arg0]], [[Arg1]] : (f32, f32) -> i1 + // CHECK-DAG: [[NaNArg0:[^ ]*]] = emitc.cmp eq, [[Arg0]], [[Arg0]] : (f32, f32) -> i1 + // CHECK-DAG: [[NaNArg1:[^ ]*]] = emitc.cmp eq, [[Arg1]], [[Arg1]] : (f32, f32) -> i1 + // CHECK-DAG: [[Ordered:[^ ]*]] = emitc.logical_and [[NaNArg0]], [[NaNArg1]] : i1, i1 + // CHECK-DAG: [[ONE:[^ ]*]] = emitc.logical_and [[Ordered]], [[NEQ]] : i1, i1 + %ueq = arith.cmpf one, %arg0, %arg1 : f32 + // CHECK: return [[ONE]] + return %ueq: i1 +} + +// ----- + +func.func @arith_cmpf_ord(%arg0: f32, %arg1: f32) -> i1 { + // CHECK-LABEL: arith_cmpf_ord + // CHECK-SAME: ([[Arg0:[^ ]*]]: f32, [[Arg1:[^ ]*]]: f32) + // CHECK-DAG: [[NaNArg0:[^ ]*]] = emitc.cmp eq, [[Arg0]], [[Arg0]] : (f32, f32) -> i1 + // CHECK-DAG: [[NaNArg1:[^ ]*]] = emitc.cmp eq, [[Arg1]], [[Arg1]] : (f32, f32) -> i1 + // CHECK-DAG: [[Ordered:[^ ]*]] = emitc.logical_and [[NaNArg0]], [[NaNArg1]] : i1, i1 + %ueq = arith.cmpf ord, %arg0, %arg1 : f32 + // CHECK: return [[Ordered]] + return %ueq: i1 +} + +// ----- + func.func @arith_cmpf_ueq(%arg0: f32, %arg1: f32) -> i1 { // CHECK-LABEL: arith_cmpf_ueq // CHECK-SAME: ([[Arg0:[^ ]*]]: f32, [[Arg1:[^ ]*]]: f32) @@ -133,30 +232,41 @@ func.func @arith_cmpf_ule(%arg0: f32, %arg1: f32) -> i1 { // ----- +func.func @arith_cmpf_une(%arg0: f32, %arg1: f32) -> i1 { + // CHECK-LABEL: arith_cmpf_une + // CHECK-SAME: ([[Arg0:[^ ]*]]: f32, [[Arg1:[^ ]*]]: f32) + // CHECK-DAG: [[NEQ:[^ ]*]] = emitc.cmp ne, [[Arg0]], [[Arg1]] : (f32, f32) -> i1 + // CHECK-DAG: [[NaNArg0:[^ ]*]] = emitc.cmp ne, [[Arg0]], [[Arg0]] : (f32, f32) -> i1 + // CHECK-DAG: [[NaNArg1:[^ ]*]] = emitc.cmp ne, [[Arg1]], [[Arg1]] : (f32, f32) -> i1 + // CHECK-DAG: [[Unordered:[^ ]*]] = emitc.logical_or [[NaNArg0]], [[NaNArg1]] : i1, i1 + // CHECK-DAG: [[UNE:[^ ]*]] = emitc.logical_or [[Unordered]], [[NEQ]] : i1, i1 + %une = arith.cmpf une, %arg0, %arg1 : f32 + // CHECK: return [[UNE]] + return %une: i1 +} + +// ----- + func.func @arith_cmpf_uno(%arg0: f32, %arg1: f32) -> i1 { // CHECK-LABEL: arith_cmpf_uno // CHECK-SAME: ([[Arg0:[^ ]*]]: f32, [[Arg1:[^ ]*]]: f32) // CHECK-DAG: [[NaNArg0:[^ ]*]] = emitc.cmp ne, [[Arg0]], [[Arg0]] : (f32, f32) -> i1 // CHECK-DAG: [[NaNArg1:[^ ]*]] = emitc.cmp ne, [[Arg1]], [[Arg1]] : (f32, f32) -> i1 // CHECK-DAG: [[Unordered:[^ ]*]] = emitc.logical_or [[NaNArg0]], [[NaNArg1]] : i1, i1 - %2 = arith.cmpf uno, %arg0, %arg1 : f32 + %uno = arith.cmpf uno, %arg0, %arg1 : f32 // CHECK: return [[Unordered]] - return %2: i1 + return %uno: i1 } // ----- -func.func @arith_cmpf_ogt(%arg0: f32, %arg1: f32) -> i1 { - // CHECK-LABEL: arith_cmpf_ogt +func.func @arith_cmpf_true(%arg0: f32, %arg1: f32) -> i1 { + // CHECK-LABEL: arith_cmpf_true // CHECK-SAME: ([[Arg0:[^ ]*]]: f32, [[Arg1:[^ ]*]]: f32) - // CHECK-DAG: [[GT:[^ ]*]] = emitc.cmp gt, [[Arg0]], [[Arg1]] : (f32, f32) -> i1 - // CHECK-DAG: [[OrderedArg0:[^ ]*]] = emitc.cmp eq, [[Arg0]], [[Arg0]] : (f32, f32) -> i1 - // CHECK-DAG: [[OrderedArg1:[^ ]*]] = emitc.cmp eq, [[Arg1]], [[Arg1]] : (f32, f32) -> i1 - // CHECK-DAG: [[Ordered:[^ ]*]] = emitc.logical_and [[OrderedArg0]], [[OrderedArg1]] : i1, i1 - // CHECK-DAG: [[OGT:[^ ]*]] = emitc.logical_and [[Ordered]], [[GT]] : i1, i1 - %ule = arith.cmpf ogt, %arg0, %arg1 : f32 - // CHECK: return [[OGT]] - return %ule: i1 + // CHECK-DAG: [[True:[^ ]*]] = "emitc.constant"() <{value = true}> : () -> i1 + %ueq = arith.cmpf true, %arg0, %arg1 : f32 + // CHECK: return [[True]] + return %ueq: i1 } // ----- From 3c8dc440d7114bf62769b2aaac2c5ad2da083ada Mon Sep 17 00:00:00 2001 From: Jose Lopes Date: Mon, 15 Apr 2024 09:57:07 +0100 Subject: [PATCH 2/2] Modify CmpFOpConversion Restrict "cmpf [true|false]" code to only allow returning i1 and nothing else. --- .../lib/Conversion/ArithToEmitC/ArithToEmitC.cpp | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp index a58a9645401a35..ba279aee16fb7b 100644 --- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp +++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp @@ -19,7 +19,6 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" -#include "llvm/ADT/APInt.h" using namespace mlir; @@ -80,13 +79,6 @@ Value createCheckIsOrdered(ConversionPatternRewriter &rewriter, Location loc, firstIsNaN, secondIsNaN); } -emitc::ConstantOp getConstant(OpBuilder rewriter, Location loc, - llvm::APInt val) { - auto type = rewriter.getIntegerType(val.getBitWidth()); - return rewriter.create(loc, type, - rewriter.getIntegerAttr(type, val)); -} - class CmpFOpConversion : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -105,7 +97,9 @@ class CmpFOpConversion : public OpConversionPattern { emitc::CmpPredicate predicate; switch (op.getPredicate()) { case arith::CmpFPredicate::AlwaysFalse: { - auto constant = getConstant(rewriter, op->getLoc(), llvm::APInt(1, 0)); + auto constant = rewriter.create( + op.getLoc(), rewriter.getI1Type(), + rewriter.getBoolAttr(/*value=*/false)); rewriter.replaceOp(op, constant); return success(); } @@ -174,7 +168,9 @@ class CmpFOpConversion : public OpConversionPattern { return success(); } case arith::CmpFPredicate::AlwaysTrue: { - auto constant = getConstant(rewriter, op->getLoc(), llvm::APInt(1, 1)); + auto constant = rewriter.create( + op.getLoc(), rewriter.getI1Type(), + rewriter.getBoolAttr(/*value=*/true)); rewriter.replaceOp(op, constant); return success(); }