Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][tensor] Add TilingInterface support for fusing tensor.pad #105892

Merged
merged 1 commit into from
Aug 23, 2024

Conversation

qedawkins
Copy link
Contributor

This adds implementations for the two TilingInterface methods required for fusion to tensor.pad: getIterationDomainTileFromResultTile and generateResultTileValue, allowing fusion of pad with a tiled consumer.

This adds implementations for the two TilingInterface methods required
for fusion to `tensor.pad`: `getIterationDomainTileFromResultTile` and
`generateResultTileValue`, allowing fusion of pad with a tiled
consumer.
@llvmbot
Copy link
Member

llvmbot commented Aug 23, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tensor

Author: Quinn Dawkins (qedawkins)

Changes

This adds implementations for the two TilingInterface methods required for fusion to tensor.pad: getIterationDomainTileFromResultTile and generateResultTileValue, allowing fusion of pad with a tiled consumer.


Full diff: https://github.com/llvm/llvm-project/pull/105892.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp (+17)
  • (modified) mlir/test/Dialect/Tensor/tiling.mlir (+41)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index dec678de6d1c27..f35a9cd4cb9275 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -67,6 +67,23 @@ struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> {
     resultSizes.assign(sizes.begin(), sizes.end());
     return success();
   }
+
+  LogicalResult getIterationDomainTileFromResultTile(
+      Operation *op, OpBuilder &b, unsigned resultNumber,
+      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+      SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
+      SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
+    iterDomainOffsets.assign(offsets.begin(), offsets.end());
+    iterDomainSizes.assign(sizes.begin(), sizes.end());
+    return success();
+  }
+
+  FailureOr<TilingResult>
+  generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
+                          ArrayRef<OpFoldResult> offsets,
+                          ArrayRef<OpFoldResult> sizes) const {
+    return getTiledImplementation(op, b, offsets, sizes);
+  }
 };
 
 template <typename OpTy>
diff --git a/mlir/test/Dialect/Tensor/tiling.mlir b/mlir/test/Dialect/Tensor/tiling.mlir
index e02ab06a9d5337..193fbe93e0f9ee 100644
--- a/mlir/test/Dialect/Tensor/tiling.mlir
+++ b/mlir/test/Dialect/Tensor/tiling.mlir
@@ -116,6 +116,47 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// CHECK-LABEL: func @fuse_static_pad_tensor_3_4(
+//  CHECK-SAME:     %[[IN:.*]]: tensor<7x9xf32>
+//   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
+//   CHECK-DAG:   %[[C3:.*]] = arith.constant 3 : index
+//   CHECK-DAG:   %[[C15:.*]] = arith.constant 15 : index
+//   CHECK-DAG:   %[[C16:.*]] = arith.constant 16 : index
+//       CHECK:   %[[RESULT:.*]] = scf.for {{.*}} = %[[C0]] to %[[C15]] step %[[C2]]
+//       CHECK:     scf.for {{.*}} = %[[C0]] to %[[C16]] step %[[C3]] iter_args(%[[INNER_OUT:.*]] =
+//       CHECK:       %[[SWAP_RESULT:.*]] = scf.if
+//       CHECK:         tensor.generate
+//       CHECK:       else
+//       CHECK:         %[[SLICE:.*]] = tensor.extract_slice %[[IN]][{{.*}}, {{.*}}] [{{.*}}, {{.*}}] [1, 1]
+//       CHECK:         %[[PAD:.*]] = tensor.pad %[[SLICE]]
+//       CHECK:       %[[COPY:.*]] = linalg.copy ins(%[[SWAP_RESULT:.*]]
+//       CHECK:       tensor.insert_slice %[[COPY]] into %[[INNER_OUT]][{{.*}}, {{.*}}] [{{.*}}, {{.*}}] [1, 1]
+//       CHECK:   return %[[RESULT]]
+
+func.func @fuse_static_pad_tensor_3_4(%input_tensor: tensor<7x9xf32>,
+                        %pad_value: f32) -> tensor<15x16xf32> {
+  %0 = tensor.pad %input_tensor low[3, 4] high[5, 3] {
+    ^bb0(%arg1: index, %arg2: index):
+      tensor.yield %pad_value : f32
+    } : tensor<7x9xf32> to tensor<15x16xf32>
+  %empty = tensor.empty() : tensor<15x16xf32>
+  %1 = linalg.copy ins(%0 : tensor<15x16xf32>) outs(%empty : tensor<15x16xf32>) -> tensor<15x16xf32>
+  return %1 : tensor<15x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %copy = transform.structured.match ops{["linalg.copy"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %a, %b, %c = transform.structured.fuse %copy [2, 3]
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+// -----
+
 // CHECK-LABEL: func @static_pad_tensor_0_3(
 //  CHECK-SAME:     %[[IN:.*]]: tensor<7x9xf32>
 //   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index

@qedawkins qedawkins merged commit 91e57c6 into llvm:main Aug 23, 2024
11 checks passed
@qedawkins qedawkins deleted the pad_fusion branch August 23, 2024 23:10
5chmidti pushed a commit that referenced this pull request Aug 24, 2024
…5892)

This adds implementations for the two TilingInterface methods required
for fusion to `tensor.pad`: `getIterationDomainTileFromResultTile` and
`generateResultTileValue`, allowing fusion of pad with a tiled consumer.
dmpolukhin pushed a commit to dmpolukhin/llvm-project that referenced this pull request Sep 2, 2024
…m#105892)

This adds implementations for the two TilingInterface methods required
for fusion to `tensor.pad`: `getIterationDomainTileFromResultTile` and
`generateResultTileValue`, allowing fusion of pad with a tiled consumer.
josel-amd added a commit to Xilinx/llvm-project that referenced this pull request Dec 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants