-
Notifications
You must be signed in to change notification settings - Fork 11.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Linalg][Vectorization] Add support for linalg vectorization of a tensor.extract case #107922
[Linalg][Vectorization] Add support for linalg vectorization of a tensor.extract case #107922
Conversation
@llvm/pr-subscribers-mlir-linalg Author: Nirvedh Meshram (nirvedhmeshram) ChangesThere is a case shown in #107476 that the current vectorization patterns cant handle. This PR provides a way of handling it by adding an extra tranpose op which showed get canceled with the existing transpose. Full diff: https://github.com/llvm/llvm-project/pull/107922.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 63dcda78d0f2be..16d1b1d6e0d0d4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1079,6 +1079,31 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
continue;
}
+ auto idxType = dyn_cast<VectorType>(idx.getType());
+
+ if (idxType && idxType.getShape().size() == resultType.getShape().size()) {
+ auto maxElement = std::max_element(resultType.getShape().begin(),
+ resultType.getShape().end());
+ auto maxElementDim =
+ std::distance(resultType.getShape().begin(), maxElement);
+ // This means that the result type of the index is non trailing and we
+ // insert transpose op in this case to match it to the extract type.
+ if (maxElementDim != resultType.getShape().size() - 1) {
+ SmallVector<int64_t> transposition = llvm::to_vector<16>(
+ llvm::seq<int64_t>(0, resultType.getShape().size()));
+ std::swap(transposition.back(), transposition[maxElementDim]);
+ auto transposeOp =
+ rewriter.create<vector::TransposeOp>(loc, idx, transposition);
+ auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
+ loc,
+ VectorType::get(*maxElement, rewriter.getIndexType(),
+ resultType.getScalableDims().back()),
+ transposeOp);
+ transferReadIdxs.push_back(rewriter.create<vector::ExtractElementOp>(
+ loc, indexAs1dVector, zero));
+ continue;
+ }
+ }
auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
loc,
VectorType::get(resultType.getShape().back(), rewriter.getIndexType(),
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index bdaa20c3bf971e..b66a0c4e4093b0 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -253,6 +253,54 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1, d2) -> (d0 + d1 + d2)>
+func.func @vectorize_nd_tensor_extract_transfer_without_outer_unit_dim(%arg0: tensor<8x128x768xf32>, %arg1 : index) -> tensor<8x1xf32> {
+ %c0 = arith.constant 0 : index
+ %0 = tensor.empty() : tensor<8x1xf32>
+ %1 = linalg.generic {
+ indexing_maps = [#map],
+ iterator_types = ["parallel", "parallel"]
+ } outs(%0 : tensor<8x1xf32>) {
+ ^bb0(%arg5: f32):
+ %2 = linalg.index 0 : index
+ %3 = linalg.index 1 : index
+ %4 = affine.apply #map1(%arg1, %3, %arg1)
+ %extracted = tensor.extract %arg0[%2, %c0, %4] : tensor<8x128x768xf32>
+ linalg.yield %extracted : f32
+ } -> tensor<8x1xf32>
+ return %1 : tensor<8x1xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg2: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg2 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 {vectorize_nd_extract} : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_transfer_without_outer_unit_dim
+// CHECK-SAME: %[[ARG0:.*]]: tensor<8x128x768xf32>
+// CHECK-SAME: %[[ARG1:.*]]: index
+// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
+// CHECK-DAG: %[[C0_i32:.*]] = arith.constant 0 : i32
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[IDX0:.*]] = tensor.empty() : tensor<8x1xf32>
+// CHECK: %[[IDX1:.*]] = vector.broadcast %[[CST_0]] : vector<8xindex> to vector<1x8xindex
+// CHECK: %[[IDX2:.*]] = vector.transpose %[[IDX1]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
+// CHECK: %[[IDX3:.*]] = arith.addi %[[ARG1]], %[[ARG1]] : index
+// CHECK: %[[IDX4:.*]] = vector.transpose %[[IDX2]], [1, 0] : vector<8x1xindex> to vector<1x8xindex>
+// CHECK: %[[IDX5:.*]] = vector.shape_cast %[[IDX4]] : vector<1x8xindex> to vector<8xindex>
+// CHECK: %[[IDX6:.*]] = vector.extractelement %[[IDX5]][%[[C0_i32]] : i32] : vector<8xindex>
+// CHECK: %[[IDX7:.*]] = vector.transfer_read %[[ARG0]][%[[IDX6]], %[[C0]], %[[IDX3]]], %[[CST]] {in_bounds = [true, true]} : tensor<8x128x768xf32>, vector<8x1xf32>
+// CHECK: vector.transfer_write %[[IDX7]], %[[IDX0]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32>
+
// -----
#map = affine_map<(d0) -> (d0)>
|
@llvm/pr-subscribers-mlir Author: Nirvedh Meshram (nirvedhmeshram) ChangesThere is a case shown in #107476 that the current vectorization patterns cant handle. This PR provides a way of handling it by adding an extra tranpose op which showed get canceled with the existing transpose. Full diff: https://github.com/llvm/llvm-project/pull/107922.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 63dcda78d0f2be..16d1b1d6e0d0d4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1079,6 +1079,31 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
continue;
}
+ auto idxType = dyn_cast<VectorType>(idx.getType());
+
+ if (idxType && idxType.getShape().size() == resultType.getShape().size()) {
+ auto maxElement = std::max_element(resultType.getShape().begin(),
+ resultType.getShape().end());
+ auto maxElementDim =
+ std::distance(resultType.getShape().begin(), maxElement);
+ // This means that the result type of the index is non trailing and we
+ // insert transpose op in this case to match it to the extract type.
+ if (maxElementDim != resultType.getShape().size() - 1) {
+ SmallVector<int64_t> transposition = llvm::to_vector<16>(
+ llvm::seq<int64_t>(0, resultType.getShape().size()));
+ std::swap(transposition.back(), transposition[maxElementDim]);
+ auto transposeOp =
+ rewriter.create<vector::TransposeOp>(loc, idx, transposition);
+ auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
+ loc,
+ VectorType::get(*maxElement, rewriter.getIndexType(),
+ resultType.getScalableDims().back()),
+ transposeOp);
+ transferReadIdxs.push_back(rewriter.create<vector::ExtractElementOp>(
+ loc, indexAs1dVector, zero));
+ continue;
+ }
+ }
auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
loc,
VectorType::get(resultType.getShape().back(), rewriter.getIndexType(),
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index bdaa20c3bf971e..b66a0c4e4093b0 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -253,6 +253,54 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1, d2) -> (d0 + d1 + d2)>
+func.func @vectorize_nd_tensor_extract_transfer_without_outer_unit_dim(%arg0: tensor<8x128x768xf32>, %arg1 : index) -> tensor<8x1xf32> {
+ %c0 = arith.constant 0 : index
+ %0 = tensor.empty() : tensor<8x1xf32>
+ %1 = linalg.generic {
+ indexing_maps = [#map],
+ iterator_types = ["parallel", "parallel"]
+ } outs(%0 : tensor<8x1xf32>) {
+ ^bb0(%arg5: f32):
+ %2 = linalg.index 0 : index
+ %3 = linalg.index 1 : index
+ %4 = affine.apply #map1(%arg1, %3, %arg1)
+ %extracted = tensor.extract %arg0[%2, %c0, %4] : tensor<8x128x768xf32>
+ linalg.yield %extracted : f32
+ } -> tensor<8x1xf32>
+ return %1 : tensor<8x1xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg2: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg2 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 {vectorize_nd_extract} : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_transfer_without_outer_unit_dim
+// CHECK-SAME: %[[ARG0:.*]]: tensor<8x128x768xf32>
+// CHECK-SAME: %[[ARG1:.*]]: index
+// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
+// CHECK-DAG: %[[C0_i32:.*]] = arith.constant 0 : i32
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[IDX0:.*]] = tensor.empty() : tensor<8x1xf32>
+// CHECK: %[[IDX1:.*]] = vector.broadcast %[[CST_0]] : vector<8xindex> to vector<1x8xindex
+// CHECK: %[[IDX2:.*]] = vector.transpose %[[IDX1]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
+// CHECK: %[[IDX3:.*]] = arith.addi %[[ARG1]], %[[ARG1]] : index
+// CHECK: %[[IDX4:.*]] = vector.transpose %[[IDX2]], [1, 0] : vector<8x1xindex> to vector<1x8xindex>
+// CHECK: %[[IDX5:.*]] = vector.shape_cast %[[IDX4]] : vector<1x8xindex> to vector<8xindex>
+// CHECK: %[[IDX6:.*]] = vector.extractelement %[[IDX5]][%[[C0_i32]] : i32] : vector<8xindex>
+// CHECK: %[[IDX7:.*]] = vector.transfer_read %[[ARG0]][%[[IDX6]], %[[C0]], %[[IDX3]]], %[[CST]] {in_bounds = [true, true]} : tensor<8x128x768xf32>, vector<8x1xf32>
+// CHECK: vector.transfer_write %[[IDX7]], %[[IDX0]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32>
+
// -----
#map = affine_map<(d0) -> (d0)>
|
3f78e9e
to
a1d86c8
Compare
Hi, thanks for the fix! Could you please elaborate a bit more on "a way of handling it by adding an extra tranpose op". I'm not sure I can infer what that means. An IR example in the description would help. |
Sounds good, seems like this is creating wrong IR as is, but once we have a resolution with help from @banach-space I can post an IR in the PR description of what we are generating for this case. |
a1d86c8
to
6c6ed79
Compare
@banach-space could you please take another look at this when you have a chance? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes sense, thank you and sorry for the delay. A few more minor comments/suggestions.
…h outer non unit dim
@banach-space #100582 breaks the logic here as one of the test cases is
where it will find two non unit dims and hit the assert |
I adapted the logic to check for dynamic dims and return the trailing dim if no non dynamic / non unit dim is found. That seems to work although seems ad-hoc. |
Apologies, I didn't expect that to be such a curve ball. But I am really glad that we are hitting these cases, it really helps to stress-test this logic. Hopefully my explanation makes sense. I've also deleted a couple of my comments - #100582 invalidated them 😅 |
a567a54
to
69ece45
Compare
69ece45
to
95189eb
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for fixing this and for bearing with me 🙏🏻
I will try to prepare a follow-on to clarify the differences between loading from statically and dynamically shaped tensors. That's two cases. There's also masked vectorisation for statically shaped tensors, so that's 3 cases in total 😅
Thank you so much for help on this and taking up supporting those other cases. This back and forth gave me a lot of new insight into vectorization so no complaints here :) |
In #102321 we relaxed the vectorizer so that when checking for contiguous loads we dont always have a trailing non unit dim. For example in the test case added we have
tensor<8x1xf32>
which is now a valid candidate for contiguous load. However, the logic to check contiguous load assumed that only the trailing dim will be non unit so this PR just updates that logic to find the actual non unit dim.