Skip to content

Commit

Permalink
Allow for negative axis
Browse files Browse the repository at this point in the history
  • Loading branch information
josel-amd committed Aug 26, 2024
1 parent d7ba85f commit 7cc1cdb
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 4 deletions.
16 changes: 13 additions & 3 deletions lib/Dialect/XTenNN/IR/XTenNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -552,11 +552,21 @@ LogicalResult TopK::inferReturnTypeComponents(

auto inTy = cast<RankedTensorType>(adaptor.getInput().getType());

auto axis = adaptor.getAxis();
if (axis >= (uint64_t)inTy.getRank()) {
return emitOptionalError(location, "expected axis <= rank of input");
auto axis = (int64_t)adaptor.getAxis();
// onnx spec: axis: [-r, r-1]
if (!(axis >= -inTy.getRank()) || !(axis < inTy.getRank())) {
return emitOptionalError(
location,
"expected axis to be within \"rank < axis <= rank - 1\" of input");
}

// normalize axis: [0, r)
if (axis < 0) {
axis += inTy.getRank();
}

assert((axis >= 0 && axis < inTy.getRank()) && "axis with wrong value");

auto dimSize = inTy.getDimSize(axis);
auto k = getConstantK(adaptor.getK());
// If both k and dim are known statically, we can check that k <= dim
Expand Down
10 changes: 10 additions & 0 deletions test/Dialect/XTenNN/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,13 @@ func.func @topk_arg_dyn_in(%arg0: tensor<?x?xf32>, %k: i64) {
// CHECK: xten_nn.topk(%arg0 : tensor<?x?xf32>, %arg1 : i64) {axis = 1 : i64, largest = true, sorted = true} -> tensor<?x?xf32>, tensor<?x?xi64>
return
}


// -----

// CHECK-LABEL: topk_neg_axis
func.func @topk_neg_axis(%arg0: tensor<10x8xf32>, %k: i64) {
xten_nn.topk(%arg0 : tensor<10x8xf32>, %k : i64) {axis = -1 : i64, largest = true, sorted = true} -> tensor<10x?xf32>, tensor<10x?xi64>
// CHECK: xten_nn.topk(%arg0 : tensor<10x8xf32>, %arg1 : i64) {axis = -1 : i64, largest = true, sorted = true} -> tensor<10x?xf32>, tensor<10x?xi64>
return
}
12 changes: 11 additions & 1 deletion test/Dialect/XTenNN/ops_invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func.func @topk_wrong_indices_shape(%arg0: tensor<10x10xf32>) {
func.func @topk_wrong_axis(%arg0: tensor<10x10xf32>) {
%k = arith.constant 7 : i64
// expected-error@+2 {{failed to infer returned types}}
// expected-error@+1 {{expected axis <= rank of input}}
// expected-error@+1 {{expected axis to be within "rank < axis <= rank - 1" of input}}
%a, %b = xten_nn.topk(%arg0 : tensor<10x10xf32>, %k : i64) {axis = 3 : i64, largest = true, sorted = true} -> tensor<10x10xf32>, tensor<1xi64>
return
}
Expand All @@ -107,3 +107,13 @@ func.func @topk_large_k(%arg0: tensor<10x10xf32>) {
%a, %b = xten_nn.topk(%arg0 : tensor<10x10xf32>, %k : i64) {axis = 0 : i64, largest = true, sorted = true} -> tensor<10x10xf32>, tensor<1xi64>
return
}

// -----

func.func @topk_negative_axis(%arg0: tensor<10x10xf32>) {
%k = arith.constant 100 : i64
// expected-error@+2 {{failed to infer returned types}}
// expected-error@+1 {{expected axis to be within "rank < axis <= rank - 1" of input}}
%a, %b = xten_nn.topk(%arg0 : tensor<10x10xf32>, %k : i64) {axis = -3 : i64, largest = true, sorted = true} -> tensor<10x10xf32>, tensor<1xi64>
return
}

0 comments on commit 7cc1cdb

Please sign in to comment.