From b91ce7bc0eda8ec91c7ba724fbe78607843e0e06 Mon Sep 17 00:00:00 2001 From: Sayan Saha Date: Fri, 25 Oct 2024 05:51:09 -0400 Subject: [PATCH] [Tosa] : Fix integer overflow for computing intmax+1 in tosa.cast to linalg. (#112455) This PR fixes an issue related to integer overflow when computing `(intmax+1)` for `i64` during `tosa-to-linalg` pass for `tosa.cast`. Found this issue while debugging a numerical mismatch for `deeplabv3` model from `torchvision` represented in `tosa` dialect using the `TorchToTosa` pipeline in `torch-mlir` repository. `torch.aten.to.dtype` is converted to `tosa.cast` that casts `f32` to `i64` type. Technically by the specification, `tosa.cast` doesn't handle casting `f32` to `i64`. So it's possible to add a verifier to error out for such tosa ops instead of producing incorrect code. However, I chose to fix the overflow issue to still be able to represent the `deeplabv3` model with `tosa` ops in the above-mentioned pipeline. Open to suggestions if adding the verifier is more appropriate instead. --- .../Conversion/TosaToLinalg/TosaToLinalg.cpp | 7 ++--- .../TosaToLinalg/tosa-to-linalg.mlir | 27 +++++++++++++++++++ 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index 495f1b4f10b0286..251c48859d2de04 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -555,9 +555,10 @@ static Value createLinalgBodyCalculationForElementwiseOp( auto intMaxPlusOneFP = rewriter.create( loc, rewriter.getFloatAttr( getElementTypeOrSelf(srcTy), - APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()) - .getSExtValue() + - 1)); + static_cast( + APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()) + .getSExtValue()) + + 1.0f)); auto intMax = rewriter.create( loc, rewriter.getIntegerAttr( diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index f9d37f9427d4f44..1a29d3f9f3507c0 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -1929,3 +1929,30 @@ func.func @test_dynamic_fft2d(%arg0: tensor, %arg1: tensor %output_real, %output_imag = "tosa.fft2d"(%arg0, %arg1) {inverse = true} : (tensor, tensor) -> (tensor, tensor) return %output_real, %output_imag : tensor, tensor } + + +// ----- + +// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (0)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (d0)> + +// CHECK-LABEL: func.func @test_cast_fp32_i64( +// CHECK-SAME: %[[ARG0:.*]]: tensor<1xf32>) -> tensor<1xi64> { +// CHECK: %[[EMPTY_TENSOR:.*]] = tensor.empty() : tensor<1xi64> +// CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<1xf32>) outs(%[[EMPTY_TENSOR]] : tensor<1xi64>) { +// CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: i64): +// CHECK: %[[ROUND_EVEN:.*]] = math.roundeven %[[IN]] : f32 +// CHECK: %[[FP_INT_MIN:.*]] = arith.constant -9.22337203E+18 : f32 +// CHECK: %[[FP_INT_MAX_PLUS_ONE:.*]] = arith.constant 9.22337203E+18 : f32 +// CHECK: %[[INT_MAX:.*]] = arith.constant 9223372036854775807 : i64 +// CHECK: %[[MAX:.*]] = arith.maximumf %[[ROUND_EVEN]], %[[FP_INT_MIN]] : f32 +// CHECK: %[[FPTOSI:.*]] = arith.fptosi %[[MAX]] : f32 to i64 +// CHECK: %[[CMPF:.*]] = arith.cmpf uge, %[[ROUND_EVEN]], %[[FP_INT_MAX_PLUS_ONE]] : f32 +// CHECK: %[[SELECT:.*]] = arith.select %[[CMPF]], %[[INT_MAX]], %[[FPTOSI]] : i64 +// CHECK: linalg.yield %[[SELECT]] : i64 +// CHECK: } -> tensor<1xi64> +// CHECK: return %[[RESULT]] : tensor<1xi64> +func.func @test_cast_fp32_i64(%arg0: tensor<1xf32>) -> (tensor<1xi64>) { + %0 = tosa.cast %arg0 : (tensor<1xf32>) -> tensor<1xi64> + return %0: tensor<1xi64> +}