Skip to content

Commit

Permalink
[Linalg][Vectorization] Add support for linalg vectorization of a ten…
Browse files Browse the repository at this point in the history
…sor.extract case (#107922)

In #102321 we relaxed the
vectorizer so that when checking for contiguous loads we dont always
have a trailing non unit dim. For example in the test case added we have
`tensor<8x1xf32>` which is now a valid candidate for contiguous load.
However, the logic to check contiguous load assumed that only the
trailing dim will be non unit so this PR just updates that logic to find
the actual non unit dim.
  • Loading branch information
nirvedhmeshram committed Sep 21, 2024
1 parent c3f9b73 commit e45fc51
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 4 deletions.
31 changes: 27 additions & 4 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,28 @@ static Value calculateGatherOffset(RewriterBase &rewriter,

enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather };

/// Find the non-unit dim in a linalgOp.
/// When executing this hook, it is expected that only one dim will be non-unit.
/// Other cases (i.e. reading n-D vectors) should've been labelled as gather
/// loads before calling this method. This is used for finding contiguous loads
/// (represented as `tensor.extract`) within `linalg.generic` Ops. Note that
/// this condition is expected to hold for statically shaped Linalg Ops only.
static uint64_t getNonUnitLoopDim(LinalgOp linalgOp) {
uint64_t nonUnitDim = 0;
uint64_t countNonUnitDim = 0;
for (auto tripCount : llvm::enumerate(linalgOp.getStaticLoopRanges())) {
if (tripCount.value() != 1) {
nonUnitDim = tripCount.index();
countNonUnitDim++;
}
}

assert(linalgOp.hasDynamicShape() ||
countNonUnitDim == 1 && "For statically shaped Linalg Ops, only one "
"non-unit loop dim is expected");
return nonUnitDim;
}

/// Checks whether `val` can be used for calculating a loop invariant index.
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val,
VectorType resType) {
Expand Down Expand Up @@ -889,11 +911,12 @@ static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
Operation *defOp = val.getDefiningOp();
assert(defOp && "This is neither a block argument nor an operation result");

// Given the assumption on the loop ranges above, only the trailing loop
// index is not constant.
auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
// Given the assumption on the loop ranges above, we expect only 1 non-unit
// loop dim.
auto nonUnitLoopDim = getNonUnitLoopDim(linalgOp);

if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
foundIndexOp = (indexOp.getDim() == trailingLoopDim);
foundIndexOp = (indexOp.getDim() == nonUnitLoopDim);
return true;
}

Expand Down
52 changes: 52 additions & 0 deletions mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,58 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}

// -----

#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1, d2) -> (d0 + d1 + d2)>
func.func @vectorize_nd_tensor_extract_load_1d_column_vector_using_gather_load(%arg0: tensor<8x128x768xf32>, %arg1 : index) -> tensor<8x1xf32> {
%c0 = arith.constant 0 : index
%0 = tensor.empty() : tensor<8x1xf32>
%1 = linalg.generic {
indexing_maps = [#map],
iterator_types = ["parallel", "parallel"]
} outs(%0 : tensor<8x1xf32>) {
^bb0(%arg5: f32):
%2 = linalg.index 0 : index
%3 = linalg.index 1 : index
%4 = affine.apply #map1(%arg1, %3, %arg1)
%extracted = tensor.extract %arg0[%2, %c0, %4] : tensor<8x128x768xf32>
linalg.yield %extracted : f32
} -> tensor<8x1xf32>
return %1 : tensor<8x1xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 {vectorize_nd_extract} : (!transform.any_op) -> !transform.any_op
transform.yield
}
}

// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_load_1d_column_vector_using_gather_load
// CHECK-SAME: %[[ARG0:.*]]: tensor<8x128x768xf32>
// CHECK-SAME: %[[ARG1:.*]]: index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<768> : vector<1x8xindex>
// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<128> : vector<1x8xindex>
// CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<8x1xf32>
// CHECK-DAG: %[[CST_2:.*]] = arith.constant dense<true> : vector<8x1xi1>
// CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<8x1xf32>
// CHECK: %[[B1:.*]] = vector.broadcast %[[CST_3]] : vector<8xindex> to vector<1x8xindex>
// CHECK: %[[ADDI_ARG1:.*]] = arith.addi %[[ARG1]], %[[ARG1]] : index
// CHECK: %[[B2:.*]] = vector.broadcast %[[ADDI_ARG1]] : index to vector<1xindex>
// CHECK: %[[MULI_1:.*]] = arith.muli %[[B1]], %[[CST_0]] : vector<1x8xindex>
// CHECK: %[[MULI_2:.*]] = arith.muli %[[MULI_1]], %[[CST]] : vector<1x8xindex>
// CHECK: %[[T:.*]] = vector.transpose %[[MULI_2]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
// CHECK: %[[B3:.*]] = vector.broadcast %[[B2]] : vector<1xindex> to vector<8x1xindex>
// CHECK: %[[ADDI:.*]] = arith.addi %[[B3]], %[[T]] : vector<8x1xindex>
// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[ADDI]]], %[[CST_2]], %[[PASSTHRU]] : tensor<8x128x768xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32>
// CHECK: vector.transfer_write %[[GATHER]], %[[EMPTY]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32>

// -----

#map = affine_map<(d0) -> (d0)>
Expand Down

0 comments on commit e45fc51

Please sign in to comment.