-
Notifications
You must be signed in to change notification settings - Fork 12k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][linalg] Vectorisation of tensor.extract - dynamic shapes #100582
[mlir][linalg] Vectorisation of tensor.extract - dynamic shapes #100582
Conversation
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: Andrzej Warzyński (banach-space) ChangesThis PR removes the assumption that reading from a dynamic tensor is %extracted = tensor.extract %src[%c79, %3] : tensor<?x?xf32> That assumption was originally introduced to simplify the implementation This is a relatively small change - rather than using the parent linalg As expected, the following test required updating (
Similar test for scalable vectors is also added. Full diff: https://github.com/llvm/llvm-project/pull/100582.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 9185663799e52..5c8d4a00bc35f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -808,14 +808,13 @@ static Value calculateGatherOffset(RewriterBase &rewriter,
enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather };
-/// Checks whether /p val can be used for calculating a loop invariant index.
-static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {
+/// Checks whether `val` can be used for calculating a loop invariant index.
+static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val, VectorType resType) {
- auto targetShape = linalgOp.getStaticLoopRanges();
- assert(((llvm::count_if(targetShape,
+ assert(((llvm::count_if(resType.getShape(),
[](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
"n-D vectors are not yet supported");
- assert(targetShape.back() != 1 &&
+ assert(resType.getShape().back() != 1 &&
"1-D vectors with the trailing dim eqaual 1 are not yet supported");
// Blocks outside _this_ linalg.generic are effectively loop invariant.
@@ -849,7 +848,7 @@ static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {
bool result = true;
for (auto op : ancestor->getOperands())
- result &= isLoopInvariantIdx(linalgOp, op);
+ result &= isLoopInvariantIdx(linalgOp, op, resType);
return result;
}
@@ -871,13 +870,12 @@ static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {
/// where <dim> is the trailing dim of the iteration space. \p foundIndexOp is
/// updated to `true` when such an op is found.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
- bool &foundIndexOp) {
+ bool &foundIndexOp, VectorType resType) {
- auto targetShape = linalgOp.getStaticLoopRanges();
- assert(((llvm::count_if(targetShape,
+ assert(((llvm::count_if(resType.getShape(),
[](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
"n-D vectors are not yet supported");
- assert(targetShape.back() != 1 &&
+ assert(resType.getShape().back() != 1 &&
"1-D vectors with the trailing dim 1 are not yet supported");
// Blocks outside _this_ linalg.generic are effectively loop invariant.
@@ -912,46 +910,40 @@ static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
bool result = false;
for (auto op : ancestor->getOperands())
- result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp);
+ result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp, resType);
return result;
}
/// Infer the memory access pattern for the input ExtractOp
///
-/// Based on the operation shapes and indices (usually based on the iteration
-/// space of the parent `linalgOp` operation), decides whether the input
-/// ExtractOp is a contiguous load (including a broadcast of a scalar) or a
-/// gather load.
+/// Based on the ExtratOp result shape and the access indices, decides whether
+/// this Op corresponds to a contiguous load (including a broadcast of a scalar)
+/// or a gather load. When analysing the ExtractOp indices (to identify
+/// contiguous laods), this method looks for "loop" invariant indices (e.g.
+/// block arguments) and indices that change linearly (e.g. via `linalg.index`
+/// Op).
///
/// Note that it is always safe to use gather load operations for contiguous
/// loads (albeit slow), but not vice-versa. When in doubt, bail out and assume
/// that `extractOp` is a gather load.
static VectorMemoryAccessKind
getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
- LinalgOp &linalgOp) {
+ LinalgOp &linalgOp, VectorType resType) {
- auto targetShape = linalgOp.getStaticLoopRanges();
auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
- // 0.1 Is this a 0-D vector? If yes then this is a scalar broadcast.
+ // 0. Is this a 0-D vector? If yes then this is a scalar broadcast.
if (inputShape.getShape().empty())
return VectorMemoryAccessKind::ScalarBroadcast;
- // 0.2 In the case of dynamic shapes just bail-out and assume that it's a
- // gather load.
- // TODO: Relax this condition.
- if (linalgOp.hasDynamicShape())
- return VectorMemoryAccessKind::Gather;
-
// 1. Assume that it's a gather load when reading _into_:
- // * an n-D "vector", like `tensor<1x2x4xi32` or `tensor<2x1x4xi32>`, or
- // * a 1-D "vector" with the trailing dim equal 1, e.g. `tensor<1x4x1xi32`.
+ // * an n-D "vector", like `vector<1x2x4xi32` or `vector<2x1x4xi32>`, or
+ // * a 1-D "vector" with the trailing dim equal 1, e.g. `vector<1x4x1xi32>`.
// TODO: Relax these conditions.
- // FIXME: This condition assumes non-dynamic sizes.
- if ((llvm::count_if(targetShape,
+ if ((llvm::count_if(resType.getShape(),
[](int64_t dimSize) { return dimSize > 1; }) != 1) ||
- targetShape.back() == 1)
+ resType.getShape().back() == 1)
return VectorMemoryAccessKind::Gather;
// 2. Assume that it's a gather load when reading _from_ a tensor for which
@@ -972,7 +964,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
if (inputShape.getShape()[i] == 1)
continue;
- leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal);
+ leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal, resType);
}
if (!leadingIdxsLoopInvariant) {
@@ -989,7 +981,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
// 4a. Scalar broadcast load
// If the trailing index is loop invariant then this is a scalar load.
if (leadingIdxsLoopInvariant &&
- isLoopInvariantIdx(linalgOp, extractOpTrailingIdx)) {
+ isLoopInvariantIdx(linalgOp, extractOpTrailingIdx, resType)) {
LDBG("Found scalar broadcast load: " << extractOp);
return VectorMemoryAccessKind::ScalarBroadcast;
@@ -1001,7 +993,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
// This is what the following bool captures.
bool foundIndexOp = false;
bool isContiguousLoad =
- isContiguousLoadIdx(linalgOp, extractOpTrailingIdx, foundIndexOp);
+ isContiguousLoadIdx(linalgOp, extractOpTrailingIdx, foundIndexOp, resType);
isContiguousLoad &= foundIndexOp;
if (isContiguousLoad) {
@@ -1042,7 +1034,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
rewriter.create<arith::ConstantIndexOp>(loc, 0));
VectorMemoryAccessKind memAccessKind =
- getTensorExtractMemoryAccessPattern(extractOp, linalgOp);
+ getTensorExtractMemoryAccessPattern(extractOp, linalgOp, resultType);
// 1. Handle gather access
if (memAccessKind == VectorMemoryAccessKind::Gather) {
diff --git a/mlir/test/Dialect/Linalg/vectorization-scalable.mlir b/mlir/test/Dialect/Linalg/vectorization-scalable.mlir
index 4ee3088cc3778..c3a30e3ee209e 100644
--- a/mlir/test/Dialect/Linalg/vectorization-scalable.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-scalable.mlir
@@ -162,17 +162,14 @@ func.func @vectorize_linalg_index(%arg0: tensor<3x3x?xf32>, %arg1: tensor<1x1x?x
// CHECK-LABEL: @vectorize_linalg_index
// CHECK-SAME: %[[SRC:.*]]: tensor<3x3x?xf32>, %[[DST:.*]]: tensor<1x1x?xf32>
-// CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x[4]xf32>
-// CHECK-DAG: %[[MASK:.*]] = arith.constant dense<true> : vector<1x1x[4]xi1>
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[DST_DIM2:.*]] = tensor.dim %[[DST]], %[[C2]] : tensor<1x1x?xf32>
-// CHECK: %[[DST_MASK:.*]] = vector.create_mask %[[C1]], %[[C1]], %[[DST_DIM2]] : vector<1x1x[4]xi1>
+// CHECK: %[[MASK:.*]] = vector.create_mask %[[C1]], %[[C1]], %[[DST_DIM2]] : vector<1x1x[4]xi1>
// CHECK: %[[INDEX_VEC:.*]] = vector.step : vector<[4]xindex>
-// CHECK: %[[INDEX_VEC_BCAST:.*]] = vector.broadcast %[[INDEX_VEC]] : vector<[4]xindex> to vector<1x1x[4]xindex>
-// CHECK: %[[GATHER:.*]] = vector.mask %[[DST_MASK]] { vector.gather %[[SRC]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {{\[}}%[[INDEX_VEC_BCAST]]], %[[MASK]], %[[PASSTHRU]] : tensor<3x3x?xf32>, vector<1x1x[4]xindex>, vector<1x1x[4]xi1>, vector<1x1x[4]xf32> into vector<1x1x[4]xf32> } : vector<1x1x[4]xi1> -> vector<1x1x[4]xf32>
-// CHECK: %[[OUT:.*]] = vector.mask %[[DST_MASK]] { vector.transfer_write %[[GATHER]], %[[DST]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x[4]xf32>, tensor<1x1x?xf32> } : vector<1x1x[4]xi1> -> tensor<1x1x?xf32>
+// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]][%c0, %c0, %2], %cst {in_bounds = [true, true, true]} : tensor<3x3x?xf32>, vector<1x1x[4]xf32> } : vector<1x1x[4]xi1> -> vector<1x1x[4]xf32>
+// CHECK: %[[OUT:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[DST]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x[4]xf32>, tensor<1x1x?xf32> } : vector<1x1x[4]xi1> -> tensor<1x1x?xf32>
// CHECK: return %[[OUT]] : tensor<1x1x?xf32>
module attributes {transform.with_named_sequence} {
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir
index 964565620fd01..31a754d934368 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir
@@ -120,52 +120,54 @@ module attributes {transform.with_named_sequence} {
// -----
-func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous(%6: tensor<?x?xf32>, %arg0: index, %extracted_slice : tensor<?x?xf32>) -> tensor<?x?xf32> {
+func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous(
+ %src: tensor<?x?xf32>,
+ %output : tensor<?x?xf32>,
+ %idx: index) -> tensor<?x?xf32> {
+
%c79 = arith.constant 79 : index
%1 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]
- } outs(%extracted_slice : tensor<?x?xf32>) {
+ } outs(%output : tensor<?x?xf32>) {
^bb0(%out: f32):
%2 = linalg.index 1 : index
- %3 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %arg0)
- %extracted = tensor.extract %6[%c79, %3] : tensor<?x?xf32>
+ %3 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %idx)
+ %extracted = tensor.extract %src[%c79, %3] : tensor<?x?xf32>
linalg.yield %extracted : f32
} -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
// CHECK-LABEL: func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xf32>,
-// CHECK-SAME: %[[VAL_1:.*]]: index,
-// CHECK-SAME: %[[VAL_2:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
-// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 79 : index
-// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_2]], %[[VAL_4]] : tensor<?x?xf32>
-// CHECK: %[[VAL_6:.*]] = arith.constant 1 : index
-// CHECK: %[[VAL_7:.*]] = tensor.dim %[[VAL_2]], %[[VAL_6]] : tensor<?x?xf32>
-// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[VAL_10:.*]] = vector.create_mask %[[VAL_5]], %[[VAL_7]] : vector<1x4xi1>
-// CHECK: %[[VAL_11:.*]] = vector.mask %[[VAL_10]] { vector.transfer_read %[[VAL_2]]{{\[}}%[[VAL_8]], %[[VAL_8]]], %[[VAL_9]] {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
-// CHECK: %[[VAL_12:.*]] = vector.step : vector<4xindex>
-// CHECK: %[[VAL_13:.*]] = vector.broadcast %[[VAL_1]] : index to vector<4xindex>
-// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_12]], %[[VAL_13]] : vector<4xindex>
-// CHECK-DAG: %[[VAL_15:.*]] = arith.constant dense<true> : vector<1x4xi1>
-// CHECK-DAG: %[[VAL_16:.*]] = arith.constant dense<0.000000e+00> : vector<1x4xf32>
-// CHECK-DAG: %[[VAL_17:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[VAL_18:.*]] = arith.constant dense<79> : vector<1x4xindex>
-// CHECK-DAG: %[[VAL_19:.*]] = arith.constant 1 : index
-// CHECK: %[[VAL_20:.*]] = tensor.dim %[[VAL_0]], %[[VAL_19]] : tensor<?x?xf32>
-// CHECK: %[[VAL_21:.*]] = vector.broadcast %[[VAL_20]] : index to vector<1x4xindex>
-// CHECK: %[[VAL_22:.*]] = arith.muli %[[VAL_18]], %[[VAL_21]] : vector<1x4xindex>
-// CHECK: %[[VAL_23:.*]] = vector.broadcast %[[VAL_14]] : vector<4xindex> to vector<1x4xindex>
-// CHECK: %[[VAL_24:.*]] = arith.addi %[[VAL_23]], %[[VAL_22]] : vector<1x4xindex>
-// CHECK: %[[VAL_25:.*]] = vector.mask %[[VAL_10]] { vector.gather %[[VAL_0]]{{\[}}%[[VAL_17]], %[[VAL_17]]] {{\[}}%[[VAL_24]]], %[[VAL_15]], %[[VAL_16]] : tensor<?x?xf32>, vector<1x4xindex>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
-// CHECK: %[[VAL_26:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_27:.*]] = vector.mask %[[VAL_10]] { vector.transfer_write %[[VAL_25]], %[[VAL_2]]{{\[}}%[[VAL_26]], %[[VAL_26]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<?x?xf32> } : vector<1x4xi1> -> tensor<?x?xf32>
-// CHECK: return %[[VAL_27]] : tensor<?x?xf32>
-// CHECK: }
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[OUTPUT:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[IDX:.*]]: index)
+
+/// Create the mask
+// CHECK: %[[C79:.*]] = arith.constant 79 : index
+// CHECK: %[[DIM_0_IDX:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM_0:.*]] = tensor.dim %[[OUTPUT]], %[[DIM_0_IDX]] : tensor<?x?xf32>
+// CHECK: %[[DIM_1_IDX:.*]] = arith.constant 1 : index
+// CHECK: %[[DIM_1:.*]] = tensor.dim %[[OUTPUT]], %[[DIM_1_IDX]] : tensor<?x?xf32>
+// CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM_0]], %[[DIM_1]] : vector<1x4xi1>
+
+/// TODO: This transfer_read is redundant - remove
+// CHECK: vector.mask %[[MASK]] { vector.transfer_read %[[OUTPUT]]{{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
+
+/// Caluclate the index vector
+// CHECK: %[[STEP:.*]] = vector.step : vector<4xindex>
+// CHECK: %[[IDX_BC:.*]] = vector.broadcast %[[IDX]] : index to vector<4xindex>
+// CHECK: %[[IDX_VEC:.*]] = arith.addi %[[STEP]], %[[IDX_BC]] : vector<4xindex>
+// CHECK: %[[C0:.*]] = arith.constant 0 : i32
+// CHECK: %[[SC:.*]] = vector.shape_cast %[[IDX_VEC]] : vector<4xindex> to vector<4xindex>
+
+/// Extract the starting point from the index vector
+// CHECK: %[[IDX_START:.*]] = vector.extractelement %[[SC]]{{\[}}%[[C0]] : i32] : vector<4xindex>
+
+// Final read and write
+// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]]{{\[}}%[[C79]], %[[IDX_START]]], {{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
+// CHECK: %[[VAL_24:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[OUTPUT]]{{.*}} {in_bounds = [true, true]} : vector<1x4xf32>, tensor<?x?xf32> } : vector<1x4xi1> -> tensor<?x?xf32>
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
@@ -177,6 +179,65 @@ module attributes {transform.with_named_sequence} {
// -----
+func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous_scalable(
+ %src: tensor<?x?xf32>,
+ %output : tensor<?x?xf32>,
+ %idx: index) -> tensor<?x?xf32> {
+
+ %c79 = arith.constant 79 : index
+ %1 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]
+ } outs(%output : tensor<?x?xf32>) {
+ ^bb0(%out: f32):
+ %2 = linalg.index 1 : index
+ %3 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %idx)
+ %extracted = tensor.extract %src[%c79, %3] : tensor<?x?xf32>
+ linalg.yield %extracted : f32
+ } -> tensor<?x?xf32>
+ return %1 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous_scalable(
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[OUTPUT:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[IDX:.*]]: index)
+
+/// Create the mask
+// CHECK: %[[C79:.*]] = arith.constant 79 : index
+// CHECK: %[[DIM_0_IDX:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM_0:.*]] = tensor.dim %[[OUTPUT]], %[[DIM_0_IDX]] : tensor<?x?xf32>
+// CHECK: %[[DIM_1_IDX:.*]] = arith.constant 1 : index
+// CHECK: %[[DIM_1:.*]] = tensor.dim %[[OUTPUT]], %[[DIM_1_IDX]] : tensor<?x?xf32>
+// CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM_0]], %[[DIM_1]] : vector<1x[4]xi1>
+
+/// TODO: This transfer_read is redundant - remove
+// CHECK: vector.mask %[[MASK]] { vector.transfer_read %[[OUTPUT]]{{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x[4]xf32> } : vector<1x[4]xi1> -> vector<1x[4]xf32>
+
+/// Caluclate the index vector
+// CHECK: %[[STEP:.*]] = vector.step : vector<[4]xindex>
+// CHECK: %[[IDX_BC:.*]] = vector.broadcast %[[IDX]] : index to vector<[4]xindex>
+// CHECK: %[[IDX_VEC:.*]] = arith.addi %[[STEP]], %[[IDX_BC]] : vector<[4]xindex>
+// CHECK: %[[C0:.*]] = arith.constant 0 : i32
+// CHECK: %[[SC:.*]] = vector.shape_cast %[[IDX_VEC]] : vector<[4]xindex> to vector<[4]xindex>
+
+/// Extract the starting point from the index vector
+// CHECK: %[[IDX_START:.*]] = vector.extractelement %[[SC]]{{\[}}%[[C0]] : i32] : vector<[4]xindex>
+
+// Final read and write
+// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]]{{\[}}%[[C79]], %[[IDX_START]]], {{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x[4]xf32> } : vector<1x[4]xi1> -> vector<1x[4]xf32>
+// CHECK: %[[VAL_24:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[OUTPUT]]{{.*}} {in_bounds = [true, true]} : vector<1x[4]xf32>, tensor<?x?xf32> } : vector<1x[4]xi1> -> tensor<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [1, [4]] {vectorize_nd_extract} : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
func.func @masked_vectorize_nd_tensor_extract_with_affine_apply_gather(%6: tensor<80x16xf32>, %arg0: index, %extracted_slice : tensor<1x3xf32>) -> tensor<1x3xf32> {
%c16 = arith.constant 16 : index
%1 = linalg.generic {
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
aae7571
to
ef8985a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, LG!
This PR removes the assumption that reading from a dynamic tensor is always a gather load: ```mlir %extracted = tensor.extract %src[%c79, %3] : tensor<?x?xf32> ``` That assumption was originally introduced to simplify the implementation and to reduce the number of cases to consider. Now that the vectorisation of `tensor.extract` has been around for > 1 year and has been quite stable, we can safely relax it. This is a relatively small change - rather than using the parent linalg Op to infer the target output shape (not possible with dynamic shapes), the vectorizer will use the (previously constructed) output vector shape instead. As expected, the following test required updating (`vector.gather` -> `vector.transfer_read`): * @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous Similar test for scalable vectors is also added.
ef8985a
to
ffb505a
Compare
…#100582) This PR removes the assumption that reading from a dynamic tensor is always a gather load: ```mlir %extracted = tensor.extract %src[%c79, %3] : tensor<?x?xf32> ``` That assumption was originally introduced to simplify the implementation and to reduce the number of cases to consider. Now that the vectorisation of `tensor.extract` has been around for > 1 year and has been quite stable, we can safely relax it. This is a relatively small change - rather than using the parent linalg Op to infer the target output shape (not possible with dynamic shapes), the vectorizer will use the (previously constructed) output vector shape instead. As expected, the following test required updating (`vector.gather` -> `vector.transfer_read`): * @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous Similar test for scalable vectors is also added.
This PR removes the assumption that reading from a dynamic tensor is
always a gather load:
That assumption was originally introduced to simplify the implementation
and to reduce the number of cases to consider. Now that the
vectorisation of
tensor.extract
has been around for > 1 year and hasbeen quite stable, we can safely relax it.
This is a relatively small change - rather than using the parent linalg
Op to infer the target output shape (not possible with dynamic shapes),
the vectorizer will use the (previously constructed) output vector
shape instead.
As expected, the following test required updating (
vector.gather
->vector.transfer_read
):Similar test for scalable vectors is also added.