diff --git a/include/xten/Dialect/XTenNN/IR/XTenNNOps.td b/include/xten/Dialect/XTenNN/IR/XTenNNOps.td index 02a12a92..d622a77f 100644 --- a/include/xten/Dialect/XTenNN/IR/XTenNNOps.td +++ b/include/xten/Dialect/XTenNN/IR/XTenNNOps.td @@ -520,15 +520,17 @@ def XtenNN_TopK: XTenNN_Op<"topk", [ Follows the specification of ONNX TopK at opset 11 }]; let arguments = (ins - AnyRankedTensor:$input, - I64Attr:$k, + AnyTensor:$input, + I64:$k, I64Attr:$axis, I1Attr:$largest, I1Attr:$sorted ); let results = (outs AnyRankedTensor:$output, AnyRankedTensor:$indices); - let assemblyFormat = [{ `(`$input `:` type($input)`)` attr-dict `->` type($output) `,` type($indices) }]; + let assemblyFormat = [{ `(`$input `:` type($input) `,` $k `:` type($k)`)` attr-dict `->` type($output) `,` type($indices) }]; + + let hasVerifier = 1; } diff --git a/lib/Dialect/XTenNN/IR/XTenNNOps.cpp b/lib/Dialect/XTenNN/IR/XTenNNOps.cpp index 53e8d8c7..67cb6ca3 100644 --- a/lib/Dialect/XTenNN/IR/XTenNNOps.cpp +++ b/lib/Dialect/XTenNN/IR/XTenNNOps.cpp @@ -10,6 +10,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/OpImplementation.h" @@ -529,6 +530,17 @@ LogicalResult amd::xten_nn::ResizeOp::verify() { return success(); } +std::optional getConstantK(Operation *op) { + auto constantOp = dyn_cast(op); + if (!constantOp) + return {}; + auto intAttr = dyn_cast(constantOp.getValue()); + if (!intAttr) + return {}; + return (uint64_t) + intAttr.getInt(); // Always positive by definition of onnx.topk +} + LogicalResult TopK::inferReturnTypeComponents( MLIRContext *context, std::optional location, TopK::Adaptor adaptor, @@ -540,12 +552,19 @@ LogicalResult TopK::inferReturnTypeComponents( return emitOptionalError(location, "expected axis <= rank of input"); } auto dimSize = inTy.getDimSize(axis); - if ((uint64_t)dimSize < adaptor.getK()) { + uint64_t k = *getConstantK(adaptor.getK().getDefiningOp()); + + if (dimSize < 0) { + // TODO: Support negative dimSize + return emitOptionalError(location, "expected positive k"); + } + + if ((uint64_t)dimSize < k) { return emitOptionalError(location, "expected k <= dimension size"); } SmallVector resultShape{inTy.getShape()}; - resultShape[axis] = adaptor.getK(); + resultShape[axis] = k; inferredReturnShapes.push_back( ShapedTypeComponents(resultShape, inTy.getElementType())); @@ -553,3 +572,11 @@ LogicalResult TopK::inferReturnTypeComponents( ShapedTypeComponents(resultShape, IntegerType::get(context, 64))); return success(); } + +LogicalResult amd::xten_nn::TopK::verify() { + if (!isa(getK().getDefiningOp())) { + return failure(); + } + + return success(); +} \ No newline at end of file diff --git a/test/Dialect/XTenNN/ops.mlir b/test/Dialect/XTenNN/ops.mlir index 412abf3f..072d686e 100644 --- a/test/Dialect/XTenNN/ops.mlir +++ b/test/Dialect/XTenNN/ops.mlir @@ -41,9 +41,11 @@ func.func @kernel(%arg0: tensor<2xi64>, %arg1 : tensor<4xi64>) { // CHECK-LABEL: topk func.func @topk(%arg0: tensor<10x8xf32>) { - xten_nn.topk(%arg0 : tensor<10x8xf32>) {axis = 0 : i64, k = 7 : i64, largest = true, sorted = true} -> tensor<7x8xf32>, tensor<7x8xi64> - // CHECK: xten_nn.topk(%arg0 : tensor<10x8xf32>) {axis = 0 : i64, k = 7 : i64, largest = true, sorted = true} -> tensor<7x8xf32>, tensor<7x8xi64> - xten_nn.topk(%arg0 : tensor<10x8xf32>) {axis = 1 : i64, k = 7 : i64, largest = true, sorted = true} -> tensor<10x7xf32>, tensor<10x7xi64> - // CHECK: xten_nn.topk(%arg0 : tensor<10x8xf32>) {axis = 1 : i64, k = 7 : i64, largest = true, sorted = true} -> tensor<10x7xf32>, tensor<10x7xi64> + %k = arith.constant 7 : i64 + // CHECK: %[[C7:.*]] = arith.constant 7 : i64 + xten_nn.topk(%arg0 : tensor<10x8xf32>, %k : i64) {axis = 0 : i64, largest = true, sorted = true} -> tensor<7x8xf32>, tensor<7x8xi64> + // CHECK: xten_nn.topk(%arg0 : tensor<10x8xf32>, %[[C7]] : i64) {axis = 0 : i64, largest = true, sorted = true} -> tensor<7x8xf32>, tensor<7x8xi64> + xten_nn.topk(%arg0 : tensor<10x8xf32>, %k : i64) {axis = 1 : i64, largest = true, sorted = true} -> tensor<10x7xf32>, tensor<10x7xi64> + // CHECK: xten_nn.topk(%arg0 : tensor<10x8xf32>, %[[C7]] : i64) {axis = 1 : i64, largest = true, sorted = true} -> tensor<10x7xf32>, tensor<10x7xi64> return } diff --git a/test/Dialect/XTenNN/ops_invalid.mlir b/test/Dialect/XTenNN/ops_invalid.mlir index f272e6c3..b20fd1f0 100644 --- a/test/Dialect/XTenNN/ops_invalid.mlir +++ b/test/Dialect/XTenNN/ops_invalid.mlir @@ -71,35 +71,39 @@ func.func @kernel_missing_result(%arg0: i8, %arg1: i8) { // ----- func.func @topk_wrong_output_shape(%arg0: tensor<10x10xf32>) { + %k = arith.constant 7 : i64 // expected-error@+2 {{failed to infer returned types}} // expected-error@+1 {{'xten_nn.topk' op inferred type(s) 'tensor<7x10xf32>', 'tensor<7x10xi64>' are incompatible with return type(s) of operation 'tensor<1xf32>', 'tensor<1xi64>'}} - %a, %b = xten_nn.topk(%arg0 : tensor<10x10xf32>) {axis = 0 : i64, k = 7 : i64, largest = true, sorted = true} -> tensor<1xf32>, tensor<1xi64> + %a, %b = xten_nn.topk(%arg0 : tensor<10x10xf32>, %k : i64) {axis = 0 : i64, largest = true, sorted = true} -> tensor<1xf32>, tensor<1xi64> return } // ----- func.func @topk_wrong_indices_shape(%arg0: tensor<10x10xf32>) { + %k = arith.constant 7 : i64 // expected-error@+2 {{failed to infer returned types}} // expected-error@+1 {{'xten_nn.topk' op inferred type(s) 'tensor<7x10xf32>', 'tensor<7x10xi64>' are incompatible with return type(s) of operation 'tensor<7x10xf32>', 'tensor<7x10xf32>'}} - %a, %b = xten_nn.topk(%arg0 : tensor<10x10xf32>) {axis = 0 : i64, k = 7 : i64, largest = true, sorted = true} -> tensor<7x10xf32>, tensor<7x10xf32> + %a, %b = xten_nn.topk(%arg0 : tensor<10x10xf32>, %k : i64) {axis = 0 : i64, largest = true, sorted = true} -> tensor<7x10xf32>, tensor<7x10xf32> return } // ----- func.func @topk_wrong_axis(%arg0: tensor<10x10xf32>) { + %k = arith.constant 7 : i64 // expected-error@+2 {{failed to infer returned types}} // expected-error@+1 {{expected axis <= rank of input}} - %a, %b = xten_nn.topk(%arg0 : tensor<10x10xf32>) {axis = 3 : i64, k = 7 : i64, largest = true, sorted = true} -> tensor<10x10xf32>, tensor<1xi64> + %a, %b = xten_nn.topk(%arg0 : tensor<10x10xf32>, %k : i64) {axis = 3 : i64, largest = true, sorted = true} -> tensor<10x10xf32>, tensor<1xi64> return } // ----- func.func @topk_large_k(%arg0: tensor<10x10xf32>) { + %k = arith.constant 100 : i64 // expected-error@+2 {{failed to infer returned types}} // expected-error@+1 {{expected k <= dimension size}} - %a, %b = xten_nn.topk(%arg0 : tensor<10x10xf32>) {axis = 0 : i64, k = 100 : i64, largest = true, sorted = true} -> tensor<10x10xf32>, tensor<1xi64> + %a, %b = xten_nn.topk(%arg0 : tensor<10x10xf32>, %k : i64) {axis = 0 : i64, largest = true, sorted = true} -> tensor<10x10xf32>, tensor<1xi64> return }