diff --git a/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h b/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h index 32d039e9c89185f..9cb43689d1ce64d 100644 --- a/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h +++ b/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h @@ -14,8 +14,7 @@ class RewritePatternSet; class TypeConverter; void populateArithToEmitCPatterns(TypeConverter &typeConverter, - RewritePatternSet &patterns, - bool optionFloatToIntTruncates); + RewritePatternSet &patterns); } // namespace mlir #endif // MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 7e0cf63a9805ed4..c01c20e747f03ae 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -139,24 +139,7 @@ def ArithToAMDGPUConversionPass : Pass<"convert-arith-to-amdgpu"> { def ConvertArithToEmitC : Pass<"convert-arith-to-emitc"> { let summary = "Convert Arith dialect to EmitC dialect"; - let description = [{ - This pass converts `arith` dialect operations to `emitc`. - - The semantics of floating-point to integer conversions `arith.fptosi`, - `arith.fptoui` require rounding towards zero. Typical C++ implementations - use this behavior for float-to-integer casts, but that is not mandated by - C++ and there are implementation-defined means to change the default behavior. - - If casts can be guaranteed to use round-to-zero, use the - `float-to-int-truncates` flag to allow conversion of `arith.fptosi` and - `arith.fptoui` operations. - }]; let dependentDialects = ["emitc::EmitCDialect"]; - let options = [ - Option<"floatToIntTruncates", "float-to-int-truncates", "bool", - /*default=*/"false", - "Whether the behavior of float-to-int cast in emitc is truncation">, - ]; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp index cb2d02ca1408781..0337314ce7f3487 100644 --- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp +++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp @@ -366,14 +366,9 @@ class SelectOpConversion : public OpConversionPattern { // Floating-point to integer conversions. template class FtoICastOpConversion : public OpConversionPattern { -private: - bool floatToIntTruncates; - public: - FtoICastOpConversion(const TypeConverter &typeConverter, MLIRContext *context, - bool optionFloatToIntTruncates) - : OpConversionPattern(typeConverter, context), - floatToIntTruncates(optionFloatToIntTruncates) {} + FtoICastOpConversion(const TypeConverter &typeConverter, MLIRContext *context) + : OpConversionPattern(typeConverter, context) {} LogicalResult matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor, @@ -384,16 +379,13 @@ class FtoICastOpConversion : public OpConversionPattern { return rewriter.notifyMatchFailure(castOp, "unsupported cast source type"); - if (!floatToIntTruncates) - return rewriter.notifyMatchFailure( - castOp, "conversion currently requires EmitC casts to use truncation " - "as rounding mode"); - Type dstType = this->getTypeConverter()->convertType(castOp.getType()); if (!dstType) return rewriter.notifyMatchFailure(castOp, "type conversion failed"); - if (!emitc::isSupportedIntegerType(dstType)) + // Float-to-i1 casts are not supported: any value with 0 < value < 1 must be + // truncated to 0, whereas a boolean conversion would return true. + if (!emitc::isSupportedIntegerType(dstType) || dstType.isInteger(1)) return rewriter.notifyMatchFailure(castOp, "unsupported cast destination type"); @@ -468,8 +460,7 @@ class ItoFCastOpConversion : public OpConversionPattern { //===----------------------------------------------------------------------===// void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter, - RewritePatternSet &patterns, - bool optionFloatToIntTruncates) { + RewritePatternSet &patterns) { MLIRContext *ctx = patterns.getContext(); // clang-format off @@ -488,11 +479,9 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter, CmpIOpConversion, SelectOpConversion, ItoFCastOpConversion, - ItoFCastOpConversion - >(typeConverter, ctx) - .add< + ItoFCastOpConversion, FtoICastOpConversion, FtoICastOpConversion - >(typeConverter, ctx, optionFloatToIntTruncates); + >(typeConverter, ctx); // clang-format on } diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp index 10ea0648823a20e..76e7707ce7109e9 100644 --- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp +++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp @@ -45,7 +45,7 @@ void ConvertArithToEmitC::runOnOperation() { TypeConverter typeConverter; typeConverter.addConversion([](Type type) { return type; }); - populateArithToEmitCPatterns(typeConverter, patterns, floatToIntTruncates); + populateArithToEmitCPatterns(typeConverter, patterns); if (failed( applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-cast-truncate.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-cast-truncate.mlir deleted file mode 100644 index 26f9261183144ec..000000000000000 --- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-cast-truncate.mlir +++ /dev/null @@ -1,21 +0,0 @@ -// RUN: mlir-opt -split-input-file --pass-pipeline="builtin.module(convert-arith-to-emitc{float-to-int-truncates})" %s | FileCheck %s - -func.func @arith_float_to_int_cast_ops(%arg0: f32, %arg1: f64) { - // CHECK: emitc.cast %arg0 : f32 to i32 - %0 = arith.fptosi %arg0 : f32 to i32 - - // CHECK: emitc.cast %arg1 : f64 to i32 - %1 = arith.fptosi %arg1 : f64 to i32 - - // CHECK: emitc.cast %arg0 : f32 to i16 - %2 = arith.fptosi %arg0 : f32 to i16 - - // CHECK: emitc.cast %arg1 : f64 to i16 - %3 = arith.fptosi %arg1 : f64 to i16 - - // CHECK: %[[CAST0:.*]] = emitc.cast %arg0 : f32 to ui32 - // CHECK: emitc.cast %[[CAST0]] : ui32 to i32 - %4 = arith.fptoui %arg0 : f32 to i32 - - return -} diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-cast-unsupported.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-cast-unsupported.mlir deleted file mode 100644 index 34fc9f3dffc0c81..000000000000000 --- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-cast-unsupported.mlir +++ /dev/null @@ -1,48 +0,0 @@ -// RUN: mlir-opt -split-input-file --pass-pipeline="builtin.module(convert-arith-to-emitc{float-to-int-truncates})" -verify-diagnostics %s - -func.func @arith_cast_tensor(%arg0: tensor<5xf32>) -> tensor<5xi32> { - // expected-error @+1 {{failed to legalize operation 'arith.fptosi'}} - %t = arith.fptosi %arg0 : tensor<5xf32> to tensor<5xi32> - return %t: tensor<5xi32> -} - -// ----- - -func.func @arith_cast_vector(%arg0: vector<5xf32>) -> vector<5xi32> { - // expected-error @+1 {{failed to legalize operation 'arith.fptosi'}} - %t = arith.fptosi %arg0 : vector<5xf32> to vector<5xi32> - return %t: vector<5xi32> -} - -// ----- - -func.func @arith_cast_bf16(%arg0: bf16) -> i32 { - // expected-error @+1 {{failed to legalize operation 'arith.fptosi'}} - %t = arith.fptosi %arg0 : bf16 to i32 - return %t: i32 -} - -// ----- - -func.func @arith_cast_f16(%arg0: f16) -> i32 { - // expected-error @+1 {{failed to legalize operation 'arith.fptosi'}} - %t = arith.fptosi %arg0 : f16 to i32 - return %t: i32 -} - - -// ----- - -func.func @arith_cast_to_bf16(%arg0: i32) -> bf16 { - // expected-error @+1 {{failed to legalize operation 'arith.sitofp'}} - %t = arith.sitofp %arg0 : i32 to bf16 - return %t: bf16 -} - -// ----- - -func.func @arith_cast_to_f16(%arg0: i32) -> f16 { - // expected-error @+1 {{failed to legalize operation 'arith.sitofp'}} - %t = arith.sitofp %arg0 : i32 to f16 - return %t: f16 -} diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir index 37427ef70cacc2e..32c0c0381d326a2 100644 --- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir +++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir @@ -16,8 +16,66 @@ func.func @arith_cmpf_vector(%arg0: vector<5xf32>, %arg1: vector<5xf32>) -> vect // ----- -func.func @arith_cast_f32(%arg0: f32) -> i32 { +func.func @arith_cast_tensor(%arg0: tensor<5xf32>) -> tensor<5xi32> { // expected-error @+1 {{failed to legalize operation 'arith.fptosi'}} - %t = arith.fptosi %arg0 : f32 to i32 + %t = arith.fptosi %arg0 : tensor<5xf32> to tensor<5xi32> + return %t: tensor<5xi32> +} + +// ----- + +func.func @arith_cast_vector(%arg0: vector<5xf32>) -> vector<5xi32> { + // expected-error @+1 {{failed to legalize operation 'arith.fptosi'}} + %t = arith.fptosi %arg0 : vector<5xf32> to vector<5xi32> + return %t: vector<5xi32> +} + +// ----- + +func.func @arith_cast_bf16(%arg0: bf16) -> i32 { + // expected-error @+1 {{failed to legalize operation 'arith.fptosi'}} + %t = arith.fptosi %arg0 : bf16 to i32 return %t: i32 } + +// ----- + +func.func @arith_cast_f16(%arg0: f16) -> i32 { + // expected-error @+1 {{failed to legalize operation 'arith.fptosi'}} + %t = arith.fptosi %arg0 : f16 to i32 + return %t: i32 +} + + +// ----- + +func.func @arith_cast_to_bf16(%arg0: i32) -> bf16 { + // expected-error @+1 {{failed to legalize operation 'arith.sitofp'}} + %t = arith.sitofp %arg0 : i32 to bf16 + return %t: bf16 +} + +// ----- + +func.func @arith_cast_to_f16(%arg0: i32) -> f16 { + // expected-error @+1 {{failed to legalize operation 'arith.sitofp'}} + %t = arith.sitofp %arg0 : i32 to f16 + return %t: f16 +} + +// ----- + +func.func @arith_cast_fptosi_i1(%arg0: f32) -> i1 { + // expected-error @+1 {{failed to legalize operation 'arith.fptosi'}} + %t = arith.fptosi %arg0 : f32 to i1 + return %t: i1 +} + +// ----- + +func.func @arith_cast_fptoui_i1(%arg0: f32) -> i1 { + // expected-error @+1 {{failed to legalize operation 'arith.fptoui'}} + %t = arith.fptoui %arg0 : f32 to i1 + return %t: i1 +} + diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir index 43f03111fc7b5bc..ed63d4080897325 100644 --- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir +++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir @@ -309,22 +309,6 @@ func.func @arith_cmpf_true(%arg0: f32, %arg1: f32) -> i1 { // ----- -func.func @arith_int_to_float_cast_ops(%arg0: i8, %arg1: i64) { - // CHECK: emitc.cast %arg0 : i8 to f32 - %0 = arith.sitofp %arg0 : i8 to f32 - - // CHECK: emitc.cast %arg1 : i64 to f32 - %1 = arith.sitofp %arg1 : i64 to f32 - - // CHECK: %[[CAST_UNS:.*]] = emitc.cast %arg0 : i8 to ui8 - // CHECK: emitc.cast %[[CAST_UNS]] : ui8 to f32 - %2 = arith.uitofp %arg0 : i8 to f32 - - return -} - -// ----- - func.func @arith_cmpi_eq(%arg0: i32, %arg1: i32) -> i1 { // CHECK-LABEL: arith_cmpi_eq // CHECK-SAME: ([[Arg0:[^ ]*]]: i32, [[Arg1:[^ ]*]]: i32) @@ -370,3 +354,39 @@ func.func @arith_cmpi_predicates(%arg0: i32, %arg1: i32) { return } + +// ----- + +func.func @arith_float_to_int_cast_ops(%arg0: f32, %arg1: f64) { + // CHECK: emitc.cast %arg0 : f32 to i32 + %0 = arith.fptosi %arg0 : f32 to i32 + + // CHECK: emitc.cast %arg1 : f64 to i32 + %1 = arith.fptosi %arg1 : f64 to i32 + + // CHECK: emitc.cast %arg0 : f32 to i16 + %2 = arith.fptosi %arg0 : f32 to i16 + + // CHECK: emitc.cast %arg1 : f64 to i16 + %3 = arith.fptosi %arg1 : f64 to i16 + + // CHECK: %[[CAST0:.*]] = emitc.cast %arg0 : f32 to ui32 + // CHECK: emitc.cast %[[CAST0]] : ui32 to i32 + %4 = arith.fptoui %arg0 : f32 to i32 + + return +} + +func.func @arith_int_to_float_cast_ops(%arg0: i8, %arg1: i64) { + // CHECK: emitc.cast %arg0 : i8 to f32 + %0 = arith.sitofp %arg0 : i8 to f32 + + // CHECK: emitc.cast %arg1 : i64 to f32 + %1 = arith.sitofp %arg1 : i64 to f32 + + // CHECK: %[[CAST_UNS:.*]] = emitc.cast %arg0 : i8 to ui8 + // CHECK: emitc.cast %[[CAST_UNS]] : ui8 to f32 + %2 = arith.uitofp %arg0 : i8 to f32 + + return +}