Skip to content

Commit

Permalink
[Tosa] : Fix integer overflow for computing intmax+1 in tosa.cast to …
Browse files Browse the repository at this point in the history
…linalg. (llvm#112455)

This PR fixes an issue related to integer overflow when computing
`(intmax+1)` for `i64` during `tosa-to-linalg` pass for `tosa.cast`.

Found this issue while debugging a numerical mismatch for `deeplabv3`
model from `torchvision` represented in `tosa` dialect using the
`TorchToTosa` pipeline in `torch-mlir` repository. `torch.aten.to.dtype`
is converted to `tosa.cast` that casts `f32` to `i64` type. Technically
by the specification, `tosa.cast` doesn't handle casting `f32` to `i64`.
So it's possible to add a verifier to error out for such tosa ops
instead of producing incorrect code. However, I chose to fix the
overflow issue to still be able to represent the `deeplabv3` model with
`tosa` ops in the above-mentioned pipeline. Open to suggestions if
adding the verifier is more appropriate instead.
  • Loading branch information
sahas3 authored Oct 25, 2024
1 parent 2e43a30 commit b91ce7b
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 3 deletions.
7 changes: 4 additions & 3 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -555,9 +555,10 @@ static Value createLinalgBodyCalculationForElementwiseOp(
auto intMaxPlusOneFP = rewriter.create<arith::ConstantOp>(
loc, rewriter.getFloatAttr(
getElementTypeOrSelf(srcTy),
APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
.getSExtValue() +
1));
static_cast<double>(
APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
.getSExtValue()) +
1.0f));

auto intMax = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(
Expand Down
27 changes: 27 additions & 0 deletions mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1929,3 +1929,30 @@ func.func @test_dynamic_fft2d(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>
%output_real, %output_imag = "tosa.fft2d"(%arg0, %arg1) {inverse = true} : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
return %output_real, %output_imag : tensor<?x?x?xf32>, tensor<?x?x?xf32>
}


// -----

// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (0)>
// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (d0)>

// CHECK-LABEL: func.func @test_cast_fp32_i64(
// CHECK-SAME: %[[ARG0:.*]]: tensor<1xf32>) -> tensor<1xi64> {
// CHECK: %[[EMPTY_TENSOR:.*]] = tensor.empty() : tensor<1xi64>
// CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<1xf32>) outs(%[[EMPTY_TENSOR]] : tensor<1xi64>) {
// CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: i64):
// CHECK: %[[ROUND_EVEN:.*]] = math.roundeven %[[IN]] : f32
// CHECK: %[[FP_INT_MIN:.*]] = arith.constant -9.22337203E+18 : f32
// CHECK: %[[FP_INT_MAX_PLUS_ONE:.*]] = arith.constant 9.22337203E+18 : f32
// CHECK: %[[INT_MAX:.*]] = arith.constant 9223372036854775807 : i64
// CHECK: %[[MAX:.*]] = arith.maximumf %[[ROUND_EVEN]], %[[FP_INT_MIN]] : f32
// CHECK: %[[FPTOSI:.*]] = arith.fptosi %[[MAX]] : f32 to i64
// CHECK: %[[CMPF:.*]] = arith.cmpf uge, %[[ROUND_EVEN]], %[[FP_INT_MAX_PLUS_ONE]] : f32
// CHECK: %[[SELECT:.*]] = arith.select %[[CMPF]], %[[INT_MAX]], %[[FPTOSI]] : i64
// CHECK: linalg.yield %[[SELECT]] : i64
// CHECK: } -> tensor<1xi64>
// CHECK: return %[[RESULT]] : tensor<1xi64>
func.func @test_cast_fp32_i64(%arg0: tensor<1xf32>) -> (tensor<1xi64>) {
%0 = tosa.cast %arg0 : (tensor<1xf32>) -> tensor<1xi64>
return %0: tensor<1xi64>
}

0 comments on commit b91ce7b

Please sign in to comment.