From 7630379156ec08c9d7b1ea3c03c09e7dc89ef4ee Mon Sep 17 00:00:00 2001 From: Corentin Ferry Date: Wed, 22 May 2024 16:33:37 +0200 Subject: [PATCH] [mlir][emitc] Add EmitC lowering for arith.trunci, arith.extsi, arith.extui This commit adds conversion to EmitC for arith dialect casts between integer types (trunc, extsi, extui), excluding indexes for now. --- .../Conversion/ArithToEmitC/ArithToEmitC.cpp | 92 +++++++++++++++++++ .../arith-to-emitc-unsupported.mlir | 7 ++ .../ArithToEmitC/arith-to-emitc.mlir | 63 +++++++++++++ 3 files changed, 162 insertions(+) diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp index 1447b182ccfdbc..0be3d76f556de9 100644 --- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp +++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Tools/PDLL/AST/Types.h" #include "mlir/Transforms/DialectConversion.h" using namespace mlir; @@ -112,6 +113,93 @@ class CmpIOpConversion : public OpConversionPattern { } }; +template +class CastConversion : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + Type opReturnType = this->getTypeConverter()->convertType(op.getType()); + if (!isa_and_nonnull(opReturnType)) + return rewriter.notifyMatchFailure(op, "expected integer result type"); + + if (adaptor.getOperands().size() != 1) { + return rewriter.notifyMatchFailure( + op, "CastConversion only supports unary ops"); + } + + Type operandType = adaptor.getIn().getType(); + if (!isa_and_nonnull(operandType)) + return rewriter.notifyMatchFailure(op, "expected integer operand type"); + + // Signed (sign-extending) casts from i1 are not supported. + if (operandType.isInteger(1) && !castToUnsigned) + return rewriter.notifyMatchFailure(op, + "operation not supported on i1 type"); + + // to-i1 conversions: arith semantics want truncation, whereas (bool)(v) is + // equivalent to (v != 0). Implementing as (bool)(v & 0x01) gives + // truncation. + if (opReturnType.isInteger(1)) { + auto constOne = rewriter.create( + op.getLoc(), operandType, rewriter.getIntegerAttr(operandType, 1)); + auto oneAndOperand = rewriter.create( + op.getLoc(), operandType, adaptor.getIn(), constOne); + rewriter.replaceOpWithNewOp(op, opReturnType, + oneAndOperand); + return success(); + } + + bool isTruncation = operandType.getIntOrFloatBitWidth() > + opReturnType.getIntOrFloatBitWidth(); + bool doUnsigned = castToUnsigned || isTruncation; + + Type castType = opReturnType; + // If the op is a ui variant and the type wanted as + // return type isn't unsigned, we need to issue an unsigned type to do + // the conversion. + if (castType.isUnsignedInteger() != doUnsigned) { + castType = rewriter.getIntegerType(opReturnType.getIntOrFloatBitWidth(), + /*isSigned=*/!doUnsigned); + } + + Value actualOp = adaptor.getIn(); + // Adapt the signedness of the operand if necessary + if (operandType.isUnsignedInteger() != doUnsigned) { + Type correctSignednessType = + rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(), + /*isSigned=*/!doUnsigned); + actualOp = rewriter.template create( + op.getLoc(), correctSignednessType, actualOp); + } + + auto result = rewriter.template create(op.getLoc(), castType, + actualOp); + + // Cast to the expected output type + if (castType != opReturnType) { + result = rewriter.template create(op.getLoc(), + opReturnType, result); + } + + rewriter.replaceOp(op, result); + return success(); + } +}; + +template +class UnsignedCastConversion : public CastConversion { + using CastConversion::CastConversion; +}; + +template +class SignedCastConversion : public CastConversion { + using CastConversion::CastConversion; +}; + template class ArithOpConversion final : public OpConversionPattern { public: @@ -313,6 +401,10 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter, IntegerOpConversion, CmpIOpConversion, SelectOpConversion, + // Truncation is guaranteed for unsigned types. + UnsignedCastConversion, + SignedCastConversion, + UnsignedCastConversion, ItoFCastOpConversion, ItoFCastOpConversion, FtoICastOpConversion, diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir index 66dfa8fa3e157e..97e4593f97b903 100644 --- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir +++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir @@ -63,3 +63,10 @@ func.func @arith_cast_fptoui_i1(%arg0: f32) -> i1 { return %t: i1 } +// ----- + +func.func @arith_extsi_i1_to_i32(%arg0: i1) { + // expected-error @+1 {{failed to legalize operation 'arith.extsi'}} + %idx = arith.extsi %arg0 : i1 to i32 + return +} diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir index 79fecd61494d0d..b453b69a214e86 100644 --- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir +++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir @@ -177,3 +177,66 @@ func.func @arith_int_to_float_cast_ops(%arg0: i8, %arg1: i64) { return } + +// ----- + +func.func @arith_trunci(%arg0: i32) -> i8 { + // CHECK-LABEL: arith_trunci + // CHECK-SAME: (%[[Arg0:[^ ]*]]: i32) + // CHECK: %[[CastUI:.*]] = emitc.cast %[[Arg0]] : i32 to ui32 + // CHECK: %[[Trunc:.*]] = emitc.cast %[[CastUI]] : ui32 to ui8 + // CHECK: emitc.cast %[[Trunc]] : ui8 to i8 + %truncd = arith.trunci %arg0 : i32 to i8 + + return %truncd : i8 +} + +// ----- + +func.func @arith_trunci_to_i1(%arg0: i32) -> i1 { + // CHECK-LABEL: arith_trunci_to_i1 + // CHECK-SAME: (%[[Arg0:[^ ]*]]: i32) + // CHECK: %[[Const:.*]] = "emitc.constant" + // CHECK-SAME: value = 1 + // CHECK: %[[And:.*]] = emitc.bitwise_and %[[Arg0]], %[[Const]] : (i32, i32) -> i32 + // CHECK: emitc.cast %[[And]] : i32 to i1 + %truncd = arith.trunci %arg0 : i32 to i1 + + return %truncd : i1 +} + +// ----- + +func.func @arith_extsi(%arg0: i32) { + // CHECK-LABEL: arith_extsi + // CHECK-SAME: ([[Arg0:[^ ]*]]: i32) + // CHECK: emitc.cast [[Arg0]] : i32 to i64 + %extd = arith.extsi %arg0 : i32 to i64 + + return +} + +// ----- + +func.func @arith_extui(%arg0: i32) { + // CHECK-LABEL: arith_extui + // CHECK-SAME: (%[[Arg0:[^ ]*]]: i32) + // CHECK: %[[Conv0:.*]] = emitc.cast %[[Arg0]] : i32 to ui32 + // CHECK: %[[Conv1:.*]] = emitc.cast %[[Conv0]] : ui32 to ui64 + // CHECK: emitc.cast %[[Conv1]] : ui64 to i64 + %extd = arith.extui %arg0 : i32 to i64 + + return +} + +// ----- + +func.func @arith_extui_i1_to_i32(%arg0: i1) { + // CHECK-LABEL: arith_extui_i1_to_i32 + // CHECK-SAME: (%[[Arg0:[^ ]*]]: i1) + // CHECK: %[[Conv0:.*]] = emitc.cast %[[Arg0]] : i1 to ui1 + // CHECK: %[[Conv1:.*]] = emitc.cast %[[Conv0]] : ui1 to ui32 + // CHECK: emitc.cast %[[Conv1]] : ui32 to i32 + %idx = arith.extui %arg0 : i1 to i32 + return +}