-
Notifications
You must be signed in to change notification settings - Fork 12.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][tensor] Add consumer fusion for tensor.pack
op.
#103715
Conversation
@llvm/pr-subscribers-mlir-tensor @llvm/pr-subscribers-mlir-scf Author: None (Yun-Fly) ChangesAdd missing Full diff: https://github.com/llvm/llvm-project/pull/103715.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index 361340a4e62f2d..51c232ae77fe6c 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -246,6 +246,97 @@ struct PackOpTiling
return failure();
return tilingResult.value();
}
+
+ /// Method to return the position of iteration domain tile computed by the
+ /// tiled operation. In current `tensor.pack` context, the `resultOffsets` and
+ /// `resultSizes` only cover outer dimensions.
+ LogicalResult getIterationDomainTileFromOperandTile(
+ Operation *op, OpBuilder &b, unsigned operandNumber,
+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+ SmallVectorImpl<OpFoldResult> &resultOffsets,
+ SmallVectorImpl<OpFoldResult> &resultSizes) const {
+ auto packOp = cast<PackOp>(op);
+ Location loc = packOp.getLoc();
+
+ SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
+ DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
+ packOp.getDimAndTileMapping();
+ for (auto dim : packOp.getOuterDimsPerm()) {
+ if (dimAndTileMapping.count(dim)) {
+ FailureOr<int64_t> cstSize =
+ ValueBoundsConstraintSet::computeConstantBound(
+ presburger::BoundType::UB, sizes[dim],
+ /*stopCondition=*/nullptr, /*closedUB=*/true);
+ std::optional<int64_t> cstInnerSize =
+ getConstantIntValue(dimAndTileMapping[dim]);
+ // Currently only expect perfect tiling cases.
+ if (failed(cstSize) || !cstInnerSize || *cstSize % *cstInnerSize != 0) {
+ return failure();
+ }
+
+ using AV = affine::AffineValueExpr;
+ affine::AffineBuilder ab(b, loc);
+ AffineExpr dim0, sym;
+ bindDims(b.getContext(), dim0);
+ bindSymbols(b.getContext(), sym);
+ auto avOffset = AV(dim0).bind(offsets[dim]);
+ auto avSize = AV(dim0).bind(sizes[dim]);
+ auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);
+ outerDimOffsets.push_back(ab.floor(avOffset, avTileSize));
+ outerDimSizes.push_back(ab.ceil(avSize, avTileSize));
+ } else {
+ outerDimOffsets.push_back(offsets[dim]);
+ outerDimSizes.push_back(sizes[dim]);
+ }
+ }
+
+ resultOffsets = outerDimOffsets;
+ resultSizes = outerDimSizes;
+ return success();
+ }
+
+ /// Method to return the tiled implementation of tensor.pack as a consumer.
+ FailureOr<TilingResult> getTiledImplementationFromOperandTile(
+ Operation *op, OpBuilder &b, unsigned operandNumber,
+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
+ auto packOp = cast<PackOp>(op);
+ Location loc = packOp.getLoc();
+
+ int64_t inputRank = packOp.getSourceRank();
+ auto oneAttr = b.getI64IntegerAttr(1);
+ SmallVector<OpFoldResult> strides(inputRank, oneAttr);
+
+ SmallVector<Value> tiledOperands;
+ tiledOperands.push_back(b.create<ExtractSliceOp>(loc, packOp.getSource(),
+ offsets, sizes, strides));
+
+ SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
+ if (failed(getIterationDomainTileFromOperandTile(
+ op, b, /*operandNumber=*/0, offsets, sizes, outerDimOffsets,
+ outerDimSizes)))
+ return failure();
+
+ SmallVector<OpFoldResult> outputOffsets, outputSizes;
+ if (failed(getResultTilePosition(op, b, 0, outerDimOffsets, outerDimSizes,
+ outputOffsets, outputSizes)))
+ return failure();
+
+ strides.append(packOp.getDestRank() - inputRank, oneAttr);
+ auto extractSlice = b.create<ExtractSliceOp>(
+ loc, packOp.getDest(), outputOffsets, outputSizes, strides);
+ tiledOperands.push_back(extractSlice);
+
+ if (auto val = packOp.getPaddingValue())
+ tiledOperands.push_back(val);
+ for (auto tile : packOp.getInnerTiles())
+ tiledOperands.push_back(tile);
+
+ Operation *tiledPackOp = b.create<PackOp>(
+ loc, TypeRange{extractSlice.getType()}, tiledOperands, op->getAttrs());
+
+ return TilingResult{{tiledPackOp},
+ SmallVector<Value>(tiledPackOp->getResults())};
+ }
};
struct UnpackTileDimInfo {
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index 400b558e37fcda..741dfbfb1cd5c2 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -315,3 +315,62 @@ module attributes {transform.with_named_sequence} {
// CHECK: }
// CHECK: }
// CHECK: return %[[FINAL_RESULT]]#1 :
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+module {
+ func.func @fuse_pack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<4x32x16xf32> {
+ %c4 = arith.constant 4 : index
+ %c64 = arith.constant 64 : index
+ %c0 = arith.constant 0 : index
+ %1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
+ %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
+ %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
+ ^bb0(%in: f32, %in_16: f32, %out: f32):
+ %13 = arith.mulf %in, %in_16 : f32
+ %14 = arith.addf %out, %13 : f32
+ linalg.yield %14 : f32
+ } -> tensor<32x32xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
+ }
+ }
+ %output = tensor.empty() : tensor<4x32x16xf32>
+ %pack = tensor.pack %1 outer_dims_perm = [0, 1] inner_dims_pos = [0] inner_tiles = [16] into %output : tensor<64x32xf32> -> tensor<4x32x16xf32>
+ return %pack : tensor<4x32x16xf32>
+ }
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %a, %b = transform.test.fuse_consumer %slice_op
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
+// CHECK: func.func @fuse_pack_consumer_into_scf_forall(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x32xf32>)
+// CHECK: %[[OUT_INIT:.*]] = tensor.empty() : tensor<4x32x16xf32>
+// CHECK: %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2)
+// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[PACK_OUT_ARG:.*]] = %[[OUT_INIT]])
+// CHECK-SAME: {
+// CHECK: %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: %[[GENERIC_OUT:.*]] = linalg.generic
+// CHECK-SAME: outs(%[[GENERIC_OUT_SLICE]] :
+// CHECK: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV1]])
+// CHECK: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], %[[IV2]], 0] [2, 32, 16] [1, 1, 1]
+// CHECK: %[[TILED_PACK_OUT:.*]] = tensor.pack %[[GENERIC_OUT]]
+// CHECK-SAME: outer_dims_perm = [0, 1] inner_dims_pos = [0] inner_tiles = [16]
+// CHECK-SAME: into %[[TILED_PACK_DEST]]
+// CHECK: scf.forall.in_parallel {
+// CHECK: tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], %[[IV2]], 0] [2, 32, 16] [1, 1, 1]
+// CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
+// CHECK: }
+// CHECK: }
+// CHECK: return %[[FINAL_RESULT]]#1 :
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Few questions, but this looks mostly OK to me.
tensor.pack
as consumertensor.pack
op.
Hi, @MaheshRavishankar @hanhanW. Let me try to rephrase the corner case we may encounter:
If we do not consider tiling, the
Then, if we take tiling into consideration, i.e. the
I am seeking your advice about how to deal with this? Based on current coordination( BTW, this issue exists even if without padding, e.g.
where the |
Ok, so I was trying something. I initially thought that fusing producer with consumer, or consumer with producer would give the same result. So I tried tiling pack and fusing the producer. Input IR:
which on using
But to get this I need to specify tile sizes in terms of the result dimension of the pack. So that works for tile consumer + fuse producer, but if you tile the producer first (the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, this makes sense to me now. Thanks.
d7fe9c7
to
cdbf9ed
Compare
Yes, as I mentioned above, the mirroring case of fusing pack as a consumer is actually fusing unpack as a producer. @hanhanW contributed a lot on this and had more experience to deal with incomplete tile before. It would be perfect if you can help to amend the restrictions left in this patch with an unified interface to address incomplete tile issue caused by either |
Thanks @MaheshRavishankar for taking over the review! I did not find cycles to do deep review, but I can imagine the problem now. Hopefully I can get back to this and add the support later. And thanks a lot for pushing on this, @Yun-Fly ! |
Thanks for all of your involving in busy schedules! |
Add missing `getIterationDomainTileFromOperandTile` and `getTiledImplementationFromOperandTile` to `tensor.pack` and enable fusing it as a consumer. NOTE that, it only expects perfect tiling scenario without padding semantic currently.
Add missing `getIterationDomainTileFromOperandTile` and `getTiledImplementationFromOperandTile` to `tensor.pack` and enable fusing it as a consumer. NOTE that, it only expects perfect tiling scenario without padding semantic currently.
Add missing `getIterationDomainTileFromOperandTile` and `getTiledImplementationFromOperandTile` to `tensor.pack` and enable fusing it as a consumer. NOTE that, it only expects perfect tiling scenario without padding semantic currently.
Add missing `getIterationDomainTileFromOperandTile` and `getTiledImplementationFromOperandTile` to `tensor.pack` and enable fusing it as a consumer. NOTE that, it only expects perfect tiling scenario without padding semantic currently.
Add missing
getIterationDomainTileFromOperandTile
andgetTiledImplementationFromOperandTile
totensor.pack
and enable fusing it as a consumer.