Skip to content

Commit

Permalink
[mlir][tensor] Add consumer fusion for tensor.pack op. (llvm#103715)
Browse files Browse the repository at this point in the history
Add missing `getIterationDomainTileFromOperandTile` and `getTiledImplementationFromOperandTile` to `tensor.pack` and enable fusing it as a consumer. NOTE that, it only expects perfect tiling scenario without padding semantic currently.
  • Loading branch information
Yun-Fly authored and cjdb committed Aug 23, 2024
1 parent 4254207 commit 26aa37b
Show file tree
Hide file tree
Showing 2 changed files with 173 additions and 0 deletions.
114 changes: 114 additions & 0 deletions mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,120 @@ struct PackOpTiling
return failure();
return tilingResult.value();
}

/// Method to return the position of iteration domain tile computed by the
/// tiled operation. In current `tensor.pack` context, the `resultOffsets` and
/// `resultSizes` only cover outer dimensions.
LogicalResult getIterationDomainTileFromOperandTile(
Operation *op, OpBuilder &b, unsigned operandNumber,
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
SmallVectorImpl<OpFoldResult> &resultOffsets,
SmallVectorImpl<OpFoldResult> &resultSizes) const {
if (operandNumber != 0)
return failure();

auto packOp = cast<PackOp>(op);
// It is not trivial to infer dest tile from source tile if `packOp` has
// padding semantic.
if (packOp.getPaddingValue())
return failure();

Location loc = packOp.getLoc();

SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
packOp.getDimAndTileMapping();
for (auto dim : packOp.getOuterDimsPerm()) {
if (dimAndTileMapping.count(dim)) {
FailureOr<int64_t> cstSize =
ValueBoundsConstraintSet::computeConstantBound(
presburger::BoundType::UB, sizes[dim],
/*stopCondition=*/nullptr, /*closedUB=*/true);
std::optional<int64_t> cstInnerSize =
getConstantIntValue(dimAndTileMapping[dim]);
// Currently fusing `packOp` as consumer only expects perfect tiling
// scenario because even if without padding semantic, the `packOp` may
// also yield incomplete tiles. E.g. tensor<30xf32> -> tensor<5x6xf32>,
// where the `tileSize` from operand of `packOp` is 5, which is not
// exactly divided by `innerTile`(=6) of `packOp`. As the result:
// 1. the first slice is extracted from (0) to (4) and inserted into
// (0,0)~(0,4) at first row.
// 2. the second slice is extracted from (5) to (9) and SHOULD BE
// respectively inserted into two rows with different length, including
// first row: (0,5) and second row (1,0)~(1,3). It is hard to coordinate
// them, thus adding below constraint to bypass them temporarily. In
// another word, we can only support tiling with consumer if the tile
// size for the producer is a multiple of the inner tile size for the
// packed dimensions at this moment.
if (failed(cstSize) || !cstInnerSize || *cstSize % *cstInnerSize != 0) {
return failure();
}

using AV = affine::AffineValueExpr;
affine::AffineBuilder ab(b, loc);
AffineExpr dim0, sym;
bindDims(b.getContext(), dim0);
bindSymbols(b.getContext(), sym);
auto avOffset = AV(dim0).bind(offsets[dim]);
auto avSize = AV(dim0).bind(sizes[dim]);
auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);
outerDimOffsets.push_back(ab.floor(avOffset, avTileSize));
outerDimSizes.push_back(ab.ceil(avSize, avTileSize));
} else {
outerDimOffsets.push_back(offsets[dim]);
outerDimSizes.push_back(sizes[dim]);
}
}

resultOffsets = outerDimOffsets;
resultSizes = outerDimSizes;
return success();
}

/// Method to return the tiled implementation of tensor.pack as a consumer.
FailureOr<TilingResult> getTiledImplementationFromOperandTile(
Operation *op, OpBuilder &b, unsigned operandNumber,
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
if (operandNumber != 0)
return failure();

auto packOp = cast<PackOp>(op);
Location loc = packOp.getLoc();

int64_t inputRank = packOp.getSourceRank();
auto oneAttr = b.getI64IntegerAttr(1);
SmallVector<OpFoldResult> strides(inputRank, oneAttr);

SmallVector<Value> tiledOperands;
tiledOperands.push_back(b.create<ExtractSliceOp>(loc, packOp.getSource(),
offsets, sizes, strides));

SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
if (failed(getIterationDomainTileFromOperandTile(
op, b, /*operandNumber=*/0, offsets, sizes, outerDimOffsets,
outerDimSizes)))
return failure();

SmallVector<OpFoldResult> outputOffsets, outputSizes;
if (failed(getResultTilePosition(op, b, 0, outerDimOffsets, outerDimSizes,
outputOffsets, outputSizes)))
return failure();

strides.append(packOp.getDestRank() - inputRank, oneAttr);
auto extractSlice = b.create<ExtractSliceOp>(
loc, packOp.getDest(), outputOffsets, outputSizes, strides);
tiledOperands.push_back(extractSlice);

assert(!packOp.getPaddingValue() && "Expect no padding semantic");
for (auto tile : packOp.getInnerTiles())
tiledOperands.push_back(tile);

Operation *tiledPackOp = b.create<PackOp>(
loc, TypeRange{extractSlice.getType()}, tiledOperands, op->getAttrs());

return TilingResult{{tiledPackOp},
SmallVector<Value>(tiledPackOp->getResults())};
}
};

struct UnpackTileDimInfo {
Expand Down
59 changes: 59 additions & 0 deletions mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -315,3 +315,62 @@ module attributes {transform.with_named_sequence} {
// CHECK: }
// CHECK: }
// CHECK: return %[[FINAL_RESULT]]#1 :

// -----

#map = affine_map<(d0, d1) -> (d0, d1)>
module {
func.func @fuse_pack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<4x32x16xf32> {
%c4 = arith.constant 4 : index
%c64 = arith.constant 64 : index
%c0 = arith.constant 0 : index
%1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
%extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
%3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
^bb0(%in: f32, %in_16: f32, %out: f32):
%13 = arith.mulf %in, %in_16 : f32
%14 = arith.addf %out, %13 : f32
linalg.yield %14 : f32
} -> tensor<32x32xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
}
}
%output = tensor.empty() : tensor<4x32x16xf32>
%pack = tensor.pack %1 outer_dims_perm = [0, 1] inner_dims_pos = [0] inner_tiles = [16] into %output : tensor<64x32xf32> -> tensor<4x32x16xf32>
return %pack : tensor<4x32x16xf32>
}
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
: (!transform.any_op) -> !transform.any_op
%a, %b = transform.test.fuse_consumer %slice_op
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}
// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
// CHECK: func.func @fuse_pack_consumer_into_scf_forall(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x32xf32>)
// CHECK: %[[OUT_INIT:.*]] = tensor.empty() : tensor<4x32x16xf32>
// CHECK: %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2)
// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[PACK_OUT_ARG:.*]] = %[[OUT_INIT]])
// CHECK-SAME: {
// CHECK: %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
// CHECK: %[[GENERIC_OUT:.*]] = linalg.generic
// CHECK-SAME: outs(%[[GENERIC_OUT_SLICE]] :
// CHECK: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV1]])
// CHECK: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], %[[IV2]], 0] [2, 32, 16] [1, 1, 1]
// CHECK: %[[TILED_PACK_OUT:.*]] = tensor.pack %[[GENERIC_OUT]]
// CHECK-SAME: outer_dims_perm = [0, 1] inner_dims_pos = [0] inner_tiles = [16]
// CHECK-SAME: into %[[TILED_PACK_DEST]]
// CHECK: scf.forall.in_parallel {
// CHECK: tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], %[[IV2]], 0] [2, 32, 16] [1, 1, 1]
// CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
// CHECK: }
// CHECK: }
// CHECK: return %[[FINAL_RESULT]]#1 :

0 comments on commit 26aa37b

Please sign in to comment.