From bf83b3f4668eb440b185c059becc73a523668283 Mon Sep 17 00:00:00 2001 From: Corentin Ferry Date: Wed, 27 Mar 2024 08:58:15 +0000 Subject: [PATCH] [mlir][emitc] Arith to EmitC: handle FP<->Integer conversions --- .../Conversion/ArithToEmitC/ArithToEmitC.cpp | 94 ++++++++++++++++++- .../arith-to-emitc-unsupported.mlir | 48 ++++++++++ .../ArithToEmitC/arith-to-emitc.mlir | 36 +++++++ 3 files changed, 177 insertions(+), 1 deletion(-) create mode 100644 mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp index 9b2544276ce474..195d4d39cbdbe7 100644 --- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp +++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp @@ -201,6 +201,94 @@ class SelectOpConversion : public OpConversionPattern { } }; +// Floating-point to integer conversions. +template +class FtoICastOpConversion : public OpConversionPattern { +public: + FtoICastOpConversion(const TypeConverter &typeConverter, MLIRContext *context) + : OpConversionPattern(typeConverter, context) {} + + LogicalResult + matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + Type operandType = adaptor.getIn().getType(); + if (!emitc::isSupportedFloatType(operandType)) + return rewriter.notifyMatchFailure(castOp, + "unsupported cast source type"); + + Type dstType = this->getTypeConverter()->convertType(castOp.getType()); + if (!dstType) + return rewriter.notifyMatchFailure(castOp, "type conversion failed"); + + if (!emitc::isSupportedIntegerType(dstType)) + return rewriter.notifyMatchFailure(castOp, + "unsupported cast destination type"); + + // Convert to unsigned if it's the "ui" variant + // Signless is interpreted as signed, so no need to cast for "si" + Type actualResultType = dstType; + if (isa(castOp)) { + actualResultType = + rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(), + /*isSigned=*/false); + } + + Value result = rewriter.create( + castOp.getLoc(), actualResultType, adaptor.getOperands()); + + if (isa(castOp)) { + result = rewriter.create(castOp.getLoc(), dstType, result); + } + rewriter.replaceOp(castOp, result); + + return success(); + } +}; + +// Integer to floating-point conversions. +template +class ItoFCastOpConversion : public OpConversionPattern { +public: + ItoFCastOpConversion(const TypeConverter &typeConverter, MLIRContext *context) + : OpConversionPattern(typeConverter, context) {} + + LogicalResult + matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Vectors in particular are not supported + Type operandType = adaptor.getIn().getType(); + if (!emitc::isSupportedIntegerType(operandType)) + return rewriter.notifyMatchFailure(castOp, + "unsupported cast source type"); + + Type dstType = this->getTypeConverter()->convertType(castOp.getType()); + if (!dstType) + return rewriter.notifyMatchFailure(castOp, "type conversion failed"); + + if (!emitc::isSupportedFloatType(dstType)) + return rewriter.notifyMatchFailure(castOp, + "unsupported cast destination type"); + + // Convert to unsigned if it's the "ui" variant + // Signless is interpreted as signed, so no need to cast for "si" + Type actualOperandType = operandType; + if (isa(castOp)) { + actualOperandType = + rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(), + /*isSigned=*/false); + } + Value fpCastOperand = adaptor.getIn(); + if (actualOperandType != operandType) { + fpCastOperand = rewriter.template create( + castOp.getLoc(), actualOperandType, fpCastOperand); + } + rewriter.replaceOpWithNewOp(castOp, dstType, fpCastOperand); + + return success(); + } +}; + } // namespace //===----------------------------------------------------------------------===// @@ -222,7 +310,11 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter, IntegerOpConversion, IntegerOpConversion, CmpIOpConversion, - SelectOpConversion + SelectOpConversion, + ItoFCastOpConversion, + ItoFCastOpConversion, + FtoICastOpConversion, + FtoICastOpConversion >(typeConverter, ctx); // clang-format on } diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir new file mode 100644 index 00000000000000..39b56882853a77 --- /dev/null +++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir @@ -0,0 +1,48 @@ +// RUN: mlir-opt -split-input-file -convert-arith-to-emitc -verify-diagnostics %s + +func.func @arith_cast_tensor(%arg0: tensor<5xf32>) -> tensor<5xi32> { + // expected-error @+1 {{failed to legalize operation 'arith.fptosi'}} + %t = arith.fptosi %arg0 : tensor<5xf32> to tensor<5xi32> + return %t: tensor<5xi32> +} + +// ----- + +func.func @arith_cast_vector(%arg0: vector<5xf32>) -> vector<5xi32> { + // expected-error @+1 {{failed to legalize operation 'arith.fptosi'}} + %t = arith.fptosi %arg0 : vector<5xf32> to vector<5xi32> + return %t: vector<5xi32> +} + +// ----- + +func.func @arith_cast_bf16(%arg0: bf16) -> i32 { + // expected-error @+1 {{failed to legalize operation 'arith.fptosi'}} + %t = arith.fptosi %arg0 : bf16 to i32 + return %t: i32 +} + +// ----- + +func.func @arith_cast_f16(%arg0: f16) -> i32 { + // expected-error @+1 {{failed to legalize operation 'arith.fptosi'}} + %t = arith.fptosi %arg0 : f16 to i32 + return %t: i32 +} + + +// ----- + +func.func @arith_cast_to_bf16(%arg0: i32) -> bf16 { + // expected-error @+1 {{failed to legalize operation 'arith.sitofp'}} + %t = arith.sitofp %arg0 : i32 to bf16 + return %t: bf16 +} + +// ----- + +func.func @arith_cast_to_f16(%arg0: i32) -> f16 { + // expected-error @+1 {{failed to legalize operation 'arith.sitofp'}} + %t = arith.sitofp %arg0 : i32 to f16 + return %t: f16 +} diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir index 46b407177b46aa..245cf7e1ddd647 100644 --- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir +++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir @@ -141,3 +141,39 @@ func.func @arith_cmpi_predicates(%arg0: i32, %arg1: i32) { return } + +// ----- + +func.func @arith_float_to_int_cast_ops(%arg0: f32, %arg1: f64) { + // CHECK: emitc.cast %arg0 : f32 to i32 + %0 = arith.fptosi %arg0 : f32 to i32 + + // CHECK: emitc.cast %arg1 : f64 to i32 + %1 = arith.fptosi %arg1 : f64 to i32 + + // CHECK: emitc.cast %arg0 : f32 to i16 + %2 = arith.fptosi %arg0 : f32 to i16 + + // CHECK: emitc.cast %arg1 : f64 to i16 + %3 = arith.fptosi %arg1 : f64 to i16 + + // CHECK: %[[CAST0:.*]] = emitc.cast %arg0 : f32 to ui32 + // CHECK: emitc.cast %[[CAST0]] : ui32 to i32 + %4 = arith.fptoui %arg0 : f32 to i32 + + return +} + +func.func @arith_int_to_float_cast_ops(%arg0: i8, %arg1: i64) { + // CHECK: emitc.cast %arg0 : i8 to f32 + %0 = arith.sitofp %arg0 : i8 to f32 + + // CHECK: emitc.cast %arg1 : i64 to f32 + %1 = arith.sitofp %arg1 : i64 to f32 + + // CHECK: %[[CAST_UNS:.*]] = emitc.cast %arg0 : i8 to ui8 + // CHECK: emitc.cast %[[CAST_UNS]] : ui8 to f32 + %2 = arith.uitofp %arg0 : i8 to f32 + + return +}