Skip to content

Commit

Permalink
[mlir][ArmSME] Pattern to swap shape_cast(tranpose) with transpose(sh…
Browse files Browse the repository at this point in the history
…ape_cast) (#100731)

This applies when the shape_cast is simply for dropping unit dims, and
the result rank is >= 2.

This simplifies the transpose making it possible for other ArmSME
legalization patterns to handle it.

Example:

```mlir
%0 = vector.transpose %vector, [3, 0, 1, 2]
       : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
%1 = vector.shape_cast %0 : vector<[4]x1x1x4xf32> to vector<[4]x4xf32>
```

```mlir
%0 = vector.shape_cast %arg0 : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
%1 = vector.transpose %0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
```
  • Loading branch information
MacDue authored Jul 26, 2024
1 parent 49cb170 commit 88accd9
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 1 deletion.
91 changes: 90 additions & 1 deletion mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -774,6 +774,94 @@ struct ConvertIllegalShapeCastOpsToTransposes
}
};

/// Returns an iterator over the dims (inc scalability) of a VectorType.
static auto getDims(VectorType vType) {
return llvm::zip_equal(vType.getShape(), vType.getScalableDims());
}

/// Helper to drop (fixed-size) unit dims from a VectorType.
static VectorType dropUnitDims(VectorType vType) {
SmallVector<bool> scalableFlags;
SmallVector<int64_t> dimSizes;
for (auto dim : getDims(vType)) {
if (dim == std::make_tuple(1, false))
continue;
auto [size, scalableFlag] = dim;
dimSizes.push_back(size);
scalableFlags.push_back(scalableFlag);
}
return VectorType::get(dimSizes, vType.getElementType(), scalableFlags);
}

/// A pattern to swap shape_cast(tranpose) with transpose(shape_cast) if the
/// shape_cast only drops unit dimensions.
///
/// This simplifies the transpose making it possible for other legalization
/// rewrites to handle it.
///
/// Example:
///
/// BEFORE:
/// ```mlir
/// %0 = vector.transpose %vector, [3, 0, 1, 2]
/// : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
/// %1 = vector.shape_cast %0 : vector<[4]x1x1x4xf32> to vector<[4]x4xf32>
/// ```
///
/// AFTER:
/// ```mlir
/// %0 = vector.shape_cast %arg0 : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
/// %1 = vector.transpose %0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
/// ```
struct SwapShapeCastOfTranspose : public OpRewritePattern<vector::ShapeCastOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
PatternRewriter &rewriter) const override {
auto transposeOp =
shapeCastOp.getSource().getDefiningOp<vector::TransposeOp>();
if (!transposeOp)
return rewriter.notifyMatchFailure(shapeCastOp, "not TransposeOp");

auto resultType = shapeCastOp.getResultVectorType();
if (resultType.getRank() <= 1)
return rewriter.notifyMatchFailure(shapeCastOp, "result rank too low");

if (resultType != dropUnitDims(shapeCastOp.getSourceVectorType()))
return rewriter.notifyMatchFailure(
shapeCastOp, "ShapeCastOp changes non-unit dimension(s)");

auto transposeSourceVectorType = transposeOp.getSourceVectorType();
auto transposeSourceDims =
llvm::to_vector(getDims(transposeSourceVectorType));

// Construct a map from dimIdx -> number of dims dropped before dimIdx.
SmallVector<int64_t> droppedDimsBefore(transposeSourceVectorType.getRank());
int64_t droppedDims = 0;
for (auto [i, dim] : llvm::enumerate(transposeSourceDims)) {
droppedDimsBefore[i] = droppedDims;
if (dim == std::make_tuple(1, false))
++droppedDims;
}

// Drop unit dims from transpose permutation.
auto perm = transposeOp.getPermutation();
SmallVector<int64_t> newPerm;
for (int64_t idx : perm) {
if (transposeSourceDims[idx] == std::make_tuple(1, false))
continue;
newPerm.push_back(idx - droppedDimsBefore[idx]);
}

auto loc = shapeCastOp.getLoc();
auto newShapeCastOp = rewriter.create<vector::ShapeCastOp>(
loc, dropUnitDims(transposeSourceVectorType), transposeOp.getVector());
rewriter.replaceOpWithNewOp<vector::TransposeOp>(shapeCastOp,
newShapeCastOp, newPerm);
return success();
}
};

/// Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use
/// the ZA state. This workaround rewrite to support these transposes when ZA is
/// available.
Expand Down Expand Up @@ -939,7 +1027,8 @@ struct VectorLegalizationPass
patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
LiftIllegalVectorTransposeToMemory,
ConvertIllegalShapeCastOpsToTransposes,
LowerIllegalTransposeStoreViaZA>(context);
SwapShapeCastOfTranspose, LowerIllegalTransposeStoreViaZA>(
context);
// Note: These two patterns are added with a high benefit to ensure:
// - Masked outer products are handled before unmasked ones
// - Multi-tile writes are lowered as a store loop (if possible)
Expand Down
26 changes: 26 additions & 0 deletions mlir/test/Dialect/ArmSME/vector-legalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -646,3 +646,29 @@ func.func @negative_transpose_store_scalable_via_za__bad_source_shape(%vec: vect
vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]} : vector<[7]x2xf32>, memref<?x?xf32>
return
}

// -----

// CHECK-LABEL: @swap_shape_cast_of_transpose(
// CHECK-SAME: %[[VEC:.*]]: vector<1x1x4x[4]xf32>)
func.func @swap_shape_cast_of_transpose(%vector: vector<1x1x4x[4]xf32>) -> vector<[4]x4xf32> {
// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[SHAPE_CAST]], [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
// CHECK: return %[[TRANSPOSE]]
%0 = vector.transpose %vector, [3, 0, 1, 2] : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
%1 = vector.shape_cast %0 : vector<[4]x1x1x4xf32> to vector<[4]x4xf32>
return %1 : vector<[4]x4xf32>
}

// -----

// CHECK-LABEL: @swap_shape_cast_of_transpose_units_dims_before_and_after(
// CHECK-SAME: %[[VEC:.*]]: vector<1x1x1x4x[4]x1xf32>)
func.func @swap_shape_cast_of_transpose_units_dims_before_and_after(%vector: vector<1x1x1x4x[4]x1xf32>) -> vector<[4]x4xf32> {
// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x1x4x[4]x1xf32> to vector<4x[4]xf32>
// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[SHAPE_CAST]], [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
// CHECK: return %[[TRANSPOSE]]
%0 = vector.transpose %vector, [4, 1, 0, 2, 3, 5] : vector<1x1x1x4x[4]x1xf32> to vector<[4]x1x1x1x4x1xf32>
%1 = vector.shape_cast %0 : vector<[4]x1x1x1x4x1xf32> to vector<[4]x4xf32>
return %1 : vector<[4]x4xf32>
}

0 comments on commit 88accd9

Please sign in to comment.