Skip to content

Commit

Permalink
[Linalg][Vectorization] Add support for linalg vectorization case wit…
Browse files Browse the repository at this point in the history
…h outer non unit dim
  • Loading branch information
nirvedhmeshram committed Sep 18, 2024
1 parent bdf0224 commit 6c6ed79
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 4 deletions.
21 changes: 17 additions & 4 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,20 @@ static Value calculateGatherOffset(RewriterBase &rewriter,

enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather };

static uint64_t getNonUnitLoopDim(LinalgOp linalgOp) {
uint64_t nonUnitDim = 0;
uint64_t countNonUnitDim = 0;
for (auto tripCount : llvm::enumerate(linalgOp.getStaticLoopRanges())) {
if (tripCount.value() != 1) {
nonUnitDim = tripCount.index();
countNonUnitDim++;
}
}
assert(countNonUnitDim == 1 &&
"Expected only one non unit loop dim in this linalg op");
return nonUnitDim;
}

/// Checks whether /p val can be used for calculating a loop invariant index.
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {

Expand Down Expand Up @@ -890,11 +904,10 @@ static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
Operation *defOp = val.getDefiningOp();
assert(defOp && "This is neither a block argument nor an operation result");

// Given the assumption on the loop ranges above, only the trailing loop
// index is not constant.
auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
auto nonUnitLoopDim = getNonUnitLoopDim(linalgOp);

if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
foundIndexOp = (indexOp.getDim() == trailingLoopDim);
foundIndexOp = (indexOp.getDim() == nonUnitLoopDim);
return true;
}

Expand Down
52 changes: 52 additions & 0 deletions mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,58 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}

// -----

#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1, d2) -> (d0 + d1 + d2)>
func.func @vectorize_nd_tensor_extract_without_outer_unit_dim(%arg0: tensor<8x128x768xf32>, %arg1 : index) -> tensor<8x1xf32> {
%c0 = arith.constant 0 : index
%0 = tensor.empty() : tensor<8x1xf32>
%1 = linalg.generic {
indexing_maps = [#map],
iterator_types = ["parallel", "parallel"]
} outs(%0 : tensor<8x1xf32>) {
^bb0(%arg5: f32):
%2 = linalg.index 0 : index
%3 = linalg.index 1 : index
%4 = affine.apply #map1(%arg1, %3, %arg1)
%extracted = tensor.extract %arg0[%2, %c0, %4] : tensor<8x128x768xf32>
linalg.yield %extracted : f32
} -> tensor<8x1xf32>
return %1 : tensor<8x1xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg2: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg2 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 {vectorize_nd_extract} : (!transform.any_op) -> !transform.any_op
transform.yield
}
}

// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_without_outer_unit_dim
// CHECK-SAME: %[[ARG0:.*]]: tensor<8x128x768xf32>
// CHECK-SAME: %[[ARG1:.*]]: index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<768> : vector<1x8xindex>
// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<128> : vector<1x8xindex>
// CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<8x1xf32>
// CHECK-DAG: %[[CST_2:.*]] = arith.constant dense<true> : vector<8x1xi1>
// CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<8x1xf32>
// CHECK: %[[B1:.*]] = vector.broadcast %[[CST_3]] : vector<8xindex> to vector<1x8xindex>
// CHECK: %[[ADDI_ARG1:.*]] = arith.addi %[[ARG1]], %[[ARG1]] : index
// CHECK: %[[B2:.*]] = vector.broadcast %[[ADDI_ARG1]] : index to vector<1xindex>
// CHECK: %[[MULI_1:.*]] = arith.muli %[[B1]], %[[CST_0]] : vector<1x8xindex>
// CHECK: %[[MULI_2:.*]] = arith.muli %[[MULI_1]], %[[CST]] : vector<1x8xindex>
// CHECK: %[[T:.*]] = vector.transpose %[[MULI_2]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
// CHECK: %[[B3:.*]] = vector.broadcast %[[B2]] : vector<1xindex> to vector<8x1xindex>
// CHECK: %[[ADDI:.*]] = arith.addi %[[B3]], %[[T]] : vector<8x1xindex>
// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[ADDI]]], %[[CST_2]], %[[PASSTHRU]] : tensor<8x128x768xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32>
// CHECK: vector.transfer_write %[[GATHER]], %[[EMPTY]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32>

// -----

#map = affine_map<(d0) -> (d0)>
Expand Down

0 comments on commit 6c6ed79

Please sign in to comment.