Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][linalg] Vectorisation of tensor.extract - dynamic shapes #100582

Conversation

banach-space
Copy link
Contributor

This PR removes the assumption that reading from a dynamic tensor is
always a gather load:

%extracted = tensor.extract %src[%c79, %3] : tensor<?x?xf32>

That assumption was originally introduced to simplify the implementation
and to reduce the number of cases to consider. Now that the
vectorisation of tensor.extract has been around for > 1 year and has
been quite stable, we can safely relax it.

This is a relatively small change - rather than using the parent linalg
Op to infer the target output shape (not possible with dynamic shapes),
the vectorizer will use the (previously constructed) output vector
shape instead.

As expected, the following test required updating (vector.gather ->
vector.transfer_read):

  • @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous

Similar test for scalable vectors is also added.

@llvmbot
Copy link
Collaborator

llvmbot commented Jul 25, 2024

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

Changes

This PR removes the assumption that reading from a dynamic tensor is
always a gather load:

%extracted = tensor.extract %src[%c79, %3] : tensor&lt;?x?xf32&gt;

That assumption was originally introduced to simplify the implementation
and to reduce the number of cases to consider. Now that the
vectorisation of tensor.extract has been around for > 1 year and has
been quite stable, we can safely relax it.

This is a relatively small change - rather than using the parent linalg
Op to infer the target output shape (not possible with dynamic shapes),
the vectorizer will use the (previously constructed) output vector
shape instead.

As expected, the following test required updating (vector.gather ->
vector.transfer_read):

  • @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous

Similar test for scalable vectors is also added.


Full diff: https://github.com/llvm/llvm-project/pull/100582.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+25-33)
  • (modified) mlir/test/Dialect/Linalg/vectorization-scalable.mlir (+3-6)
  • (modified) mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir (+95-34)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 9185663799e52..5c8d4a00bc35f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -808,14 +808,13 @@ static Value calculateGatherOffset(RewriterBase &rewriter,
 
 enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather };
 
-/// Checks whether /p val can be used for calculating a loop invariant index.
-static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {
+/// Checks whether `val` can be used for calculating a loop invariant index.
+static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val, VectorType resType) {
 
-  auto targetShape = linalgOp.getStaticLoopRanges();
-  assert(((llvm::count_if(targetShape,
+  assert(((llvm::count_if(resType.getShape(),
                           [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
          "n-D vectors are not yet supported");
-  assert(targetShape.back() != 1 &&
+  assert(resType.getShape().back() != 1 &&
          "1-D vectors with the trailing dim eqaual 1 are not yet supported");
 
   // Blocks outside _this_ linalg.generic are effectively loop invariant.
@@ -849,7 +848,7 @@ static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {
 
   bool result = true;
   for (auto op : ancestor->getOperands())
-    result &= isLoopInvariantIdx(linalgOp, op);
+    result &= isLoopInvariantIdx(linalgOp, op, resType);
 
   return result;
 }
@@ -871,13 +870,12 @@ static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {
 /// where <dim> is the trailing dim of the iteration space. \p foundIndexOp is
 /// updated to `true` when such an op is found.
 static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
-                                bool &foundIndexOp) {
+                                bool &foundIndexOp, VectorType resType) {
 
-  auto targetShape = linalgOp.getStaticLoopRanges();
-  assert(((llvm::count_if(targetShape,
+  assert(((llvm::count_if(resType.getShape(),
                           [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
          "n-D vectors are not yet supported");
-  assert(targetShape.back() != 1 &&
+  assert(resType.getShape().back() != 1 &&
          "1-D vectors with the trailing dim 1 are not yet supported");
 
   // Blocks outside _this_ linalg.generic are effectively loop invariant.
@@ -912,46 +910,40 @@ static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
 
   bool result = false;
   for (auto op : ancestor->getOperands())
-    result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp);
+    result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp, resType);
 
   return result;
 }
 
 /// Infer the memory access pattern for the input ExtractOp
 ///
-/// Based on the operation shapes and indices (usually based on the iteration
-/// space of the parent `linalgOp` operation), decides whether the input
-/// ExtractOp is a contiguous load (including a broadcast of a scalar) or a
-/// gather load.
+/// Based on the ExtratOp result shape and the access indices, decides whether
+/// this Op corresponds to a contiguous load (including a broadcast of a scalar)
+/// or a gather load. When analysing the ExtractOp indices (to identify
+/// contiguous laods), this method looks for "loop" invariant indices (e.g.
+/// block arguments) and indices that change linearly (e.g. via `linalg.index`
+/// Op).
 ///
 /// Note that it is always safe to use gather load operations for contiguous
 /// loads (albeit slow), but not vice-versa. When in doubt, bail out and assume
 /// that `extractOp` is a gather load.
 static VectorMemoryAccessKind
 getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
-                                    LinalgOp &linalgOp) {
+                                    LinalgOp &linalgOp, VectorType resType) {
 
-  auto targetShape = linalgOp.getStaticLoopRanges();
   auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
 
-  // 0.1 Is this a 0-D vector? If yes then this is a scalar broadcast.
+  // 0. Is this a 0-D vector? If yes then this is a scalar broadcast.
   if (inputShape.getShape().empty())
     return VectorMemoryAccessKind::ScalarBroadcast;
 
-  // 0.2 In the case of dynamic shapes just bail-out and assume that it's a
-  // gather load.
-  // TODO: Relax this condition.
-  if (linalgOp.hasDynamicShape())
-    return VectorMemoryAccessKind::Gather;
-
   // 1. Assume that it's a gather load when reading _into_:
-  //    * an n-D "vector", like `tensor<1x2x4xi32` or `tensor<2x1x4xi32>`, or
-  //    * a 1-D "vector" with the trailing dim equal 1, e.g. `tensor<1x4x1xi32`.
+  //    * an n-D "vector", like `vector<1x2x4xi32` or `vector<2x1x4xi32>`, or
+  //    * a 1-D "vector" with the trailing dim equal 1, e.g. `vector<1x4x1xi32>`.
   // TODO: Relax these conditions.
-  // FIXME: This condition assumes non-dynamic sizes.
-  if ((llvm::count_if(targetShape,
+  if ((llvm::count_if(resType.getShape(),
                       [](int64_t dimSize) { return dimSize > 1; }) != 1) ||
-      targetShape.back() == 1)
+      resType.getShape().back() == 1)
     return VectorMemoryAccessKind::Gather;
 
   // 2. Assume that it's a gather load when reading _from_ a tensor for which
@@ -972,7 +964,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
     if (inputShape.getShape()[i] == 1)
       continue;
 
-    leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal);
+    leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal, resType);
   }
 
   if (!leadingIdxsLoopInvariant) {
@@ -989,7 +981,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
   // 4a. Scalar broadcast load
   // If the trailing index is loop invariant then this is a scalar load.
   if (leadingIdxsLoopInvariant &&
-      isLoopInvariantIdx(linalgOp, extractOpTrailingIdx)) {
+      isLoopInvariantIdx(linalgOp, extractOpTrailingIdx, resType)) {
     LDBG("Found scalar broadcast load: " << extractOp);
 
     return VectorMemoryAccessKind::ScalarBroadcast;
@@ -1001,7 +993,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
   // This is what the following bool captures.
   bool foundIndexOp = false;
   bool isContiguousLoad =
-      isContiguousLoadIdx(linalgOp, extractOpTrailingIdx, foundIndexOp);
+      isContiguousLoadIdx(linalgOp, extractOpTrailingIdx, foundIndexOp, resType);
   isContiguousLoad &= foundIndexOp;
 
   if (isContiguousLoad) {
@@ -1042,7 +1034,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
       rewriter.create<arith::ConstantIndexOp>(loc, 0));
 
   VectorMemoryAccessKind memAccessKind =
-      getTensorExtractMemoryAccessPattern(extractOp, linalgOp);
+      getTensorExtractMemoryAccessPattern(extractOp, linalgOp, resultType);
 
   // 1. Handle gather access
   if (memAccessKind == VectorMemoryAccessKind::Gather) {
diff --git a/mlir/test/Dialect/Linalg/vectorization-scalable.mlir b/mlir/test/Dialect/Linalg/vectorization-scalable.mlir
index 4ee3088cc3778..c3a30e3ee209e 100644
--- a/mlir/test/Dialect/Linalg/vectorization-scalable.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-scalable.mlir
@@ -162,17 +162,14 @@ func.func @vectorize_linalg_index(%arg0: tensor<3x3x?xf32>, %arg1: tensor<1x1x?x
 
 // CHECK-LABEL: @vectorize_linalg_index
 // CHECK-SAME: %[[SRC:.*]]: tensor<3x3x?xf32>, %[[DST:.*]]: tensor<1x1x?xf32>
-// CHECK-DAG:    %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x[4]xf32>
-// CHECK-DAG:        %[[MASK:.*]] = arith.constant dense<true> : vector<1x1x[4]xi1>
 // CHECK-DAG:          %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:          %[[C1:.*]] = arith.constant 1 : index
 // CHECK-DAG:          %[[C2:.*]] = arith.constant 2 : index
 // CHECK:        %[[DST_DIM2:.*]] = tensor.dim %[[DST]], %[[C2]] : tensor<1x1x?xf32>
-// CHECK:        %[[DST_MASK:.*]] = vector.create_mask %[[C1]], %[[C1]], %[[DST_DIM2]] : vector<1x1x[4]xi1>
+// CHECK:        %[[MASK:.*]] = vector.create_mask %[[C1]], %[[C1]], %[[DST_DIM2]] : vector<1x1x[4]xi1>
 // CHECK:       %[[INDEX_VEC:.*]] = vector.step : vector<[4]xindex>
-// CHECK: %[[INDEX_VEC_BCAST:.*]] = vector.broadcast %[[INDEX_VEC]] : vector<[4]xindex> to vector<1x1x[4]xindex>
-// CHECK:          %[[GATHER:.*]] = vector.mask %[[DST_MASK]] { vector.gather %[[SRC]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {{\[}}%[[INDEX_VEC_BCAST]]], %[[MASK]], %[[PASSTHRU]] : tensor<3x3x?xf32>, vector<1x1x[4]xindex>, vector<1x1x[4]xi1>, vector<1x1x[4]xf32> into vector<1x1x[4]xf32> } : vector<1x1x[4]xi1> -> vector<1x1x[4]xf32>
-// CHECK:             %[[OUT:.*]] = vector.mask %[[DST_MASK]] { vector.transfer_write %[[GATHER]], %[[DST]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x[4]xf32>, tensor<1x1x?xf32> } : vector<1x1x[4]xi1> -> tensor<1x1x?xf32>
+// CHECK:            %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]][%c0, %c0, %2], %cst {in_bounds = [true, true, true]} : tensor<3x3x?xf32>, vector<1x1x[4]xf32> } : vector<1x1x[4]xi1> -> vector<1x1x[4]xf32>
+// CHECK:             %[[OUT:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[DST]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x[4]xf32>, tensor<1x1x?xf32> } : vector<1x1x[4]xi1> -> tensor<1x1x?xf32>
 // CHECK:           return %[[OUT]] : tensor<1x1x?xf32>
 
 module attributes {transform.with_named_sequence} {
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir
index 964565620fd01..31a754d934368 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir
@@ -120,52 +120,54 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous(%6: tensor<?x?xf32>, %arg0: index, %extracted_slice : tensor<?x?xf32>) -> tensor<?x?xf32> {
+func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous(
+    %src: tensor<?x?xf32>,
+    %output : tensor<?x?xf32>,
+    %idx: index) -> tensor<?x?xf32> {
+
   %c79 = arith.constant 79 : index
   %1 = linalg.generic {
     indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>],
     iterator_types = ["parallel", "parallel"]
-  } outs(%extracted_slice : tensor<?x?xf32>) {
+  } outs(%output : tensor<?x?xf32>) {
   ^bb0(%out: f32):
     %2 = linalg.index 1 : index
-    %3 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %arg0)
-    %extracted = tensor.extract %6[%c79, %3] : tensor<?x?xf32>
+    %3 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %idx)
+    %extracted = tensor.extract %src[%c79, %3] : tensor<?x?xf32>
     linalg.yield %extracted : f32
   } -> tensor<?x?xf32>
   return %1 : tensor<?x?xf32>
 }
 
 // CHECK-LABEL:   func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous(
-// CHECK-SAME:                                                                                       %[[VAL_0:.*]]: tensor<?x?xf32>,
-// CHECK-SAME:                                                                                       %[[VAL_1:.*]]: index,
-// CHECK-SAME:                                                                                       %[[VAL_2:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
-// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 79 : index
-// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_5:.*]] = tensor.dim %[[VAL_2]], %[[VAL_4]] : tensor<?x?xf32>
-// CHECK:           %[[VAL_6:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_7:.*]] = tensor.dim %[[VAL_2]], %[[VAL_6]] : tensor<?x?xf32>
-// CHECK-DAG:       %[[VAL_8:.*]] = arith.constant 0 : index
-// CHECK-DAG:       %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:           %[[VAL_10:.*]] = vector.create_mask %[[VAL_5]], %[[VAL_7]] : vector<1x4xi1>
-// CHECK:           %[[VAL_11:.*]] = vector.mask %[[VAL_10]] { vector.transfer_read %[[VAL_2]]{{\[}}%[[VAL_8]], %[[VAL_8]]], %[[VAL_9]] {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
-// CHECK:           %[[VAL_12:.*]] = vector.step : vector<4xindex>
-// CHECK:           %[[VAL_13:.*]] = vector.broadcast %[[VAL_1]] : index to vector<4xindex>
-// CHECK:           %[[VAL_14:.*]] = arith.addi %[[VAL_12]], %[[VAL_13]] : vector<4xindex>
-// CHECK-DAG:       %[[VAL_15:.*]] = arith.constant dense<true> : vector<1x4xi1>
-// CHECK-DAG:       %[[VAL_16:.*]] = arith.constant dense<0.000000e+00> : vector<1x4xf32>
-// CHECK-DAG:       %[[VAL_17:.*]] = arith.constant 0 : index
-// CHECK-DAG:       %[[VAL_18:.*]] = arith.constant dense<79> : vector<1x4xindex>
-// CHECK-DAG:       %[[VAL_19:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_20:.*]] = tensor.dim %[[VAL_0]], %[[VAL_19]] : tensor<?x?xf32>
-// CHECK:           %[[VAL_21:.*]] = vector.broadcast %[[VAL_20]] : index to vector<1x4xindex>
-// CHECK:           %[[VAL_22:.*]] = arith.muli %[[VAL_18]], %[[VAL_21]] : vector<1x4xindex>
-// CHECK:           %[[VAL_23:.*]] = vector.broadcast %[[VAL_14]] : vector<4xindex> to vector<1x4xindex>
-// CHECK:           %[[VAL_24:.*]] = arith.addi %[[VAL_23]], %[[VAL_22]] : vector<1x4xindex>
-// CHECK:           %[[VAL_25:.*]] = vector.mask %[[VAL_10]] { vector.gather %[[VAL_0]]{{\[}}%[[VAL_17]], %[[VAL_17]]] {{\[}}%[[VAL_24]]], %[[VAL_15]], %[[VAL_16]] : tensor<?x?xf32>, vector<1x4xindex>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
-// CHECK:           %[[VAL_26:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_27:.*]] = vector.mask %[[VAL_10]] { vector.transfer_write %[[VAL_25]], %[[VAL_2]]{{\[}}%[[VAL_26]], %[[VAL_26]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<?x?xf32> } : vector<1x4xi1> -> tensor<?x?xf32>
-// CHECK:           return %[[VAL_27]] : tensor<?x?xf32>
-// CHECK:         }
+// CHECK-SAME:      %[[SRC:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
+// CHECK-SAME:      %[[OUTPUT:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
+// CHECK-SAME:      %[[IDX:.*]]: index)
+
+/// Create the mask
+// CHECK:           %[[C79:.*]] = arith.constant 79 : index
+// CHECK:           %[[DIM_0_IDX:.*]] = arith.constant 0 : index
+// CHECK:           %[[DIM_0:.*]] = tensor.dim %[[OUTPUT]], %[[DIM_0_IDX]] : tensor<?x?xf32>
+// CHECK:           %[[DIM_1_IDX:.*]] = arith.constant 1 : index
+// CHECK:           %[[DIM_1:.*]] = tensor.dim %[[OUTPUT]], %[[DIM_1_IDX]] : tensor<?x?xf32>
+// CHECK:           %[[MASK:.*]] = vector.create_mask %[[DIM_0]], %[[DIM_1]] : vector<1x4xi1>
+
+/// TODO: This transfer_read is redundant - remove
+// CHECK:           vector.mask %[[MASK]] { vector.transfer_read %[[OUTPUT]]{{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
+
+/// Caluclate the index vector
+// CHECK:           %[[STEP:.*]] = vector.step : vector<4xindex>
+// CHECK:           %[[IDX_BC:.*]] = vector.broadcast %[[IDX]] : index to vector<4xindex>
+// CHECK:           %[[IDX_VEC:.*]] = arith.addi %[[STEP]], %[[IDX_BC]] : vector<4xindex>
+// CHECK:           %[[C0:.*]] = arith.constant 0 : i32
+// CHECK:           %[[SC:.*]] = vector.shape_cast %[[IDX_VEC]] : vector<4xindex> to vector<4xindex>
+
+/// Extract the starting point from the index vector
+// CHECK:           %[[IDX_START:.*]] = vector.extractelement %[[SC]]{{\[}}%[[C0]] : i32] : vector<4xindex>
+
+// Final read and write
+// CHECK:           %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]]{{\[}}%[[C79]], %[[IDX_START]]], {{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
+// CHECK:           %[[VAL_24:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[OUTPUT]]{{.*}} {in_bounds = [true, true]} : vector<1x4xf32>, tensor<?x?xf32> } : vector<1x4xi1> -> tensor<?x?xf32>
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
@@ -177,6 +179,65 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous_scalable(
+    %src: tensor<?x?xf32>,
+    %output : tensor<?x?xf32>,
+    %idx: index) -> tensor<?x?xf32> {
+
+  %c79 = arith.constant 79 : index
+  %1 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel"]
+  } outs(%output : tensor<?x?xf32>) {
+  ^bb0(%out: f32):
+    %2 = linalg.index 1 : index
+    %3 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %idx)
+    %extracted = tensor.extract %src[%c79, %3] : tensor<?x?xf32>
+    linalg.yield %extracted : f32
+  } -> tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL:   func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous_scalable(
+// CHECK-SAME:      %[[SRC:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
+// CHECK-SAME:      %[[OUTPUT:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
+// CHECK-SAME:      %[[IDX:.*]]: index)
+
+/// Create the mask
+// CHECK:           %[[C79:.*]] = arith.constant 79 : index
+// CHECK:           %[[DIM_0_IDX:.*]] = arith.constant 0 : index
+// CHECK:           %[[DIM_0:.*]] = tensor.dim %[[OUTPUT]], %[[DIM_0_IDX]] : tensor<?x?xf32>
+// CHECK:           %[[DIM_1_IDX:.*]] = arith.constant 1 : index
+// CHECK:           %[[DIM_1:.*]] = tensor.dim %[[OUTPUT]], %[[DIM_1_IDX]] : tensor<?x?xf32>
+// CHECK:           %[[MASK:.*]] = vector.create_mask %[[DIM_0]], %[[DIM_1]] : vector<1x[4]xi1>
+
+/// TODO: This transfer_read is redundant - remove
+// CHECK:           vector.mask %[[MASK]] { vector.transfer_read %[[OUTPUT]]{{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x[4]xf32> } : vector<1x[4]xi1> -> vector<1x[4]xf32>
+
+/// Caluclate the index vector
+// CHECK:           %[[STEP:.*]] = vector.step : vector<[4]xindex>
+// CHECK:           %[[IDX_BC:.*]] = vector.broadcast %[[IDX]] : index to vector<[4]xindex>
+// CHECK:           %[[IDX_VEC:.*]] = arith.addi %[[STEP]], %[[IDX_BC]] : vector<[4]xindex>
+// CHECK:           %[[C0:.*]] = arith.constant 0 : i32
+// CHECK:           %[[SC:.*]] = vector.shape_cast %[[IDX_VEC]] : vector<[4]xindex> to vector<[4]xindex>
+
+/// Extract the starting point from the index vector
+// CHECK:           %[[IDX_START:.*]] = vector.extractelement %[[SC]]{{\[}}%[[C0]] : i32] : vector<[4]xindex>
+
+// Final read and write
+// CHECK:           %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]]{{\[}}%[[C79]], %[[IDX_START]]], {{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x[4]xf32> } : vector<1x[4]xi1> -> vector<1x[4]xf32>
+// CHECK:           %[[VAL_24:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[OUTPUT]]{{.*}} {in_bounds = [true, true]} : vector<1x[4]xf32>, tensor<?x?xf32> } : vector<1x[4]xi1> -> tensor<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+     %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+     transform.structured.vectorize %0 vector_sizes [1, [4]] {vectorize_nd_extract} : !transform.any_op
+     transform.yield
+  }
+}
+
+// -----
+
 func.func @masked_vectorize_nd_tensor_extract_with_affine_apply_gather(%6: tensor<80x16xf32>, %arg0: index, %extracted_slice : tensor<1x3xf32>) -> tensor<1x3xf32> {
   %c16 = arith.constant 16 : index
   %1 = linalg.generic {

Copy link

github-actions bot commented Jul 25, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@banach-space banach-space force-pushed the andrzej/tensor_extract_vectorisation_dynamic branch from aae7571 to ef8985a Compare July 25, 2024 16:22
@banach-space banach-space requested a review from MacDue July 26, 2024 15:52
@banach-space banach-space changed the title [mlir][linalg] Upgrade vectorisation of tensor.extract [mlir][linalg] Vectorisation of tensor.extract - dynamic shapes Aug 5, 2024
Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, LG!

This PR removes the assumption that reading from a dynamic tensor is
always a gather load:

```mlir
%extracted = tensor.extract %src[%c79, %3] : tensor<?x?xf32>
```

That assumption was originally introduced to simplify the implementation
and to reduce the number of cases to consider. Now that the
vectorisation of `tensor.extract` has been around for > 1 year and has
been quite stable, we can safely relax it.

This is a relatively small change - rather than using the parent linalg
Op to infer the target output shape (not possible with dynamic shapes),
the vectorizer will use the (previously constructed) output vector
shape instead.

As expected, the following test required updating (`vector.gather` ->
`vector.transfer_read`):
* @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous

Similar test for scalable vectors is also added.
@banach-space banach-space force-pushed the andrzej/tensor_extract_vectorisation_dynamic branch from ef8985a to ffb505a Compare September 19, 2024 18:31
@banach-space banach-space merged commit 315ba77 into llvm:main Sep 19, 2024
6 of 7 checks passed
@banach-space banach-space deleted the andrzej/tensor_extract_vectorisation_dynamic branch September 19, 2024 18:56
tmsri pushed a commit to tmsri/llvm-project that referenced this pull request Sep 19, 2024
…#100582)

This PR removes the assumption that reading from a dynamic tensor is
always a gather load:

```mlir
%extracted = tensor.extract %src[%c79, %3] : tensor<?x?xf32>
```

That assumption was originally introduced to simplify the implementation
and to reduce the number of cases to consider. Now that the
vectorisation of `tensor.extract` has been around for > 1 year and has
been quite stable, we can safely relax it.

This is a relatively small change - rather than using the parent linalg
Op to infer the target output shape (not possible with dynamic shapes),
the vectorizer will use the (previously constructed) output vector
shape instead.

As expected, the following test required updating (`vector.gather` ->
`vector.transfer_read`):
*
@masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous

Similar test for scalable vectors is also added.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants