Skip to content

Commit

Permalink
Use K as an operand instead of attribute
Browse files Browse the repository at this point in the history
Set k as an operand
  • Loading branch information
josel-amd committed Aug 20, 2024
1 parent c9dfb61 commit 51363b6
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 13 deletions.
8 changes: 5 additions & 3 deletions include/xten/Dialect/XTenNN/IR/XTenNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -520,15 +520,17 @@ def XtenNN_TopK: XTenNN_Op<"topk", [
Follows the specification of ONNX TopK at opset 11
}];
let arguments = (ins
AnyRankedTensor:$input,
I64Attr:$k,
AnyTensor:$input,
I64:$k,
I64Attr:$axis,
I1Attr:$largest,
I1Attr:$sorted
);
let results = (outs AnyRankedTensor:$output, AnyRankedTensor:$indices);

let assemblyFormat = [{ `(`$input `:` type($input)`)` attr-dict `->` type($output) `,` type($indices) }];
let assemblyFormat = [{ `(`$input `:` type($input) `,` $k `:` type($k)`)` attr-dict `->` type($output) `,` type($indices) }];

let hasVerifier = 1;
}


Expand Down
31 changes: 29 additions & 2 deletions lib/Dialect/XTenNN/IR/XTenNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/OpImplementation.h"
Expand Down Expand Up @@ -529,6 +530,17 @@ LogicalResult amd::xten_nn::ResizeOp::verify() {
return success();
}

std::optional<uint64_t> getConstantK(Operation *op) {
auto constantOp = dyn_cast<arith::ConstantOp>(op);
if (!constantOp)
return {};
auto intAttr = dyn_cast<IntegerAttr>(constantOp.getValue());
if (!intAttr)
return {};
return (uint64_t)
intAttr.getInt(); // Always positive by definition of onnx.topk
}

LogicalResult TopK::inferReturnTypeComponents(
MLIRContext *context, std::optional<Location> location,
TopK::Adaptor adaptor,
Expand All @@ -540,16 +552,31 @@ LogicalResult TopK::inferReturnTypeComponents(
return emitOptionalError(location, "expected axis <= rank of input");
}
auto dimSize = inTy.getDimSize(axis);
if ((uint64_t)dimSize < adaptor.getK()) {
uint64_t k = *getConstantK(adaptor.getK().getDefiningOp());

if (dimSize < 0) {
// TODO: Support negative dimSize
return emitOptionalError(location, "expected positive k");
}

if ((uint64_t)dimSize < k) {
return emitOptionalError(location, "expected k <= dimension size");
}

SmallVector<int64_t> resultShape{inTy.getShape()};
resultShape[axis] = adaptor.getK();
resultShape[axis] = k;

inferredReturnShapes.push_back(
ShapedTypeComponents(resultShape, inTy.getElementType()));
inferredReturnShapes.push_back(
ShapedTypeComponents(resultShape, IntegerType::get(context, 64)));
return success();
}

LogicalResult amd::xten_nn::TopK::verify() {
if (!isa<arith::ConstantOp>(getK().getDefiningOp())) {
return failure();
}

return success();
}
10 changes: 6 additions & 4 deletions test/Dialect/XTenNN/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,11 @@ func.func @kernel(%arg0: tensor<2xi64>, %arg1 : tensor<4xi64>) {

// CHECK-LABEL: topk
func.func @topk(%arg0: tensor<10x8xf32>) {
xten_nn.topk(%arg0 : tensor<10x8xf32>) {axis = 0 : i64, k = 7 : i64, largest = true, sorted = true} -> tensor<7x8xf32>, tensor<7x8xi64>
// CHECK: xten_nn.topk(%arg0 : tensor<10x8xf32>) {axis = 0 : i64, k = 7 : i64, largest = true, sorted = true} -> tensor<7x8xf32>, tensor<7x8xi64>
xten_nn.topk(%arg0 : tensor<10x8xf32>) {axis = 1 : i64, k = 7 : i64, largest = true, sorted = true} -> tensor<10x7xf32>, tensor<10x7xi64>
// CHECK: xten_nn.topk(%arg0 : tensor<10x8xf32>) {axis = 1 : i64, k = 7 : i64, largest = true, sorted = true} -> tensor<10x7xf32>, tensor<10x7xi64>
%k = arith.constant 7 : i64
// CHECK: %[[C7:.*]] = arith.constant 7 : i64
xten_nn.topk(%arg0 : tensor<10x8xf32>, %k : i64) {axis = 0 : i64, largest = true, sorted = true} -> tensor<7x8xf32>, tensor<7x8xi64>
// CHECK: xten_nn.topk(%arg0 : tensor<10x8xf32>, %[[C7]] : i64) {axis = 0 : i64, largest = true, sorted = true} -> tensor<7x8xf32>, tensor<7x8xi64>
xten_nn.topk(%arg0 : tensor<10x8xf32>, %k : i64) {axis = 1 : i64, largest = true, sorted = true} -> tensor<10x7xf32>, tensor<10x7xi64>
// CHECK: xten_nn.topk(%arg0 : tensor<10x8xf32>, %[[C7]] : i64) {axis = 1 : i64, largest = true, sorted = true} -> tensor<10x7xf32>, tensor<10x7xi64>
return
}
12 changes: 8 additions & 4 deletions test/Dialect/XTenNN/ops_invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -71,35 +71,39 @@ func.func @kernel_missing_result(%arg0: i8, %arg1: i8) {
// -----

func.func @topk_wrong_output_shape(%arg0: tensor<10x10xf32>) {
%k = arith.constant 7 : i64
// expected-error@+2 {{failed to infer returned types}}
// expected-error@+1 {{'xten_nn.topk' op inferred type(s) 'tensor<7x10xf32>', 'tensor<7x10xi64>' are incompatible with return type(s) of operation 'tensor<1xf32>', 'tensor<1xi64>'}}
%a, %b = xten_nn.topk(%arg0 : tensor<10x10xf32>) {axis = 0 : i64, k = 7 : i64, largest = true, sorted = true} -> tensor<1xf32>, tensor<1xi64>
%a, %b = xten_nn.topk(%arg0 : tensor<10x10xf32>, %k : i64) {axis = 0 : i64, largest = true, sorted = true} -> tensor<1xf32>, tensor<1xi64>
return
}

// -----

func.func @topk_wrong_indices_shape(%arg0: tensor<10x10xf32>) {
%k = arith.constant 7 : i64
// expected-error@+2 {{failed to infer returned types}}
// expected-error@+1 {{'xten_nn.topk' op inferred type(s) 'tensor<7x10xf32>', 'tensor<7x10xi64>' are incompatible with return type(s) of operation 'tensor<7x10xf32>', 'tensor<7x10xf32>'}}
%a, %b = xten_nn.topk(%arg0 : tensor<10x10xf32>) {axis = 0 : i64, k = 7 : i64, largest = true, sorted = true} -> tensor<7x10xf32>, tensor<7x10xf32>
%a, %b = xten_nn.topk(%arg0 : tensor<10x10xf32>, %k : i64) {axis = 0 : i64, largest = true, sorted = true} -> tensor<7x10xf32>, tensor<7x10xf32>
return
}

// -----

func.func @topk_wrong_axis(%arg0: tensor<10x10xf32>) {
%k = arith.constant 7 : i64
// expected-error@+2 {{failed to infer returned types}}
// expected-error@+1 {{expected axis <= rank of input}}
%a, %b = xten_nn.topk(%arg0 : tensor<10x10xf32>) {axis = 3 : i64, k = 7 : i64, largest = true, sorted = true} -> tensor<10x10xf32>, tensor<1xi64>
%a, %b = xten_nn.topk(%arg0 : tensor<10x10xf32>, %k : i64) {axis = 3 : i64, largest = true, sorted = true} -> tensor<10x10xf32>, tensor<1xi64>
return
}

// -----

func.func @topk_large_k(%arg0: tensor<10x10xf32>) {
%k = arith.constant 100 : i64
// expected-error@+2 {{failed to infer returned types}}
// expected-error@+1 {{expected k <= dimension size}}
%a, %b = xten_nn.topk(%arg0 : tensor<10x10xf32>) {axis = 0 : i64, k = 100 : i64, largest = true, sorted = true} -> tensor<10x10xf32>, tensor<1xi64>
%a, %b = xten_nn.topk(%arg0 : tensor<10x10xf32>, %k : i64) {axis = 0 : i64, largest = true, sorted = true} -> tensor<10x10xf32>, tensor<1xi64>
return
}

0 comments on commit 51363b6

Please sign in to comment.