diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index b139f1ef58b3a9..309573a562872f 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -21,10 +21,10 @@ include "mlir/IR/RegionKindInterface.td" // This is roughly similar to OpFoldResult assuming the handle produces a single // value in the payload IR. -def TransformParamTypeOrAnyHandle : Type< +def TransformAnyParamTypeOrAnyHandle : Type< Or<[TransformHandleTypeInterface.predicate, - Transform_ParamType.predicate]>, - "transform 'param' type or any handle type">; + TransformParamTypeInterface.predicate]>, + "transform any param type or any handle type">; //===----------------------------------------------------------------------===// // Apply...PatternsOp @@ -691,9 +691,9 @@ def MultiTileSizesOp : Op:$divisor); - let results = (outs TransformParamTypeOrAnyHandle:$low_size, - TransformParamTypeOrAnyHandle:$high_size, - TransformParamTypeOrAnyHandle:$split_point); + let results = (outs TransformAnyParamTypeOrAnyHandle:$low_size, + TransformAnyParamTypeOrAnyHandle:$high_size, + TransformAnyParamTypeOrAnyHandle:$split_point); let hasVerifier = 1; let assemblyFormat = "$target attr-dict `:` custom(" @@ -1408,7 +1408,7 @@ def SplitOp : Op:$dynamic_split_point, + Optional:$dynamic_split_point, I64Attr:$static_split_point); let results = (outs TransformHandleTypeInterface:$first, TransformHandleTypeInterface:$second); @@ -1857,7 +1857,7 @@ def TileUsingForOp : Op:$dynamic_sizes, + Variadic:$dynamic_sizes, DefaultValuedOptionalAttr:$static_sizes, DefaultValuedOptionalAttr:$interchange, DefaultValuedOptionalAttr:$scalable_sizes); @@ -1968,10 +1968,10 @@ def TileUsingForallOp : }]; let arguments = (ins TransformHandleTypeInterface:$target, - Variadic:$num_threads, - Variadic:$tile_sizes, - Optional:$packed_num_threads, - Optional:$packed_tile_sizes, + Variadic:$num_threads, + Variadic:$tile_sizes, + Optional:$packed_num_threads, + Optional:$packed_tile_sizes, DefaultValuedOptionalAttr:$static_num_threads, DefaultValuedOptionalAttr:$static_tile_sizes, OptionalAttr:$mapping); diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index df9e613e04aed3..6431bbd25396a5 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -86,8 +86,9 @@ static FailureOr tryApply(Operation *operation, Args &&...args) { return cast(result->getOperation()); } -/// Assuming that `ofr` is an index attr or a transform dialect handle mapped -/// to exactly one op with one index result, return that value. +/// Assuming that `ofr` is an index attr or a param of index type +/// or a transform dialect handle mapped to exactly one op +/// with one index result, return that value. static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations( transform::TransformState &state, TransformOpInterface transformOp, SmallVector &result, ArrayRef ofrs) { @@ -98,12 +99,23 @@ static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations( result.push_back(ofr); continue; } - auto payloadOps = state.getPayloadOps(ofr.get()); + + Value transformValue = ofr.get(); + if (isa(transformValue.getType())) { + ArrayRef params = state.getParams(transformValue); + if (params.size() != 1) + return transformOp.emitDefiniteFailure() + << "requires exactly one parameter associated"; + result.push_back(params[0]); + continue; + } + + auto payloadOps = state.getPayloadOps(transformValue); if (!llvm::hasSingleElement(payloadOps)) { DiagnosedSilenceableFailure diag = transformOp.emitSilenceableError() << "handle must be mapped to exactly one payload op"; - diag.attachNote(ofr.get().getLoc()) + diag.attachNote(transformValue.getLoc()) << "mapped to " << llvm::range_size(payloadOps) << " payload ops"; return diag; } @@ -123,14 +135,27 @@ static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations( return DiagnosedSilenceableFailure::success(); } -// Given a list of OpFoldResults that are either index attrs or op -// handles, return a list of OpFoldResults where all op handles are -// replaced with the first (and only) OpResult of that payload op. (There -// must be exactly one mapped payload op and it must have exactly one -// index result.) +// Given a list of params that are index attrs or a list of OpFoldResults +// that are either index attrs or op handles, return a list of OpFoldResults +// of index attrs or a list of OpFoldResults where all op handles are +// replaced with the first (and only) OpResult of that payload op. +// (There must be exactly one parameter associated with the AnyParamType or +// one mapped payload op which must have exactly one index result.) static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations( transform::TransformState &state, TransformOpInterface transformOp, SmallVector &result, Value packedHandle) { + if (isa(packedHandle.getType())) { + ArrayRef params = state.getParams(packedHandle); + for (auto param : params) { + if (!isa(param)) + return transformOp.emitDefiniteFailure() + << "expected the parameter to be associated with an integer " + "attribute"; + result.push_back(param); + } + return DiagnosedSilenceableFailure::success(); + } + for (Operation *op : state.getPayloadOps(packedHandle)) { if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) { DiagnosedSilenceableFailure diag = diff --git a/mlir/test/Dialect/Linalg/tile-to-forall.mlir b/mlir/test/Dialect/Linalg/tile-to-forall.mlir index 2192d160b1150f..abd807b3e4d3e1 100644 --- a/mlir/test/Dialect/Linalg/tile-to-forall.mlir +++ b/mlir/test/Dialect/Linalg/tile-to-forall.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s --transform-interpreter -canonicalize -cse -split-input-file | FileCheck %s +// RUN: mlir-opt %s --transform-interpreter -canonicalize -cse -split-input-file -verify-diagnostics | FileCheck %s // Offset per thread: // CHECK-DAG: affine_map<(d0)[s0] -> (d0 * (s0 ceildiv 10))> @@ -451,3 +451,138 @@ module attributes {transform.with_named_sequence} { } } +// ----- + +// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0] -> (s0 ceildiv 10)> +// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> (s0 ceildiv 20)> +// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0] -> (d0 * -10 + s0, 10)> +// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)> +// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0) -> (d0 * 10)> +// CHECK-DAG: #[[$map5:.+]] = affine_map<(d0) -> (d0 * 20)> + +// CHECK-LABEL: matmul_tile_size_dynamic( +// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor +// CHECK-SAME: %[[B:[0-9a-z]+]]: tensor +// CHECK-SAME: %[[C:[0-9a-z]+]]: tensor +func.func @matmul_tile_size_dynamic(%A: tensor, %B: tensor, %C: tensor) -> tensor { + // CHECK: %[[c1:.*]] = arith.constant 1 : index + // CHECK: %[[c0:.*]] = arith.constant 0 : index + // CHECK: %[[M:.+]] = tensor.dim %[[A]], %[[c0]] : + // CHECK: %[[N:.+]] = tensor.dim %[[B]], %[[c1]] : + // CHECK: %[[NT0:.+]] = affine.apply #map()[%[[M]]] + // CHECK: %[[NT1:.+]] = affine.apply #map1()[%[[N]]] + // CHECK: %[[K:.+]] = tensor.dim %[[A]], %[[c1]] : + // CHECK: scf.forall (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]]) shared_outs(%[[C_BLK:.*]] = %[[C]]) + // CHECK: %[[TS0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]] + // CHECK: %[[TS1:.+]] = affine.min #[[$map3]](%[[IV1]])[%[[N]]] + // CHECK: %[[LB0:.+]] = affine.apply #[[$map4]](%[[IV0]]) + // CHECK: %[[LB1:.+]] = affine.apply #[[$map5]](%[[IV1]]) + // CHECK: tensor.extract_slice %[[A]][%[[LB0]], 0] [%[[TS0]], %[[K]]] [1, 1] : + // CHECK: tensor.extract_slice %[[B]][0, %[[LB1]]] [%[[K]], %[[TS1]]] [1, 1] : + // CHECK: tensor.extract_slice %[[C_BLK]][%[[LB0]], %[[LB1]]] [%[[TS0]], %[[TS1]]] [1, 1] : + // CHECK: linalg.matmul + // CHECK: scf.forall.in_parallel + // CHECK-NEXT: tensor.parallel_insert_slice + %0 = linalg.matmul ins(%A, %B : tensor, tensor) + outs(%C : tensor) -> (tensor) + return %0 : tensor +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %sz = transform.param.constant 10 : i64 -> !transform.param + %1:2 = transform.structured.tile_using_forall %0 tile_sizes [%sz : !transform.param, 20] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +func.func @matmul_tile_size_dynamic(%A: tensor, %B: tensor, %C: tensor) -> tensor { + %0 = linalg.matmul ins(%A, %B : tensor, tensor) + outs(%C : tensor) -> (tensor) + return %0 : tensor +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.matmul_transpose_b"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %c10 = transform.param.constant 10 : i64 -> !transform.param + %c20 = transform.param.constant 20 : i64 -> !transform.param + %sz = transform.merge_handles %c10, %c20 : !transform.param + // expected-error @below {{requires exactly one parameter associated}} + %1:2 = transform.structured.tile_using_forall %0 tile_sizes [%sz : !transform.param, 20] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0] -> (s0 ceildiv 10)> +// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> (s0 ceildiv 20)> +// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0] -> (d0 * -10 + s0, 10)> +// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)> +// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0) -> (d0 * 10)> +// CHECK-DAG: #[[$map5:.+]] = affine_map<(d0) -> (d0 * 20)> + +// CHECK-LABEL: matmul_tile_size_dynamic( +// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor +// CHECK-SAME: %[[B:[0-9a-z]+]]: tensor +// CHECK-SAME: %[[C:[0-9a-z]+]]: tensor +func.func @matmul_tile_size_dynamic(%A: tensor, %B: tensor, %C: tensor) -> tensor { + // CHECK: %[[c1:.*]] = arith.constant 1 : index + // CHECK: %[[c0:.*]] = arith.constant 0 : index + // CHECK: %[[M:.+]] = tensor.dim %[[A]], %[[c0]] : + // CHECK: %[[N:.+]] = tensor.dim %[[B]], %[[c1]] : + // CHECK: %[[NT0:.+]] = affine.apply #map()[%[[M]]] + // CHECK: %[[NT1:.+]] = affine.apply #map1()[%[[N]]] + // CHECK: %[[K:.+]] = tensor.dim %[[A]], %[[c1]] : + // CHECK: scf.forall (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]]) shared_outs(%[[C_BLK:.*]] = %[[C]]) + // CHECK: %[[TS0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]] + // CHECK: %[[TS1:.+]] = affine.min #[[$map3]](%[[IV1]])[%[[N]]] + // CHECK: %[[LB0:.+]] = affine.apply #[[$map4]](%[[IV0]]) + // CHECK: %[[LB1:.+]] = affine.apply #[[$map5]](%[[IV1]]) + // CHECK: tensor.extract_slice %[[A]][%[[LB0]], 0] [%[[TS0]], %[[K]]] [1, 1] : + // CHECK: tensor.extract_slice %[[B]][0, %[[LB1]]] [%[[K]], %[[TS1]]] [1, 1] : + // CHECK: tensor.extract_slice %[[C_BLK]][%[[LB0]], %[[LB1]]] [%[[TS0]], %[[TS1]]] [1, 1] : + // CHECK: linalg.matmul + // CHECK: scf.forall.in_parallel + // CHECK-NEXT: tensor.parallel_insert_slice + %0 = linalg.matmul ins(%A, %B : tensor, tensor) + outs(%C : tensor) -> (tensor) + return %0 : tensor +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %c10 = transform.param.constant 10 : i64 -> !transform.any_param + %c20 = transform.param.constant 20 : i64 -> !transform.any_param + %sz = transform.merge_handles %c10, %c20 : !transform.any_param + %1:2 = transform.structured.tile_using_forall %0 tile_sizes *(%sz : !transform.any_param) + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +func.func @matmul_tile_size_dynamic(%A: tensor, %B: tensor, %C: tensor) -> tensor { + %0 = linalg.matmul ins(%A, %B : tensor, tensor) + outs(%C : tensor) -> (tensor) + return %0 : tensor +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %sz = transform.param.constant "[10 : i64, 20 : i64]" -> !transform.any_param + // expected-error @below {{expected the parameter to be associated with an integer attribute}} + %1:2 = transform.structured.tile_using_forall %0 tile_sizes *(%sz : !transform.any_param) + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +}