Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Linalg][Vectorization] Add support for linalg vectorization of a tensor.extract case #107922

Merged

Conversation

nirvedhmeshram
Copy link
Contributor

@nirvedhmeshram nirvedhmeshram commented Sep 9, 2024

In #102321 we relaxed the vectorizer so that when checking for contiguous loads we dont always have a trailing non unit dim. For example in the test case added we have tensor<8x1xf32> which is now a valid candidate for contiguous load. However, the logic to check contiguous load assumed that only the trailing dim will be non unit so this PR just updates that logic to find the actual non unit dim.

@llvmbot
Copy link
Collaborator

llvmbot commented Sep 9, 2024

@llvm/pr-subscribers-mlir-linalg

Author: Nirvedh Meshram (nirvedhmeshram)

Changes

There is a case shown in #107476 that the current vectorization patterns cant handle. This PR provides a way of handling it by adding an extra tranpose op which showed get canceled with the existing transpose.
Fixes: #107476


Full diff: https://github.com/llvm/llvm-project/pull/107922.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+25)
  • (modified) mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir (+48)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 63dcda78d0f2be..16d1b1d6e0d0d4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1079,6 +1079,31 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
       continue;
     }
 
+    auto idxType = dyn_cast<VectorType>(idx.getType());
+
+    if (idxType && idxType.getShape().size() == resultType.getShape().size()) {
+      auto maxElement = std::max_element(resultType.getShape().begin(),
+                                         resultType.getShape().end());
+      auto maxElementDim =
+          std::distance(resultType.getShape().begin(), maxElement);
+      // This means that the result type of the index is non trailing and we
+      // insert transpose op in this case to match it to the extract type.
+      if (maxElementDim != resultType.getShape().size() - 1) {
+        SmallVector<int64_t> transposition = llvm::to_vector<16>(
+            llvm::seq<int64_t>(0, resultType.getShape().size()));
+        std::swap(transposition.back(), transposition[maxElementDim]);
+        auto transposeOp =
+            rewriter.create<vector::TransposeOp>(loc, idx, transposition);
+        auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
+            loc,
+            VectorType::get(*maxElement, rewriter.getIndexType(),
+                            resultType.getScalableDims().back()),
+            transposeOp);
+        transferReadIdxs.push_back(rewriter.create<vector::ExtractElementOp>(
+            loc, indexAs1dVector, zero));
+        continue;
+      }
+    }
     auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
         loc,
         VectorType::get(resultType.getShape().back(), rewriter.getIndexType(),
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index bdaa20c3bf971e..b66a0c4e4093b0 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -253,6 +253,54 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1, d2) -> (d0 + d1 + d2)>
+func.func @vectorize_nd_tensor_extract_transfer_without_outer_unit_dim(%arg0: tensor<8x128x768xf32>, %arg1 : index) -> tensor<8x1xf32> {
+  %c0 = arith.constant 0 : index
+  %0 = tensor.empty() : tensor<8x1xf32>
+  %1 = linalg.generic {
+    indexing_maps = [#map], 
+    iterator_types = ["parallel", "parallel"]
+  } outs(%0 : tensor<8x1xf32>) {
+  ^bb0(%arg5: f32):
+      %2 = linalg.index 0 : index
+      %3 = linalg.index 1 : index
+      %4 = affine.apply #map1(%arg1, %3, %arg1)
+    %extracted = tensor.extract %arg0[%2, %c0, %4] : tensor<8x128x768xf32>
+    linalg.yield %extracted : f32
+  } -> tensor<8x1xf32>
+  return %1 : tensor<8x1xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg2: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg2 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.vectorize_children_and_apply_patterns %1 {vectorize_nd_extract} : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_transfer_without_outer_unit_dim
+// CHECK-SAME: %[[ARG0:.*]]: tensor<8x128x768xf32>
+// CHECK-SAME: %[[ARG1:.*]]: index
+// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
+// CHECK-DAG: %[[C0_i32:.*]] = arith.constant 0 : i32
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:     %[[IDX0:.*]] = tensor.empty() : tensor<8x1xf32>
+// CHECK:     %[[IDX1:.*]] = vector.broadcast %[[CST_0]] : vector<8xindex> to vector<1x8xindex
+// CHECK:    %[[IDX2:.*]] = vector.transpose %[[IDX1]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
+// CHECK:    %[[IDX3:.*]] = arith.addi %[[ARG1]], %[[ARG1]] : index
+// CHECK:    %[[IDX4:.*]] = vector.transpose %[[IDX2]], [1, 0] : vector<8x1xindex> to vector<1x8xindex>
+// CHECK:    %[[IDX5:.*]] = vector.shape_cast %[[IDX4]] : vector<1x8xindex> to vector<8xindex>
+// CHECK:    %[[IDX6:.*]] = vector.extractelement %[[IDX5]][%[[C0_i32]] : i32] : vector<8xindex>
+// CHECK:    %[[IDX7:.*]] = vector.transfer_read %[[ARG0]][%[[IDX6]], %[[C0]], %[[IDX3]]], %[[CST]] {in_bounds = [true, true]} : tensor<8x128x768xf32>, vector<8x1xf32>
+// CHECK:    vector.transfer_write %[[IDX7]], %[[IDX0]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32>
+
 // -----
 
 #map = affine_map<(d0) -> (d0)>

@llvmbot
Copy link
Collaborator

llvmbot commented Sep 9, 2024

@llvm/pr-subscribers-mlir

Author: Nirvedh Meshram (nirvedhmeshram)

Changes

There is a case shown in #107476 that the current vectorization patterns cant handle. This PR provides a way of handling it by adding an extra tranpose op which showed get canceled with the existing transpose.
Fixes: #107476


Full diff: https://github.com/llvm/llvm-project/pull/107922.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+25)
  • (modified) mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir (+48)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 63dcda78d0f2be..16d1b1d6e0d0d4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1079,6 +1079,31 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
       continue;
     }
 
+    auto idxType = dyn_cast<VectorType>(idx.getType());
+
+    if (idxType && idxType.getShape().size() == resultType.getShape().size()) {
+      auto maxElement = std::max_element(resultType.getShape().begin(),
+                                         resultType.getShape().end());
+      auto maxElementDim =
+          std::distance(resultType.getShape().begin(), maxElement);
+      // This means that the result type of the index is non trailing and we
+      // insert transpose op in this case to match it to the extract type.
+      if (maxElementDim != resultType.getShape().size() - 1) {
+        SmallVector<int64_t> transposition = llvm::to_vector<16>(
+            llvm::seq<int64_t>(0, resultType.getShape().size()));
+        std::swap(transposition.back(), transposition[maxElementDim]);
+        auto transposeOp =
+            rewriter.create<vector::TransposeOp>(loc, idx, transposition);
+        auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
+            loc,
+            VectorType::get(*maxElement, rewriter.getIndexType(),
+                            resultType.getScalableDims().back()),
+            transposeOp);
+        transferReadIdxs.push_back(rewriter.create<vector::ExtractElementOp>(
+            loc, indexAs1dVector, zero));
+        continue;
+      }
+    }
     auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
         loc,
         VectorType::get(resultType.getShape().back(), rewriter.getIndexType(),
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index bdaa20c3bf971e..b66a0c4e4093b0 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -253,6 +253,54 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1, d2) -> (d0 + d1 + d2)>
+func.func @vectorize_nd_tensor_extract_transfer_without_outer_unit_dim(%arg0: tensor<8x128x768xf32>, %arg1 : index) -> tensor<8x1xf32> {
+  %c0 = arith.constant 0 : index
+  %0 = tensor.empty() : tensor<8x1xf32>
+  %1 = linalg.generic {
+    indexing_maps = [#map], 
+    iterator_types = ["parallel", "parallel"]
+  } outs(%0 : tensor<8x1xf32>) {
+  ^bb0(%arg5: f32):
+      %2 = linalg.index 0 : index
+      %3 = linalg.index 1 : index
+      %4 = affine.apply #map1(%arg1, %3, %arg1)
+    %extracted = tensor.extract %arg0[%2, %c0, %4] : tensor<8x128x768xf32>
+    linalg.yield %extracted : f32
+  } -> tensor<8x1xf32>
+  return %1 : tensor<8x1xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg2: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg2 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.vectorize_children_and_apply_patterns %1 {vectorize_nd_extract} : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_transfer_without_outer_unit_dim
+// CHECK-SAME: %[[ARG0:.*]]: tensor<8x128x768xf32>
+// CHECK-SAME: %[[ARG1:.*]]: index
+// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
+// CHECK-DAG: %[[C0_i32:.*]] = arith.constant 0 : i32
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:     %[[IDX0:.*]] = tensor.empty() : tensor<8x1xf32>
+// CHECK:     %[[IDX1:.*]] = vector.broadcast %[[CST_0]] : vector<8xindex> to vector<1x8xindex
+// CHECK:    %[[IDX2:.*]] = vector.transpose %[[IDX1]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
+// CHECK:    %[[IDX3:.*]] = arith.addi %[[ARG1]], %[[ARG1]] : index
+// CHECK:    %[[IDX4:.*]] = vector.transpose %[[IDX2]], [1, 0] : vector<8x1xindex> to vector<1x8xindex>
+// CHECK:    %[[IDX5:.*]] = vector.shape_cast %[[IDX4]] : vector<1x8xindex> to vector<8xindex>
+// CHECK:    %[[IDX6:.*]] = vector.extractelement %[[IDX5]][%[[C0_i32]] : i32] : vector<8xindex>
+// CHECK:    %[[IDX7:.*]] = vector.transfer_read %[[ARG0]][%[[IDX6]], %[[C0]], %[[IDX3]]], %[[CST]] {in_bounds = [true, true]} : tensor<8x128x768xf32>, vector<8x1xf32>
+// CHECK:    vector.transfer_write %[[IDX7]], %[[IDX0]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32>
+
 // -----
 
 #map = affine_map<(d0) -> (d0)>

@dcaballe
Copy link
Contributor

Hi, thanks for the fix! Could you please elaborate a bit more on "a way of handling it by adding an extra tranpose op". I'm not sure I can infer what that means. An IR example in the description would help.

@nirvedhmeshram
Copy link
Contributor Author

Hi, thanks for the fix! Could you please elaborate a bit more on "a way of handling it by adding an extra tranpose op". I'm not sure I can infer what that means. An IR example in the description would help.

Sounds good, seems like this is creating wrong IR as is, but once we have a resolution with help from @banach-space I can post an IR in the PR description of what we are generating for this case.

@nirvedhmeshram
Copy link
Contributor Author

@banach-space could you please take another look at this when you have a chance?

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes sense, thank you and sorry for the delay. A few more minor comments/suggestions.

@nirvedhmeshram
Copy link
Contributor Author

nirvedhmeshram commented Sep 19, 2024

@banach-space #100582 breaks the logic here as one of the test cases is

%5 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} outs(%arg1 : tensor<?x?xf32>) {
^bb0(%out: f32):
  %6 = linalg.index 1 : index
  %7 = arith.addi %6, %arg2 : index
  %extracted = tensor.extract %arg0[%c79, %7] : tensor<?x?xf32>
  linalg.yield %extracted : f32
} -> tensor<?x?xf32>

where it will find two non unit dims and hit the assert

@nirvedhmeshram
Copy link
Contributor Author

nirvedhmeshram commented Sep 19, 2024

@banach-space #100582 breaks the logic here as one of the test cases is

%5 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} outs(%arg1 : tensor<?x?xf32>) {
^bb0(%out: f32):
  %6 = linalg.index 1 : index
  %7 = arith.addi %6, %arg2 : index
  %extracted = tensor.extract %arg0[%c79, %7] : tensor<?x?xf32>
  linalg.yield %extracted : f32
} -> tensor<?x?xf32>

where it will find two non unit dims and hit the assert

I adapted the logic to check for dynamic dims and return the trailing dim if no non dynamic / non unit dim is found. That seems to work although seems ad-hoc.

@banach-space
Copy link
Contributor

@banach-space #100582 breaks the logic here as one of the test cases is

%5 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} outs(%arg1 : tensor<?x?xf32>) {
^bb0(%out: f32):
  %6 = linalg.index 1 : index
  %7 = arith.addi %6, %arg2 : index
  %extracted = tensor.extract %arg0[%c79, %7] : tensor<?x?xf32>
  linalg.yield %extracted : f32
} -> tensor<?x?xf32>

where it will find two non unit dims and hit the assert

Apologies, I didn't expect that to be such a curve ball. But I am really glad that we are hitting these cases, it really helps to stress-test this logic. Hopefully my explanation makes sense. I've also deleted a couple of my comments - #100582 invalidated them 😅

@llvm llvm deleted a comment from github-actions bot Sep 20, 2024
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for fixing this and for bearing with me 🙏🏻

I will try to prepare a follow-on to clarify the differences between loading from statically and dynamically shaped tensors. That's two cases. There's also masked vectorisation for statically shaped tensors, so that's 3 cases in total 😅

@nirvedhmeshram
Copy link
Contributor Author

nirvedhmeshram commented Sep 21, 2024

LGTM, thanks for fixing this and for bearing with me 🙏🏻

I will try to prepare a follow-on to clarify the differences between loading from statically and dynamically shaped tensors. That's two cases. There's also masked vectorisation for statically shaped tensors, so that's 3 cases in total 😅

Thank you so much for help on this and taking up supporting those other cases. This back and forth gave me a lot of new insight into vectorization so no complaints here :)

@nirvedhmeshram nirvedhmeshram merged commit e45fc51 into llvm:main Sep 21, 2024
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants