Skip to content

Commit

Permalink
[mlir][Transforms] Add missing check in applyPermutation
Browse files Browse the repository at this point in the history
The applyPermutation() utility should make sure
that the permutation numbers are within the size
of the input array. Otherwise it will cause a
cryptic array out of bound assertion later.
  • Loading branch information
DarshanRamakant committed Aug 11, 2024
1 parent 4ac42af commit 054856d
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 27 deletions.
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/Utils/IndexingUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,9 @@ SmallVector<T> applyPermutation(ArrayRef<T> input,
ArrayRef<int64_t> permutation) {
assert(input.size() == permutation.size() &&
"expected input rank to equal permutation rank");
assert(
llvm::all_of(permutation, [&](size_t s) { return s < input.size(); }) &&
"permutation must be within input bounds");
auto permutationRange = llvm::map_range(
llvm::seq<unsigned>(0, input.size()),
[&](int64_t idx) -> T { return input[permutation[idx]]; });
Expand Down
6 changes: 6 additions & 0 deletions mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1116,6 +1116,12 @@ LogicalResult tosa::TransposeOp::verify() {
"Unexpectedly found permutation tensor without rank");
if (!isPermutationVector(constantPerms))
return emitOpError() << "expected valid permutation tensor";

if (inputType.hasRank() && !llvm::all_of(constantPerms, [&](int64_t s) {
return s < inputType.getRank();
})) {
return emitOpError() << "permutation must be within input bounds";
}
}
return success();
}
Expand Down
35 changes: 35 additions & 0 deletions mlir/test/Dialect/Tosa/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -413,3 +413,38 @@ func.func @test_tile_invalid_multiples() {
%1 = tosa.tile %0 {multiples = array<i64>} : (tensor<4x31x31xf32>) -> tensor<4x31x31xf32>
return
}

// -----

// CHECK-LABEL: @test_invalid_constant_permutation
func.func @test_invalid_constant_permutation() {
// expected-error@+3 {{permutation must be within input bounds}}
%0 = tensor.empty() : tensor<3x4x5xi32>
%1 = arith.constant dense<[3, 0, 1]> : tensor<3xi32>
%2 = tosa.transpose %0, %1 : (tensor<3x4x5xi32>, tensor<3xi32>) -> tensor<3x4x5xi32>
return
}

// -----

// CHECK-LABEL: test_rank_size_constant_permutation
func.func @test_rank_size_constant_permutation() {
// expected-error@+4 {{permutation must be within input bounds}}
%0 = arith.constant 6 : index
%1 = arith.constant dense<[0, 2]> : tensor<2xi32>
%2 = tensor.empty(%0) : tensor<?x27xi64>
%3 = tosa.transpose %2, %1 : (tensor<?x27xi64>, tensor<2xi32>) -> tensor<?x27xi64>
return
}

// -----

// CHECK-LABEL: test_large_constant_permutation
func.func @test_large_constant_permutation() {
// expected-error@+4 {{permutation must be within input bounds}}
%0 = arith.constant 6 : index
%1 = arith.constant dense<[1185677355, 332462212]> : tensor<2xi32>
%2 = tensor.empty(%0) : tensor<?x27xi64>
%3 = tosa.transpose %2, %1 : (tensor<?x27xi64>, tensor<2xi32>) -> tensor<?x27xi64>
return
}
28 changes: 1 addition & 27 deletions mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1373,30 +1373,4 @@ func.func @test_tosa_use_def_chain(%arg0: tensor<1x32x32x3xf32>, %arg1: tensor<1
// CHECK: (tensor<1x32x32x16xf32>) -> tensor<1x16x16x16xf32>
%1 = tosa.max_pool2d %0 {kernel = array<i64: 2, 2>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 2>} : (tensor<?x32x32x16xf32>) -> tensor<?x16x16x16xf32>
return %1 : tensor<?x16x16x16xf32>
}

// -----

// CHECK-LABEL: test_rank_size_constant_permutation
func.func @test_rank_size_constant_permutation() {
%c6 = arith.constant 6 : index
%cst_26 = arith.constant dense<[0, 2]> : tensor<2xi32>
%14 = tensor.empty(%c6) : tensor<?x27xi64>
// Fail to infer the shape but not crash.
// CHECK: (tensor<?x27xi64>, tensor<2xi32>) -> tensor<?x27xi64>
%72 = tosa.transpose %14, %cst_26 : (tensor<?x27xi64>, tensor<2xi32>) -> tensor<?x27xi64>
return
}

// -----

// CHECK-LABEL: test_large_constant_permutation
func.func @test_large_constant_permutation() {
%c6 = arith.constant 6 : index
%cst_26 = arith.constant dense<[1185677355, 332462212]> : tensor<2xi32>
%14 = tensor.empty(%c6) : tensor<?x27xi64>
// Fail to infer the shape but not crash.
// CHECK: (tensor<?x27xi64>, tensor<2xi32>) -> tensor<?x27xi64>
%72 = tosa.transpose %14, %cst_26 : (tensor<?x27xi64>, tensor<2xi32>) -> tensor<?x27xi64>
return
}
}

0 comments on commit 054856d

Please sign in to comment.