From 3c8dc440d7114bf62769b2aaac2c5ad2da083ada Mon Sep 17 00:00:00 2001 From: Jose Lopes Date: Mon, 15 Apr 2024 09:57:07 +0100 Subject: [PATCH] Modify CmpFOpConversion Restrict "cmpf [true|false]" code to only allow returning i1 and nothing else. --- .../lib/Conversion/ArithToEmitC/ArithToEmitC.cpp | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp index a58a9645401a353..ba279aee16fb7bf 100644 --- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp +++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp @@ -19,7 +19,6 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" -#include "llvm/ADT/APInt.h" using namespace mlir; @@ -80,13 +79,6 @@ Value createCheckIsOrdered(ConversionPatternRewriter &rewriter, Location loc, firstIsNaN, secondIsNaN); } -emitc::ConstantOp getConstant(OpBuilder rewriter, Location loc, - llvm::APInt val) { - auto type = rewriter.getIntegerType(val.getBitWidth()); - return rewriter.create(loc, type, - rewriter.getIntegerAttr(type, val)); -} - class CmpFOpConversion : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -105,7 +97,9 @@ class CmpFOpConversion : public OpConversionPattern { emitc::CmpPredicate predicate; switch (op.getPredicate()) { case arith::CmpFPredicate::AlwaysFalse: { - auto constant = getConstant(rewriter, op->getLoc(), llvm::APInt(1, 0)); + auto constant = rewriter.create( + op.getLoc(), rewriter.getI1Type(), + rewriter.getBoolAttr(/*value=*/false)); rewriter.replaceOp(op, constant); return success(); } @@ -174,7 +168,9 @@ class CmpFOpConversion : public OpConversionPattern { return success(); } case arith::CmpFPredicate::AlwaysTrue: { - auto constant = getConstant(rewriter, op->getLoc(), llvm::APInt(1, 1)); + auto constant = rewriter.create( + op.getLoc(), rewriter.getI1Type(), + rewriter.getBoolAttr(/*value=*/true)); rewriter.replaceOp(op, constant); return success(); }