Skip to content

Commit

Permalink
[mlir][vector] Add pattern to drop unit dims from vector.transpose (l…
Browse files Browse the repository at this point in the history
…lvm#102017)

Example:

BEFORE:
```mlir
%transpose = vector.transpose %vector, [3, 0, 1, 2]
  : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
```

AFTER:
```mlir
%dropDims = vector.shape_cast %vector
  : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
%transpose = vector.transpose %0, [1, 0]
  : vector<4x[4]xf32> to vector<[4]x4xf32>
%restoreDims = vector.shape_cast %transpose
  : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
```
  • Loading branch information
MacDue authored Aug 8, 2024
1 parent 13d04fa commit da8778e
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 2 deletions.
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,11 @@ inline auto makeVscaleConstantBuilder(PatternRewriter &rewriter, Location loc) {
};
}

/// Returns a range over the dims (size and scalability) of a VectorType.
inline auto getDims(VectorType vType) {
return llvm::zip_equal(vType.getShape(), vType.getScalableDims());
}

/// A wrapper for getMixedSizes for vector.transfer_read and
/// vector.transfer_write Ops (for source and destination, respectively).
///
Expand Down
70 changes: 68 additions & 2 deletions mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1720,6 +1720,72 @@ struct DropUnitDimFromElementwiseOps final
}
};

/// A pattern to drop unit dims from vector.transpose.
///
/// Example:
///
/// BEFORE:
/// ```mlir
/// %transpose = vector.transpose %vector, [3, 0, 1, 2]
/// : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
/// ```
///
/// AFTER:
/// ```mlir
/// %dropDims = vector.shape_cast %vector
/// : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
/// %transpose = vector.transpose %0, [1, 0]
/// : vector<4x[4]xf32> to vector<[4]x4xf32>
/// %restoreDims = vector.shape_cast %transpose
/// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
/// ```
struct DropUnitDimsFromTransposeOp final
: OpRewritePattern<vector::TransposeOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(vector::TransposeOp op,
PatternRewriter &rewriter) const override {
VectorType sourceType = op.getSourceVectorType();
VectorType sourceTypeWithoutUnitDims =
dropNonScalableUnitDimFromType(sourceType);

if (sourceType == sourceTypeWithoutUnitDims)
return failure();

// Construct a map from dimIdx -> number of dims dropped before dimIdx.
auto sourceDims = llvm::to_vector(vector::getDims(sourceType));
SmallVector<int64_t> droppedDimsBefore(sourceType.getRank());
int64_t droppedDims = 0;
for (auto [i, dim] : llvm::enumerate(sourceDims)) {
droppedDimsBefore[i] = droppedDims;
if (dim == std::make_tuple(1, false))
++droppedDims;
}

// Drop unit dims from transpose permutation.
ArrayRef<int64_t> perm = op.getPermutation();
SmallVector<int64_t> newPerm;
for (int64_t idx : perm) {
if (sourceDims[idx] == std::make_tuple(1, false))
continue;
newPerm.push_back(idx - droppedDimsBefore[idx]);
}

Location loc = op.getLoc();
// Drop the unit dims via shape_cast.
auto dropDimsShapeCast = rewriter.create<vector::ShapeCastOp>(
loc, sourceTypeWithoutUnitDims, op.getVector());
// Create the new transpose.
auto tranposeWithoutUnitDims =
rewriter.create<vector::TransposeOp>(loc, dropDimsShapeCast, newPerm);
// Restore the unit dims via shape cast.
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
op, op.getResultVectorType(), tranposeWithoutUnitDims);

return failure();
}
};

/// Pattern to eliminate redundant zero-constants added to reduction operands.
/// It's enough for there to be one initial zero value, so we can eliminate the
/// extra ones that feed into `vector.reduction <add>`. These get created by the
Expand Down Expand Up @@ -1924,8 +1990,8 @@ void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,

void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<DropUnitDimFromElementwiseOps, ShapeCastOpFolder>(
patterns.getContext(), benefit);
patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimsFromTransposeOp,
ShapeCastOpFolder>(patterns.getContext(), benefit);
}

void mlir::vector::populateBubbleVectorBitCastOpPatterns(
Expand Down
44 changes: 44 additions & 0 deletions mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -700,3 +700,47 @@ func.func @negative_out_of_bound_transfer_write(
}
// CHECK: func.func @negative_out_of_bound_transfer_write
// CHECK-NOT: memref.collapse_shape

// -----

///----------------------------------------------------------------------------------------
/// [Pattern: DropUnitDimsFromTransposeOp]
/// TODO: Move to a dedicated file - there's no "flattening" in the following tests
///----------------------------------------------------------------------------------------

func.func @transpose_with_internal_unit_dims(%vec: vector<1x1x4x[4]xf32>) -> vector<[4]x1x1x4xf32> {
%res = vector.transpose %vec, [3, 0, 1, 2] : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
return %res : vector<[4]x1x1x4xf32>
}

// CHECK-LABEL: func.func @transpose_with_internal_unit_dims(
// CHECK-SAME: %[[VEC:.*]]: vector<1x1x4x[4]xf32>)
// CHECK-NEXT: %[[DROP_DIMS:.*]] = vector.shape_cast %arg0 : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
// CHECK-NEXT: %[[TRANSPOSE:.*]] = vector.transpose %0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
// CHECK-NEXT: %[[RESTORE_DIMS:.*]] = vector.shape_cast %1 : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
// CHECK-NEXT: return %[[RESTORE_DIMS]] : vector<[4]x1x1x4xf32>

// -----

func.func @transpose_with_scalable_unit_dims(%vec: vector<[1]x1x2x4x1xf32>) -> vector<1x1x4x2x[1]xf32>
{
%res = vector.transpose %vec, [4, 1, 3, 2, 0] : vector<[1]x1x2x4x1xf32> to vector<1x1x4x2x[1]xf32>
return %res: vector<1x1x4x2x[1]xf32>
}

// CHECK-LABEL: func.func @transpose_with_scalable_unit_dims(
// CHECK-SAME: %[[VEC:.*]]: vector<[1]x1x2x4x1xf32>)
// CHECK-NEXT: %[[DROP_DIMS:.*]] = vector.shape_cast %[[VEC]] : vector<[1]x1x2x4x1xf32> to vector<[1]x2x4xf32>
// CHECK-NEXT: %[[TRANSPOSE:.*]] = vector.transpose %[[DROP_DIMS]], [2, 1, 0] : vector<[1]x2x4xf32> to vector<4x2x[1]xf32>
// CHECK-NEXT: %[[RESTORE_DIMS:.*]] = vector.shape_cast %[[TRANSPOSE]] : vector<4x2x[1]xf32> to vector<1x1x4x2x[1]xf32>
// CHECK-NEXT: return %[[RESTORE_DIMS]] : vector<1x1x4x2x[1]xf32>

// -----

func.func @negative_transpose_with_no_unit_dims(%vec: vector<4x2x3xf32>) -> vector<4x3x2xf32> {
%res = vector.transpose %vec, [0, 2, 1] : vector<4x2x3xf32> to vector<4x3x2xf32>
return %res : vector<4x3x2xf32>
}

// CHECK-LABEL: func.func @negative_transpose_with_no_unit_dims
// CHECK-NOT: vector.shape_cast

0 comments on commit da8778e

Please sign in to comment.